In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import numpy as np

## Define parameters

In [None]:
learning_rate = 1e-3
input_size = 146
output_size = 10
batch_size = 64
num_epochs = 100
num_workers = 8

## Define the neural network architecture
Layer options (More info at https://pytorch.org/docs/stable/nn.html):
+ Linear: fully connected layer
+ Conv1d/Conv2d: Convolutional layers
+ BatchNorm2d/LayerNorm/InstanceNorm2d: Normalization layers
+ Dropout: Dropout layer
+ MaxPool2d/AvgPool2d: Pooling layers

In [None]:
class NN(pl.LightningModule):
    def __init__(self, learning_rate, input_size, output_size):
        super(NN, self).__init__()
        self.layer1 = nn.Linear(input_size, 20)
        self.layer2 = nn.Linear(20, output_size)
        self.relu = nn.ReLU()
        self.loss_fn = nn.MSELoss()
    
    def forward(self, x):
        x = self.relu(self.layer1(x))
        x = self.layer2(x)
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self.forward(inputs)
        loss = self.loss_fn(outputs, targets)
        return loss
    
    def test_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self.forward(inputs)
        loss = self.loss_fn(outputs, targets)
        return loss



## Custom dataset class (TODO)

In [None]:
# TODO
class CustomDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, num_workers):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

    def prepare_data(self):
        #TODO
        pass

    def setup(self, stage):
        #todo
        self.train_ds, self.val_ds = random_split(entire_dataset, [50000, 10000])
        self.test_ds = datasets.MNIST(
            root=self.data_dir,
            train=False,
            transform=transforms.ToTensor(),
            download=False,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

In [None]:
class DistanceCallback(pl.Callback):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset

    def on_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch == trainer.max_epochs - 1:
            # Calculate distances for the last epoch
            distances = self.calculate_distances(pl_module, self.dataset)
            
            # Plot average and maximum distances
            self.plot_distances(distances)

    def calculate_distances(self, pl_module, dataset):
        pl_module.eval()
        distances = []

        for data in DataLoader(dataset, batch_size=1):
            inputs, targets = data
            with torch.no_grad():
                outputs = pl_module(inputs)
            distances.append(self.calculate_distance(outputs, targets))

        return distances

    def calculate_distance(self, output, target):
        # Replace this with your distance calculation method
        # Here, I'm using L2 distance for illustration purposes
        distance = torch.norm(output - target)
        return distance.item()

    def plot_distances(self, distances):
        avg_distance = np.mean(distances)
        max_distance = np.max(distances)

        plt.figure(figsize=(8, 5))
        plt.hist(distances, bins=50, color='skyblue', edgecolor='black')
        plt.axvline(avg_distance, color='red', linestyle='dashed', linewidth=2, label=f'Average Distance: {avg_distance:.2f}')
        plt.axvline(max_distance, color='green', linestyle='dashed', linewidth=2, label=f'Max Distance: {max_distance:.2f}')
        plt.title('Distance Distribution')
        plt.xlabel('Distance')
        plt.ylabel('Frequency')
        plt.legend()
        plt.show()

In [None]:
# Set device cuda for GPU if it's available otherwise run on the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lightning_model = NN(learning_rate=learning_rate, input_size=input_size, output_size=output_size).to(device)


# TODO
data_dir = './data'
dm = CustomDataModule(data_dir, batch_size, num_workers)

# Train and plot loss over epochs, as well as average and maximum difference for every sample in the last epoch

In [None]:
callbacks = [pl.callbacks.LearningRateMonitor(logging_interval='epoch'),
             pl.callbacks.ProgressBar(),
             DistanceCallback(dataset=dm.)]
trainer = pl.Trainer(max_epochs=num_epochs, gpus=1, callbacks=callbacks)  # Adjust parameters as needed

# Train the model
trainer.fit(lightning_model, dm)

# Get results from testing: loss, accuracy, recall, F1 score...

In [None]:
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
trainer.test(lightning_model, dm)
