In [8]:
import numpy
import torch
from torch.utils.data import Dataset, DataLoader
import torchmetrics
from torchmetrics import Accuracy

In [9]:
windows_l:numpy.ndarray = numpy.random.randint(0, 100, size=(61, 100, 19)) + 19
windows_r:numpy.ndarray = numpy.random.randint(0, 100, size=(61, 100, 19))
y_l:numpy.ndarray = numpy.zeros((61, 1))
y_r:numpy.ndarray = numpy.ones((61, 1))


### Train
windows_train_l:numpy.ndarray = windows_l[:45]
windows_train_r:numpy.ndarray = windows_r[:45]
y_train_l:numpy.ndarray = y_l[:45]
y_train_r:numpy.ndarray = y_r[:45]

### Test 
windows_test_l:numpy.ndarray = windows_l[45:]
windows_test_r:numpy.ndarray = windows_r[45:]
y_test_l:numpy.ndarray = y_l[45:]
y_test_r:numpy.ndarray = y_r[45:]

In [11]:
class SignalsDataset(torch.nn.Module):
  def __init__(self, X_l:numpy.ndarray, X_r:numpy.ndarray, y_l:numpy.ndarray, y_r:numpy.ndarray)->None:
    self.X:torch.Tensor = torch.tensor(numpy.concatenate((X_l, X_r), axis=0), dtype=torch.float32)
    self.y:torch.Tensor = torch.tensor(numpy.concatenate((y_l, y_r), axis=0), dtype=torch.float32)
    
  def __getitem__(self, index:int):
    return (self.X[index], self.y[index])

  def __len__(self)->int:
    return len(self.X)

In [5]:
train_dataset:SignalsDataset = SignalsDataset(
  windows_train_l,
  windows_train_r,
  y_train_l,
  y_train_r)

test_dataset:SignalsDataset = SignalsDataset(
  windows_test_l,
  windows_test_r,
  y_test_l,
  y_test_r)

In [15]:
device:torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size:int = 10
num_epochs:int = 200
learning_rate:float = 1e-3

In [12]:
train_loader:DataLoader = DataLoader(
  train_dataset,
  batch_size, 
  shuffle=True,
  num_workers=2
)

test_loader:DataLoader = DataLoader(
  test_dataset,
  batch_size, 
  shuffle=True,
  num_workers=2
)

In [13]:
class MLP(torch.nn.Module):
  def __init__(self)->None:
    super(MLP, self).__init__()

    self.flatten:torch.nn.Flatten = torch.nn.Flatten() # (10, 100, 19) -> (10, 1900)
    self.fc1:torch.nn.Linear = torch.nn.Linear(100*19, 100) # (10, 1900) x (1900, 100) -> (10, 100)
    self.fc2:torch.nn.Linear = torch.nn.Linear(100, 1) # (10, 100) x (100, 1) -> (10, 1)
    self.relu:torch.nn.ReLU = torch.nn.ReLU(inplace=True)
    self.softmax:torch.nn.Softmax = torch.nn.Softmax(dim=1)

  def forward(self, X:torch.Tensor)->torch.Tensor:
    out:torch.Tensor = self.fc1(self.flatten(X))
    out:torch.Tensor = self.relu(out)
    out:torch.Tensor = self.fc2(out)
    out:torch.Tensor = self.softmax(out)
    return out

In [16]:
model:MLP = MLP().to(device)

In [20]:
def train_model(train_loader:DataLoader, model, num_epochs:int, lr:float, device:torch.device, save_path:str):
    criterion:torch.nn.BCELoss = torch.nn.BCELoss()
    optimizer:torch.optim.Adam = torch.optim.Adam(model.parameters(), lr=learning_rate)
    for epoch in tqdm(range(num_epochs)):
        model.train()
        for _, (features, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs:torch.Tensor = model(features.to(device))
            loss:torch.Tensor = criterion(outputs.to(device), labels.to(device))
            loss.backward()
            optimizer.step()
        torch.save(model, save_path)
        print(f"loss: {loss.detach().cpu().numpy():.4f}")
    return model

In [18]:
def eval_model(test_loader:DataLoader, model, device:torch.device)->list[float]:
    accuracy:Accuracy = Accuracy('binary')
    accuracies:list[float] = []
    for _, (features, labels) in enumerate(test_loader):
        model.eval()
        with torch.no_grad():
            outputs:torch.Tensor = model(features.to(device))
            batch_accuracy:torch.Tensor = accuracy(outputs.to(device), labels.to(device))
            accuracies.append(batch_accuracy.detach().cpu().numpy())
    return accuracies

In [None]:
train_model(
  train_loader,
  model,
  num_epochs,
  learning_rate,
  device,
  "./models/best_mlp_model.pth")

In [None]:
accuracies:list[float] = eval_model(
  test_loader,
  model,
  device
)

In [None]:
print(f"Accuracy: {numpy.mean(accuracies)}")