# MCC125 - Wireless Link Project

## 1. Loading the data

In [167]:
# importing libraries
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import csv
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import model_Transformer_modified
import ast  # Used to parse complex numbers from string

In [168]:
input_file1 = 'simulation_feature_dataset.csv'
output_file_feature = 'simulation_feature_dataset_processing.csv'

input_file2 = 'simulation_target_dataset.csv'
output_file_target = 'simulation_target_dataset_processing.csv'

with open(input_file1, 'r', newline='') as infile, open(output_file_feature, 'w', newline='') as outfile:
    reader = csv.reader(infile)
    writer = csv.writer(outfile)

    for row in reader:
        new_row = [cell.replace('i', 'j') for cell in row]
        writer.writerow(new_row)

with open(input_file2, 'r', newline='') as infile, open(output_file_target, 'w', newline='') as outfile:
    reader = csv.reader(infile)
    writer = csv.writer(outfile)

    for row in reader:
        new_row = [cell.replace('i', 'j') for cell in row]
        writer.writerow(new_row)

## 2. Pre-processing

In [169]:
# Define a function to parse complex numbers from string
def complex_parser(s):
    try:
        return ast.literal_eval(s)
    except (ValueError, SyntaxError):
        return s

In [170]:
df_target = pd.read_csv(output_file_target, header=None, converters={col: complex_parser for col in range(len(pd.read_csv(output_file_target).columns))})

df_feature = pd.read_csv(output_file_feature, header=None, converters={col: complex_parser for col in range(len(pd.read_csv(output_file_feature).columns))})

# filter zero values rows in the dataframe
# df_target = df_target[df_target.loc[:,0] != 0j]
# df_feature = df_feature[df_feature.loc[:,0] != 0j]

In [171]:
index_of_rows = df_feature.index[df_feature.iloc[:, 0] == 0].tolist()
len(index_of_rows)

0

In [172]:
df_target = df_target.drop(index=index_of_rows)
df_target

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
0,0.574380+0.114876j,0.574380+0.038292j,0.191460+0.038292j,-0.421212+0.497796j,1.110467-0.344628j,-0.344628-1.110467j,-0.804132-0.727548j,-0.804132-0.038292j,0.344628-0.038292j,0.344628-0.114876j,-0.038292+0.497796j,0.574380+0.957299j,-1.187051+0.727548j,0.114876-0.344628j,-0.038292+0.957299j
1,-0.880716+0.114876j,0.880716+1.187051j,-0.650964+0.421212j,-0.268044+0.191460j,0.344628-0.268044j,-0.114876+0.650964j,-1.110467-0.344628j,1.033883+0.421212j,-0.497796-0.957299j,0.650964+0.880716j,-0.114876-0.114876j,0.574380-0.268044j,0.650964+0.880716j,-0.727548+0.268044j,0.957299+0.497796j
2,0.421212-0.114876j,-0.114876-0.880716j,-1.033883+0.957299j,0.804132+0.497796j,-0.114876+1.033883j,0.957299+0.191460j,0.191460-1.110467j,0.191460-0.957299j,0.344628+0.727548j,0.574380-0.650964j,0.497796+0.574380j,-0.650964+0.421212j,-1.033883+0.574380j,-0.268044+0.114876j,-0.191460-1.110467j
3,0.574380+0.650964j,0.574380-0.804132j,-0.804132-0.344628j,-0.957299-0.957299j,0.421212+0.804132j,-0.114876+0.650964j,0.038292+0.268044j,0.114876+0.268044j,0.497796-0.574380j,-0.727548+1.110467j,-0.650964+0.268044j,-0.421212+0.344628j,-1.110467+0.804132j,-0.114876+1.187051j,-0.880716+1.187051j
4,-0.880716-0.344628j,1.033883+0.804132j,-0.650964-0.421212j,0.268044-0.727548j,-1.110467+0.114876j,-1.187051-1.110467j,0.574380+0.421212j,0.497796+0.421212j,0.880716+0.957299j,-0.574380-0.344628j,-0.038292-0.038292j,-1.033883+0.114876j,-0.421212+1.187051j,-0.191460-0.114876j,-0.344628+0.191460j
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29995,0.574380+0.268044j,0.650964-0.268044j,-1.110467-0.804132j,0.574380-0.574380j,-0.114876+0.268044j,0.804132+0.268044j,-0.804132+0.957299j,1.110467+1.110467j,-0.727548-0.804132j,-1.033883+0.191460j,-0.497796+0.344628j,-0.344628-0.421212j,-0.268044+0.114876j,-0.344628+0.574380j,0.880716+0.114876j
29996,-0.727548+0.421212j,0.344628+0.114876j,-0.421212+0.574380j,-0.650964-0.650964j,1.033883+0.268044j,1.110467-0.268044j,-1.033883-0.344628j,0.114876-0.650964j,0.038292+0.804132j,0.574380-0.191460j,0.650964-0.804132j,1.187051+0.114876j,0.727548-0.804132j,1.033883+0.344628j,0.114876+0.727548j
29997,-1.033883-0.497796j,0.421212-0.191460j,0.727548+0.574380j,-0.114876-0.114876j,0.038292+0.268044j,-1.187051+0.497796j,-0.344628-0.650964j,0.727548-0.114876j,-0.344628+0.804132j,-0.038292-0.344628j,-1.033883-0.421212j,0.114876-0.191460j,-0.038292-1.033883j,0.497796-0.191460j,-0.268044+0.344628j
29998,-1.033883+0.114876j,0.957299-0.804132j,-0.191460-0.574380j,0.804132-1.110467j,0.344628+0.038292j,-0.497796+0.421212j,-0.344628-0.727548j,-0.268044-0.038292j,-1.110467+0.727548j,0.421212+0.344628j,0.114876-0.957299j,1.187051-0.650964j,-1.033883+0.191460j,-1.110467+0.957299j,-0.957299-0.574380j


In [173]:
df_feature = df_feature[df_feature.loc[:,0] != 0j]

In [174]:
df_feature

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,465,466,467,468,469,470,471,472,473,474
0,-1.245799-0.292176j,-1.330345-0.318523j,-1.333179-0.327656j,-1.295741-0.327539j,-1.257863-0.325232j,-1.242229-0.324840j,-1.249740-0.327369j,-1.265334-0.330762j,-1.270650-0.331395j,-1.258896-0.326735j,...,0.223757-0.379588j,0.244452-0.201714j,0.092920+0.175302j,-0.114768+0.635493j,-0.277130+1.017293j,-0.343242+1.189098j,-0.315785+1.108918j,-0.231035+0.835668j,-0.132680+0.489787j,-0.052745+0.190412j
1,1.106911+0.491718j,1.177100+0.521656j,1.172034+0.520385j,1.131914+0.507528j,1.095133+0.499688j,1.082941+0.503957j,1.095336+0.517099j,1.116791+0.529557j,1.128315+0.531811j,1.121248+0.520812j,...,0.951374+0.052800j,0.759483-0.036866j,0.339731-0.312025j,-0.170124-0.652597j,-0.597819-0.921273j,-0.816160-1.019368j,-0.794928-0.922378j,-0.599373-0.680918j,-0.342885-0.389879j,-0.126636-0.142680j
2,0.018773+1.244856j,0.021368+1.319193j,0.024165+1.313594j,0.023490+1.275121j,0.018128+1.244869j,0.010646+1.241904j,0.005653+1.261579j,0.006373+1.284680j,0.012243+1.291802j,0.018959+1.277444j,...,0.140435+0.337597j,-0.133219+0.260295j,-0.506815+0.242203j,-0.917298+0.242280j,-1.249729+0.230142j,-1.389536+0.196261j,-1.283955+0.146636j,-0.973925+0.091823j,-0.575025+0.040488j,-0.219282-0.000987j
3,1.295474-0.280640j,1.375535-0.299172j,1.371563-0.299915j,1.331481-0.291760j,1.298031-0.283309j,1.291666-0.280071j,1.308322-0.283263j,1.329003-0.290310j,1.335002-0.297024j,1.322205-0.300643j,...,-0.129886-1.593995j,-0.059573-1.652626j,0.190969-1.739536j,0.509871-1.828353j,0.767413-1.860173j,0.870409-1.767205j,0.795386-1.512713j,0.587682-1.121238j,0.331428-0.676545j,0.108752-0.284741j
4,-0.802838-0.766358j,-0.852347-0.813439j,-0.847930-0.812228j,-0.820891-0.790825j,-0.799941-0.773245j,-0.799005-0.770460j,-0.814688-0.779914j,-0.832167-0.790823j,-0.836501-0.792250j,-0.823647-0.781384j,...,-0.087930-0.242369j,-0.068578-0.280977j,-0.157946-0.247598j,-0.289956-0.190257j,-0.395097-0.141364j,-0.429473-0.108417j,-0.386520-0.082776j,-0.290129-0.055175j,-0.177237-0.025308j,-0.080350-0.000589j
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29995,-0.591411-0.946587j,-0.631333-1.003882j,-0.632693-0.998537j,-0.615195-0.966185j,-0.596590-0.939667j,-0.586312-0.935106j,-0.584751-0.948876j,-0.586570-0.964652j,-0.585496-0.966410j,-0.579427-0.950975j,...,-0.745403-0.004432j,-0.603040+0.153681j,-0.333364+0.382821j,-0.013515+0.632996j,0.257330+0.827564j,0.404964+0.897226j,0.411455+0.815245j,0.314687+0.612213j,0.180684+0.361055j,0.067738+0.140979j
29996,1.414226-0.558196j,1.500999-0.596972j,1.492131-0.602040j,1.437303-0.592868j,1.383923-0.585465j,1.359225-0.585929j,1.364790-0.590859j,1.383077-0.592651j,1.391935-0.585949j,1.381941-0.572388j,...,-1.675473+0.076402j,-1.568001-0.163546j,-1.316565-0.408754j,-1.014707-0.661175j,-0.744910-0.870518j,-0.544256-0.962972j,-0.402459-0.892619j,-0.287181-0.677981j,-0.175038-0.397339j,-0.067040-0.147374j
29997,0.722211-1.194098j,0.766435-1.271573j,0.760854-1.273961j,0.735113-1.245743j,0.714173-1.224247j,0.708246-1.225443j,0.713131-1.243518j,0.717808-1.260203j,0.714080-1.258104j,0.703292-1.233477j,...,-0.105733+0.736649j,-0.217215+0.551022j,-0.246886+0.251756j,-0.244449-0.097723j,-0.239509-0.398960j,-0.232995-0.565116j,-0.210805-0.563894j,-0.164268-0.430281j,-0.101104-0.242452j,-0.040537-0.078296j
29998,0.749197+1.069434j,0.793965+1.148567j,0.792032+1.157250j,0.771860+1.130366j,0.757186+1.101022j,0.757277+1.087773j,0.766513+1.091279j,0.771764+1.099125j,0.762402+1.096343j,0.738889+1.077893j,...,1.846746+0.453473j,1.664992+0.588680j,1.291616+0.852089j,0.824895+1.155062j,0.401593+1.383355j,0.122026+1.440803j,0.004537+1.290577j,-0.004731+0.972788j,0.017895+0.586556j,0.022631+0.244009j


In [175]:
target_symbol = df_target.to_numpy()

real_target_symbol = np.real(target_symbol)
imag_target_symbol = np.imag(target_symbol)

y = np.array([real_target_symbol, imag_target_symbol]).transpose(1, 2, 0)


feature_symbol = df_feature.to_numpy()
real_feature_symbol = np.real(feature_symbol)
imag_feature_symbol = np.imag(feature_symbol)

X = np.array([real_feature_symbol, imag_feature_symbol]).transpose(1, 2, 0)

In [176]:
y

array([[[ 0.57437969,  0.11487594],
        [ 0.57437969,  0.03829198],
        [ 0.1914599 ,  0.03829198],
        ...,
        [-1.18705135,  0.7275476 ],
        [ 0.11487594, -0.34462781],
        [-0.03829198,  0.95729948]],

       [[-0.88071552,  0.11487594],
        [ 0.88071552,  1.18705135],
        [-0.65096364,  0.42121177],
        ...,
        [ 0.65096364,  0.88071552],
        [-0.7275476 ,  0.26804385],
        [ 0.95729948,  0.49779573]],

       [[ 0.42121177, -0.11487594],
        [-0.11487594, -0.88071552],
        [-1.03388343,  0.95729948],
        ...,
        [-1.03388343,  0.57437969],
        [-0.26804385,  0.11487594],
        [-0.1914599 , -1.11046739]],

       ...,

       [[-1.03388343, -0.49779573],
        [ 0.42121177, -0.1914599 ],
        [ 0.7275476 ,  0.57437969],
        ...,
        [-0.03829198, -1.03388343],
        [ 0.49779573, -0.1914599 ],
        [-0.26804385,  0.34462781]],

       [[-1.03388343,  0.11487594],
        [ 0.95729948, -0.80

In [177]:
X

array([[[-1.24579929e+00, -2.92176498e-01],
        [-1.33034466e+00, -3.18523467e-01],
        [-1.33317900e+00, -3.27656059e-01],
        ...,
        [-2.31035171e-01,  8.35667791e-01],
        [-1.32679638e-01,  4.89787191e-01],
        [-5.27448015e-02,  1.90412222e-01]],

       [[ 1.10691094e+00,  4.91718224e-01],
        [ 1.17710002e+00,  5.21655622e-01],
        [ 1.17203411e+00,  5.20384510e-01],
        ...,
        [-5.99372615e-01, -6.80917566e-01],
        [-3.42885332e-01, -3.89878988e-01],
        [-1.26636151e-01, -1.42679808e-01]],

       [[ 1.87729445e-02,  1.24485557e+00],
        [ 2.13683243e-02,  1.31919320e+00],
        [ 2.41648558e-02,  1.31359442e+00],
        ...,
        [-9.73924909e-01,  9.18232059e-02],
        [-5.75024614e-01,  4.04881646e-02],
        [-2.19282050e-01, -9.86593105e-04]],

       ...,

       [[ 7.22211107e-01, -1.19409807e+00],
        [ 7.66435403e-01, -1.27157321e+00],
        [ 7.60853510e-01, -1.27396071e+00],
        ...,
     

In [178]:
print(X.shape)
print(y.shape)
print(X.shape[0])

(30000, 475, 2)
(30000, 15, 2)
30000


In [179]:
y[1].shape

(15, 2)

In [180]:
class SymbolDataset(Dataset):
    """Class to create the torch Dataset object
    """
    def __init__(self, X, y):
        super(SymbolDataset, self).__init__()

        self.X = X
        self.y = y
    
    def __getitem__(self, index):
        return self.X[index], self.X[index]
    
    def __len__(self):
        """Total number of samples"""
        return len(self.X[:,0,0])

## 3. Data loaders

In [181]:
# lenght of sequence given to encoder
gt = 475
# length of sequence given to decoder
horizon = 15

# defining batch size
batch_size = 64

# creating torch dataset and dataloaders
symbol_dataset = SymbolDataset(X,y)
train_loader = DataLoader(symbol_dataset, batch_size, shuffle=True, num_workers=0)

In [182]:
len(train_loader)

469

## 4. Training

#### 4.1 create a model

In [183]:
# defining model save location
save_location = "./Transformer_models"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tf_model = model_Transformer_modified.Transformer(encoder_input_size=2, decoder_input_size=2,
                                embedding_size=32, num_heads=4, num_layers=6, feedforward_size=1024).to(device)

#### 4.2 The training loop

In [184]:
def train_epoch(model, optimizer, loss_fn, train_loader, device, print_every):
    # Train:
    model.train()
    train_loss_batches = []

    num_batches = len(train_loader)
    
    for batch_index, (x, y) in enumerate(train_loader, 1):
        
        feature, target = x.float().to(device), y.float().to(device)

        # initializing a sequence for decoding by concatenating a start-of-sequence tensor with the target tensor along the second dimension.
        start_of_seq = torch.Tensor([0,1]).unsqueeze(0).unsqueeze(1).repeat(target.shape[0],1, 1).to(device)
        dec_input = torch.cat((start_of_seq, target[:,:-1,:]), 1)

        optimizer.zero_grad()

        predictions = model.forward(feature, dec_input)
        
        loss = loss_fn(predictions.view(feature.size(0), -1), target.contiguous().view(feature.size(0), -1))

        loss.backward()
        optimizer.step()

        train_loss_batches.append(loss.item())

        # If you want to print your progress more often than every epoch you can
        # set `print_every` to the number of batches you want between every status update.
        # Note that the print out will trigger a full validation on the full val. set => slows down training
        if print_every is not None and batch_index % print_every == 0:
            model.train()
            print(f"\tBatch {batch_index}/{num_batches}: "
                  f"\tTrain loss: {sum(train_loss_batches[-print_every:])/print_every:.3f}")

    return model, train_loss_batches

In [185]:
def training_loop(model, optimizer, loss_fn, train_loader, num_epochs, print_every):
    print("Starting training")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    train_losses = []

    for epoch in range(1, num_epochs+1):
        model, train_loss = train_epoch(model, optimizer, loss_fn, train_loader, device, print_every)

        print(f"Epoch {epoch}/{num_epochs}: "
              f"Train loss: {sum(train_loss)/len(train_loss):.3f}")
        
        train_losses.extend(train_loss)

        if (epoch)%10 == 0:
            # Saving model, loss and error log files
            torch.save({
                'model_state_dict': model.state_dict(),
                'training_loss': train_loss
                }, os.path.join(save_location, 'Transformer_based_channel_model_fsfd3_epoch{}.pth'.format(epoch)))
        
        # if (epoch)%10 == 0:
        # # Saving model, loss and error log files
        #     torch.save(model, os.path.join(save_location, 'Transformer_based_channel_model_epoch{}.pt'.format(epoch)))

    return model, train_losses

#### 4.3 Train the model

In [186]:
# loss function
loss_fn = nn.MSELoss().float()

# creating optimizer
# optimizer = torch.optim.SGD(tf_model.parameters(), lr=1e-4, momentum=0.9, weight_decay=1e-3, nesterov=True)
optimizer = torch.optim.Adam(tf_model.parameters(), lr=1e-4)
# optimizer = torch.optim.AdamW(tf_model.parameters(), lr=1e-2, betas=(0.9, 0.95), weight_decay=1e-1)

# number of epochs 
num_epochs = 100

Trained_model, train_losses = training_loop(tf_model, optimizer, loss_fn, train_loader, num_epochs, print_every=100)

Starting training
	Batch 100/469: 	Train loss: 0.101
	Batch 200/469: 	Train loss: 0.068
	Batch 300/469: 	Train loss: 0.065
	Batch 400/469: 	Train loss: 0.063
Epoch 1/100: Train loss: 0.072
	Batch 100/469: 	Train loss: 0.061
	Batch 200/469: 	Train loss: 0.058
	Batch 300/469: 	Train loss: 0.055
	Batch 400/469: 	Train loss: 0.050
Epoch 2/100: Train loss: 0.054
	Batch 100/469: 	Train loss: 0.038
	Batch 200/469: 	Train loss: 0.031
	Batch 300/469: 	Train loss: 0.026
	Batch 400/469: 	Train loss: 0.021
Epoch 3/100: Train loss: 0.027
	Batch 100/469: 	Train loss: 0.015
	Batch 200/469: 	Train loss: 0.014
	Batch 300/469: 	Train loss: 0.013
	Batch 400/469: 	Train loss: 0.012
Epoch 4/100: Train loss: 0.013
	Batch 100/469: 	Train loss: 0.011
	Batch 200/469: 	Train loss: 0.011
	Batch 300/469: 	Train loss: 0.011
	Batch 400/469: 	Train loss: 0.011
Epoch 5/100: Train loss: 0.011
	Batch 100/469: 	Train loss: 0.010
	Batch 200/469: 	Train loss: 0.010
	Batch 300/469: 	Train loss: 0.010
	Batch 400/469: 	Train

KeyboardInterrupt: 

In [118]:
import torch, gc
gc.collect()
torch.cuda.empty_cache() 