In [None]:
import numpy as np
from utils import parse_csi, get_data_files, generate_labels, split_csi
from nn_utils import CSIModel, train, evaluate, ComplexCSIModel


import numpy as np
import torch
from torch.utils.data import DataLoader, random_split, TensorDataset
import torch.nn as nn
import torch.optim as optim

import os 

In [None]:

file_paths = get_data_files(data_dir = '../data/')
file_paths = [file_path for file_path in file_paths if 'library' in file_path]
file_paths = sorted(file_paths)
print(file_paths)

In [None]:
macs = ['A0:A3:B3:AA:76:38']

data, found_macs = [], [] 
for file_path in file_paths: 
    csi, mac = parse_csi(file_path, macs)
    print(csi.shape)
    csi = split_csi(csi, 20) 
    if file_path.split('/')[-1].split('.')[0][-1].isdigit():
        data[-1] = np.concatenate((data[-1], csi), axis = 0)
        found_macs[-1] = np.concatenate((found_macs[-1], mac), axis = 0) 
    else:
        data.append(csi)
        found_macs.append(mac)



In [None]:
min_len = min([len(d) for d in data])
data = [d[:min_len] for d in data]


In [None]:
labels = generate_labels(data) 

data = np.concatenate(data, axis = 0)
labels = np.concatenate(labels, axis = 0)
data.shape, labels.shape

In [None]:

t_data, t_labels = torch.tensor(data, dtype = torch.float32), torch.tensor(labels, dtype = torch.long)
dataset = TensorDataset(t_data, t_labels)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size = 32, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = 32, shuffle = True)

input_size, hidden_size, output_size = 384, 100, len(np.unique(labels)) 
model = ComplexCSIModel(input_size, hidden_size, num_layers = 20, output_size = output_size, dropout_rate = 0.1)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.05)

train(model, train_loader, criterion, optimizer, num_epochs = 20) 


In [None]:
tars, preds = evaluate(model, val_loader)

In [None]:
# calculate metrics 
from sklearn.metrics import confusion_matrix, classification_report
cm = confusion_matrix(tars, preds)
print(cm)
print(classification_report(tars, preds))
