# MCC125 - Wireless Link Project

## 1. Loading the data

In [68]:
# importing libraries
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import csv
import os
import pandas as pd
import numpy as np
from tqdm import tqdm 
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import model_Transformer_modified

In [69]:
input_file1 = 'feature_symbol_dataset.csv'
output_file_feature = 'feature_symbol_dataset_processing.csv'

input_file2 = 'target_symbol_dataset.csv'
output_file_target = 'target_symbol_dataset_processing.csv'

with open(input_file1, 'r', newline='') as infile, open(output_file_feature, 'w', newline='') as outfile:
    reader = csv.reader(infile)
    writer = csv.writer(outfile)

    for row in reader:
        new_row = [cell.replace('i', 'j') for cell in row]
        writer.writerow(new_row)

with open(input_file2, 'r', newline='') as infile, open(output_file_target, 'w', newline='') as outfile:
    reader = csv.reader(infile)
    writer = csv.writer(outfile)

    for row in reader:
        new_row = [cell.replace('i', 'j') for cell in row]
        writer.writerow(new_row)

## 2. Pre-processing

In [70]:
import ast  # Used to parse complex numbers from string
# Define a function to parse complex numbers from string
def complex_parser(s):
    try:
        return ast.literal_eval(s)
    except (ValueError, SyntaxError):
        return s

In [71]:
df_target = pd.read_csv(output_file_target, header=None, converters={col: complex_parser for col in range(len(pd.read_csv(output_file_target).columns))})

df_feature = pd.read_csv(output_file_feature, header=None, converters={col: complex_parser for col in range(len(pd.read_csv(output_file_feature).columns))})

# filter zero values rows in the dataframe
df_target = df_target[df_target.loc[:,0] != 0j]
df_feature = df_feature[df_feature.loc[:,0] != 0j]

In [72]:
df_target

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
0,0.114876-0.574380j,-0.268044+1.110467j,0.804132+0.650964j,1.033883+0.038292j,0.727548-0.191460j,0.421212+1.110467j,-1.110467+0.880716j,1.187051+0.191460j,-0.114876+0.880716j,-0.344628-0.114876j,0.191460+1.033883j,-1.110467+0.880716j,1.033883-0.497796j,-0.650964+0.191460j,0.191460+0.114876j
1,0.268044-0.804132j,-0.650964-0.344628j,-0.268044+0.727548j,1.110467-0.421212j,1.187051+0.497796j,-0.957299-0.574380j,0.574380-0.114876j,-0.880716+1.033883j,-0.650964+1.033883j,0.650964-1.187051j,1.110467+0.957299j,-0.421212+0.421212j,0.344628-0.574380j,-0.114876+1.187051j,0.421212+0.727548j
2,-0.880716+0.650964j,0.497796-0.650964j,-0.344628+0.727548j,1.033883-0.344628j,0.880716-1.110467j,-0.038292+0.268044j,0.191460-0.344628j,0.421212+1.033883j,-0.880716+0.804132j,-0.191460-0.268044j,0.957299-0.344628j,0.727548-0.038292j,0.880716-0.804132j,-0.880716+0.421212j,1.187051+0.574380j
3,0.421212-0.804132j,-0.880716+0.880716j,0.344628+0.497796j,0.344628+1.110467j,0.497796+0.727548j,0.038292-0.804132j,1.033883-0.880716j,0.727548+0.574380j,0.650964+1.110467j,-0.114876-0.191460j,-1.033883-0.268044j,-0.421212+1.033883j,1.110467+0.957299j,-0.727548-0.727548j,-0.880716-0.804132j
4,0.268044+0.957299j,0.880716+1.033883j,-0.574380-0.191460j,0.574380-0.344628j,-1.187051+1.187051j,0.957299+0.191460j,-0.497796+0.727548j,-0.191460-0.497796j,0.421212+0.038292j,0.650964+0.804132j,0.344628+0.344628j,1.033883-0.497796j,0.191460+0.191460j,-0.191460+0.421212j,-0.574380-0.114876j
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
145,-0.880716+0.727548j,-0.344628-1.187051j,-0.804132-0.804132j,-0.727548+0.957299j,0.038292+0.957299j,0.957299-0.650964j,1.187051+0.344628j,1.110467-0.344628j,-1.187051-0.114876j,0.191460-0.574380j,-0.497796+0.727548j,0.804132-1.033883j,-0.497796+1.187051j,0.421212+0.191460j,0.957299+1.110467j
146,-1.187051-0.727548j,0.344628-1.110467j,-0.114876-0.421212j,1.033883+0.727548j,-0.727548+0.650964j,0.574380-0.114876j,-0.574380+1.033883j,-0.344628+0.574380j,0.880716-0.114876j,1.187051+0.574380j,-0.880716-0.727548j,-0.191460+0.344628j,-1.110467+1.110467j,0.344628+0.650964j,0.957299+0.268044j
147,-0.880716+0.727548j,-0.344628-0.727548j,-0.650964-0.880716j,-1.110467+1.033883j,0.497796+0.268044j,-0.268044-0.804132j,-0.497796-0.574380j,-1.033883+1.110467j,-1.187051+0.574380j,0.804132-0.038292j,-0.727548-0.727548j,0.650964+0.957299j,0.421212+0.957299j,-1.033883+0.268044j,-0.497796-0.268044j
148,-0.880716-1.187051j,1.110467-1.033883j,0.650964-0.268044j,1.110467+0.957299j,-0.804132-0.727548j,0.421212+0.344628j,-0.727548-0.727548j,0.957299+0.421212j,-0.038292+0.038292j,0.574380-0.038292j,-0.574380-0.344628j,0.268044-0.114876j,-0.727548-0.650964j,0.957299-0.650964j,0.880716+0.574380j


In [73]:
df_feature

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,140,141,142,143,144,145,146,147,148,149
0,0.603278-0.092069j,0.509914-0.036396j,0.354361+0.003164j,0.146608+0.028546j,-0.097360+0.043150j,-0.357368+0.051472j,-0.611558+0.058731j,-0.839062+0.070304j,-1.022597+0.091143j,-1.150416+0.125263j,...,-0.034822+0.260713j,-0.022812+0.266143j,-0.017089+0.250383j,-0.014873+0.219416j,-0.013996+0.179617j,-0.013044+0.136955j,-0.011413+0.096349j,-0.009085+0.061301j,-0.006428+0.033768j,-0.003976+0.014256j
1,0.574348-0.712335j,0.545541-0.713478j,0.461712-0.711128j,0.334159-0.707626j,0.178102-0.704154j,0.010467-0.700196j,-0.152631-0.693281j,-0.298097-0.679136j,-0.417379-0.652310j,-0.507174-0.607082j,...,0.196721+0.855423j,0.199359+0.792985j,0.184515+0.717407j,0.157780+0.628023j,0.125042+0.526805j,0.091643+0.418203j,0.061773+0.308379j,0.038127+0.204153j,0.021770+0.111928j,0.012345+0.036682j
2,0.875279+0.974010j,0.809136+0.886516j,0.684224+0.767567j,0.505408+0.618022j,0.283969+0.442698j,0.036850+0.250368j,-0.214807+0.053078j,-0.447815-0.134948j,-0.639468-0.298860j,-0.770197-0.425031j,...,0.549205-1.418380j,0.518208-1.440372j,0.478376-1.376485j,0.428478-1.241401j,0.369037-1.055448j,0.302406-0.841566j,0.232382-0.622274j,0.163582-0.417072j,0.100667-0.240540j,0.047521-0.101406j
3,0.857477+0.401416j,0.766454+0.364486j,0.616267+0.266248j,0.416402+0.116723j,0.181517-0.068507j,-0.070192-0.270148j,-0.319149-0.467290j,-0.546647-0.639872j,-0.736961-0.770884j,-0.878933-0.848174j,...,0.796642-0.966437j,0.751355-0.912577j,0.689842-0.834409j,0.612176-0.733404j,0.520900-0.614489j,0.420846-0.485276j,0.318394-0.354916j,0.220436-0.232674j,0.133290-0.126532j,0.061744-0.042055j
4,-0.572997-0.701675j,-0.665814-0.705072j,-0.749554-0.701304j,-0.828681-0.692237j,-0.905926-0.679646j,-0.981102-0.664744j,-1.050507-0.647941j,-1.107159-0.628800j,-1.141736-0.606234j,-1.144146-0.578875j,...,0.533896-0.101798j,0.519004-0.078724j,0.483237-0.061527j,0.430283-0.048783j,0.365324-0.039002j,0.294340-0.030933j,0.223364-0.023680j,0.157733-0.016795j,0.101479-0.010201j,0.056995-0.004104j
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
145,0.646598+0.854384j,0.577590+0.815908j,0.452744+0.766608j,0.277978+0.706989j,0.065072+0.639050j,-0.169781+0.566310j,-0.408098+0.493535j,-0.631652+0.426289j,-0.824930+0.370231j,-0.977091+0.330323j,...,1.023905-0.864738j,1.002944-0.836888j,0.941459-0.780316j,0.845032-0.697652j,0.722251-0.594447j,0.583657-0.478448j,0.440521-0.358604j,0.303536-0.243842j,0.181690-0.141924j,0.081396-0.058517j
146,-1.104657-0.975513j,-0.984444-1.018058j,-0.825874-1.053621j,-0.634933-1.083944j,-0.421624-1.109882j,-0.199137-1.131134j,0.017518-1.146315j,0.213200-1.153354j,0.374314-1.150091j,0.490377-1.134840j,...,0.969035+0.421856j,0.925914+0.379784j,0.848542+0.334614j,0.741587+0.286999j,0.613148+0.237976j,0.473701+0.188990j,0.334566+0.141800j,0.206348+0.098310j,0.097476+0.060288j,0.013324+0.029141j
147,0.098147-1.150394j,0.120822-1.075045j,0.170590-0.957223j,0.242345-0.802858j,0.328929-0.622248j,0.422341-0.428766j,0.514919-0.236998j,0.600535-0.060718j,0.675377+0.089095j,0.738335+0.205709j,...,0.525901-0.149427j,0.491115-0.109931j,0.449165-0.084624j,0.399447-0.069968j,0.342853-0.062005j,0.281683-0.057105j,0.219259-0.052449j,0.159404-0.046309j,0.105698-0.038056j,0.060886-0.028025j
148,0.469930+1.641743j,0.316117+1.638584j,0.124937+1.592317j,-0.100350+1.509389j,-0.351531+1.398794j,-0.615954+1.270666j,-0.877827+1.134878j,-1.120170+0.999945j,-1.326876+0.872310j,-1.484804+0.756113j,...,-0.690954-0.965989j,-0.623995-0.963485j,-0.555329-0.916302j,-0.482467-0.830053j,-0.404630-0.714130j,-0.323068-0.580352j,-0.240785-0.441346j,-0.161972-0.308895j,-0.091155-0.192523j,-0.032329-0.098591j


In [74]:
target_symbol = df_target.to_numpy()

real_target_symbol = np.real(target_symbol)
imag_target_symbol = np.imag(target_symbol)

y = np.array([real_target_symbol, imag_target_symbol]).transpose(1, 2, 0)


feature_symbol = df_feature.to_numpy()
real_feature_symbol = np.real(feature_symbol)
imag_feature_symbol = np.imag(feature_symbol)

X = np.array([real_feature_symbol, imag_feature_symbol]).transpose(1, 2, 0)

In [75]:
y

array([[[ 0.11487594, -0.57437969],
        [-0.26804385,  1.11046739],
        [ 0.80413156,  0.65096364],
        ...,
        [ 1.03388343, -0.49779573],
        [-0.65096364,  0.1914599 ],
        [ 0.1914599 ,  0.11487594]],

       [[ 0.26804385, -0.80413156],
        [-0.65096364, -0.34462781],
        [-0.26804385,  0.7275476 ],
        ...,
        [ 0.34462781, -0.57437969],
        [-0.11487594,  1.18705135],
        [ 0.42121177,  0.7275476 ]],

       [[-0.88071552,  0.65096364],
        [ 0.49779573, -0.65096364],
        [-0.34462781,  0.7275476 ],
        ...,
        [ 0.88071552, -0.80413156],
        [-0.88071552,  0.42121177],
        [ 1.18705135,  0.57437969]],

       ...,

       [[-0.88071552,  0.7275476 ],
        [-0.34462781, -0.7275476 ],
        [-0.65096364, -0.88071552],
        ...,
        [ 0.42121177,  0.95729948],
        [-1.03388343,  0.26804385],
        [-0.49779573, -0.26804385]],

       [[-0.88071552, -1.18705135],
        [ 1.11046739, -1.03

In [76]:
X

array([[[ 6.03277520e-01, -9.20685975e-02],
        [ 5.09914361e-01, -3.63963129e-02],
        [ 3.54361090e-01,  3.16440467e-03],
        ...,
        [-9.08470535e-03,  6.13007800e-02],
        [-6.42816120e-03,  3.37681903e-02],
        [-3.97554236e-03,  1.42558204e-02]],

       [[ 5.74347854e-01, -7.12334767e-01],
        [ 5.45541404e-01, -7.13478488e-01],
        [ 4.61712294e-01, -7.11127668e-01],
        ...,
        [ 3.81273672e-02,  2.04153046e-01],
        [ 2.17701975e-02,  1.11927685e-01],
        [ 1.23449857e-02,  3.66821135e-02]],

       [[ 8.75278510e-01,  9.74009919e-01],
        [ 8.09135610e-01,  8.86516386e-01],
        [ 6.84224342e-01,  7.67566661e-01],
        ...,
        [ 1.63581902e-01, -4.17071719e-01],
        [ 1.00666936e-01, -2.40540019e-01],
        [ 4.75212822e-02, -1.01405987e-01]],

       ...,

       [[ 9.81466321e-02, -1.15039433e+00],
        [ 1.20821923e-01, -1.07504485e+00],
        [ 1.70590054e-01, -9.57223402e-01],
        ...,
     

In [77]:
print(X.shape)
print(y.shape)
print(X.shape[0])

(150, 150, 2)
(150, 15, 2)
150


In [78]:
y[1]

array([[ 0.26804385, -0.80413156],
       [-0.65096364, -0.34462781],
       [-0.26804385,  0.7275476 ],
       [ 1.11046739, -0.42121177],
       [ 1.18705135,  0.49779573],
       [-0.95729948, -0.57437969],
       [ 0.57437969, -0.11487594],
       [-0.88071552,  1.03388343],
       [-0.65096364,  1.03388343],
       [ 0.65096364, -1.18705135],
       [ 1.11046739,  0.95729948],
       [-0.42121177,  0.42121177],
       [ 0.34462781, -0.57437969],
       [-0.11487594,  1.18705135],
       [ 0.42121177,  0.7275476 ]])

In [79]:
class SymbolDataset(Dataset):
    """Class to create the torch Dataset object
    """
    def __init__(self, X, y):
        super(SymbolDataset, self).__init__()

        self.X = X
        self.y = y
    
    def __getitem__(self, index):
        return self.X[index], self.X[index]
    
    def __len__(self):
        """Total number of samples"""
        return len(self.X[:,0,0])

## 3. Data loaders

In [80]:
# lenght of sequence given to encoder
gt = 8
# length of sequence given to decoder
horizon = 12

# defining batch size
batch_size = 32

# creating torch dataset and dataloaders
symbol_dataset = SymbolDataset(X,y)
train_loader = DataLoader(symbol_dataset, batch_size, shuffle=True, num_workers=0)

In [81]:
len(train_loader)

5

## 4. Training

#### 4.1 create a model

In [82]:
# defining model save location
save_location = "./Transformer_models"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tf_model = model_Transformer_modified.Transformer(encoder_input_size=2, decoder_input_size=2,
                                embedding_size=512, num_heads=8, num_layers=6, feedforward_size=2048).to(device)

#### 4.2 The training loop

In [83]:
def train_epoch(model, optimizer, loss_fn, train_loader, device, print_every):
    # Train:
    model.train()
    train_loss_batches = []

    num_batches = len(train_loader)
    
    for batch_index, (x, y) in enumerate(train_loader, 1):
        
        feature, target = x.to(device), y.to(device)
        optimizer.zero_grad()

        predictions = model.forward(feature, target)
        
        loss = loss_fn(predictions.view(X.size(0), -1), target.contiguous().view(X.size(0), -1))

        loss.backward()
        optimizer.step()

        train_loss_batches.append(loss.item())

        # If you want to print your progress more often than every epoch you can
        # set `print_every` to the number of batches you want between every status update.
        # Note that the print out will trigger a full validation on the full val. set => slows down training
        if print_every is not None and batch_index % print_every == 0:
            model.train()
            print(f"\tBatch {batch_index}/{num_batches}: "
                  f"\tTrain loss: {sum(train_loss_batches[-print_every:])/print_every:.3f}")

    return model, train_loss_batches

In [84]:
def training_loop(model, optimizer, loss_fn, train_loader, num_epochs, print_every):
    print("Starting training")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    train_losses = []

    for epoch in tqdm(range(1, num_epochs+1)):
        model, train_loss = train_epoch(model, optimizer, loss_fn, train_loader, device, print_every)

        print(f"Epoch {epoch}/{num_epochs}: "
              f"Train loss: {sum(train_loss)/len(train_loss):.3f}")
        
        train_losses.extend(train_loss)

        if (epoch)%50 == 0:
            # Saving model, loss and error log files
            torch.save({
                'model_state_dict': model.state_dict(),
                'training_loss': train_loss
                }, os.path.join(save_location, 'Transformer_based_channel_model_epoch{}.pth'.format(epoch)))

    return model, train_losses

#### 4.3 Train the model

In [85]:
# loss function
loss_fn = nn.MSELoss()

# creating optimizer
# optimizer = torch.optim.SGD(tf_model.parameters(), lr=1e-4, momentum=0.9, weight_decay=1e-3, nesterov=True)
optimizer = torch.optim.Adam(tf_model.parameters(), lr=1e-4)
# optimizer = torch.optim.AdamW(tf_model.parameters(), lr=1e-2, betas=(0.9, 0.95), weight_decay=1e-1)

# number of epochs 
num_epochs = 100

Trained_model, train_losses = train_epoch(tf_model, optimizer, loss_fn, train_loader, device, print_every = 1)

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/Users/chenbingcheng/Library/Python/3.10/lib/python/site-packages/IPython/core/interactiveshell.py", line 3378, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/06/gmp4cfpd1kn490jzvs864zh00000gn/T/ipykernel_41908/3049097336.py", line 12, in <module>
    Trained_model, train_losses = train_epoch(tf_model, optimizer, loss_fn, train_loader, device, print_every = 1)
  File "/var/folders/06/gmp4cfpd1kn490jzvs864zh00000gn/T/ipykernel_41908/2242708943.py", line 15, in train_epoch
    loss = loss_fn(predictions.view(X.size(0), -1), target.contiguous().view(X.size(0), -1))
TypeError: 'int' object is not callable

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/chenbingcheng/Library/Python/3.10/lib/python/site-packages/IPython/core/interactiveshell.py", line 1997, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
  File "/Users/