# Imports/setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [None]:
torch.manual_seed(42)

# Loading datasets

In [None]:
class IRMASDataset(Dataset): # Adapted from PyTorch docs: https://docs.pytorch.org/tutorials/beginner/basics/data_tutorial.html
    def __init__(self, data_file, label_file):
        self.data = torch.from_numpy(np.load(data_file))
        self.labels = torch.from_numpy(np.load(label_file))
    def __len__(self):
        return self.labels.size(dim=0)
    def __getitem__(self, idx):
        return torch.index_select(self.data, 0, torch.tensor([idx])), torch.index_select(self.labels, 0, torch.tensor([idx])).item()

In [None]:
train_set = IRMASDataset("X_train.npy", "y_train.npy")
val_set = IRMASDataset("X_val.npy", "y_val.npy")

# Train loops

In [None]:
def train(model, model_name, num_epochs, batch_size, lr): # Adapted from PyTorch docs: https://docs.pytorch.org/tutorials/beginner/introyt/trainingyt.html
    train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_set, batch_size=batch_size, shuffle=True)
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    train_losses, val_losses = [], []
    best_loss = 100000.0
    for epoch in range(num_epochs): #TODO: calculate accuracy
        train_loss, val_loss = 0.0, 0.0
        print("Epoch: {}".format(epoch+1))
        c = 0
        model.train(True)
        for i, data in enumerate(train_dataloader): # Train
            c += 1
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= c
        train_losses.append(train_loss)
        c = 0
        model.eval()
        with torch.no_grad():
            for i, data in enumerate(val_dataloader): # Validation
                c += 1
                inputs, labels = data
                outputs = model(inputs)
                loss = loss_fn(outputs, labels)
                val_loss += loss.item()
            val_loss /= c
            val_losses.append(val_loss)
        print("Train loss: {}, val loss: {}".format(train_loss, val_loss))
        if val_loss < best_loss:
            print("New best val loss!")
            best_loss = val_loss
            model_path = "saved_models/{}_{}".format(model_name, epoch+1)
    torch.save(model.state_dict(), model_path)
    return train_losses, val_losses

def test(model): #TODO
    return

# Models

In [None]:
class BaselineNN(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = nn.Linear(in_features=128*128, out_features=128)
        self.linear2 = nn.Linear(in_features=128, out_features=32)
        self.linear3 = nn.Linear(in_features=32, out_features=11)
    def forward(self, x_raw):
        x = x_raw.view(-1, 128*128)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        y = self.linear3(x)
        return y

In [None]:
baseline = BaselineNN()
baseline_train_loss, baseline_val_loss = train(model=baseline, model_name="baseline", num_epochs=20, batch_size=64, lr=0.005)

In [None]:
class PrimaryNN(nn.Module): #TODO: FINISH BUILDING; THIS IS A PLACEHOLDER
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3, stride=1, padding=0)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.linear1 = nn.Linear(in_features=63*63*3, out_features=11)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        y = self.linear1(x.view(-1, 63*63*3))
        return y

In [None]:
primary = PrimaryNN()
primary_train_loss, primary_val_loss = train(model=primary, model_name="primary", num_epochs=20, batch_size=64, lr=0.001)

# Plotting

In [None]:
plt.scatter(np.arange(1, len(baseline_train_loss)+1, 1, dtype=int), baseline_train_loss, label="Train loss")
plt.scatter(np.arange(1, len(baseline_val_loss)+1, 1, dtype=int), baseline_val_loss, label="Validation loss")
plt.xlabel("Epoch")
plt.ylabel("Cross-Entropy Loss")
plt.legend()
plt.title("Train and Validation Loss Curves for Baseline Model")

In [None]:
plt.scatter(np.arange(1, len(primary_train_loss)+1, 1, dtype=int), primary_train_loss, label="Train loss")
plt.scatter(np.arange(1, len(primary_val_loss)+1, 1, dtype=int), primary_val_loss, label="Validation loss")
plt.xlabel("Epoch")
plt.ylabel("Cross-Entropy Loss")
plt.legend()
plt.title("Train and Validation Loss Curves for Primary Model")