In [None]:
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
import pickle


# Load the dataset
dataset = pickle.load(open('datasets.pickle', 'rb'))


batchsize = 32
trainset = dataset[0]
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batchsize, shuffle=True)
valset = dataset[1]
valloader = torch.utils.data.DataLoader(valset, batch_size=batchsize, shuffle=False)  # Disable shuffling
testset = dataset[2]
testloader = torch.utils.data.DataLoader(testset, batch_size=batchsize, shuffle=False)  # Disable shuffling

input_dim = 4
output_dim = trainset[0][1].shape[0]

class MLP(pl.LightningModule):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(MLP, self).__init__()
        self.emb1 = torch.nn.Embedding(100000, 32)
        self.emb2 = torch.nn.Embedding(100000, 32)
        self.timeday = torch.nn.Linear(2, 32)
        self.class1 = torch.nn.Linear(96, hidden_channels)
        
        self.class2 = torch.nn.Linear(hidden_channels, 32)
       
        self.class3 = torch.nn.Linear(32, 16)
        
        self.class4 = torch.nn.Linear(16, out_channels)
        self.train_losses = []
        self.val_losses = []
        self.test_losses = []
        self.train_acc_at_k = []
        self.val_acc_at_k = []
        self.test_acc_at_k = []
        self.train_mrr = []
        self.val_mrr = []
        self.test_mrr = []
        self.epochs_trained = 0

    def forward(self, data):
        user = self.emb1(data[:, 0])
        loc = self.emb2(data[:, 1])
        hour = data[:, 2]
        day = data[:, 3]
        timeday = torch.stack([hour, day], dim=-1)
        timeday = self.timeday(timeday.float())
        x = torch.cat((user, loc, timeday), dim=1)
        x = self.class1(x).relu()
        
        x = self.class2(x).relu()
        
        x = self.class3(x).relu()
        
        x = self.class4(x)
        return x

    def training_step(self, batch, batch_idx):
        input = batch[0]
        target = batch[1]
        output = self(input[:, :4])
        loss = F.cross_entropy(output, target)
        acc_at_k_value = self.accuracy_at_k(output, target, k=5)
        mrr_value = self.mean_reciprocal_rank(output, target)
        self.log('train_loss', loss, on_epoch=True, on_step=False)
        self.log('train_acc_at_k', acc_at_k_value, on_epoch=True, on_step=False)
        self.log('train_mrr', mrr_value, on_epoch=True, on_step=False)

        # Debug statements
        print(f"Epoch {self.current_epoch} Training loss: {loss.item()} Accuracy@k: {acc_at_k_value.item()} MRR: {mrr_value.item()}")

        return loss

    def validation_step(self, batch, batch_idx):
        input = batch[0]
        target = batch[1]
        output = self(input[:, :4])
        val_loss = F.cross_entropy(output, target)
        acc_at_k = self.accuracy_at_k(output, target, k=5)
        mrr = self.mean_reciprocal_rank(output, target)
        self.log('val_loss', val_loss, on_epoch=True, on_step=False)
        self.log('val_acc_at_k', acc_at_k, on_epoch=True, on_step=False)
        self.log('val_mrr', mrr, on_epoch=True, on_step=False)

        # Debug statements
        print(f"Epoch {self.current_epoch} Validation loss: {val_loss.item()} Accuracy@k: {acc_at_k.item()} MRR: {mrr.item()}")

        return val_loss

    def test_step(self, batch, batch_idx):
        input = batch[0]
        target = batch[1]
        output = self(input[:, :4])
        test_loss = F.cross_entropy(output, target)
        acc_at_k = self.accuracy_at_k(output, target, k=5)
        mrr = self.mean_reciprocal_rank(output, target)
        self.log('test_loss', test_loss)
        self.log('test_acc_at_k', acc_at_k)
        self.log('test_mrr', mrr)
        self.test_losses.append(test_loss.item())
        self.test_acc_at_k.append(acc_at_k.item())
        self.test_mrr.append(mrr.item())

        # Debug statements
        print(f"Test loss: {test_loss.item()} Accuracy@k: {acc_at_k.item()} MRR: {mrr.item()}")

        return test_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001, weight_decay=5e-4)
        return optimizer

    def accuracy_at_k(self, y_pred, y_true, k=10):
        _, top_k = y_pred.topk(k, dim=1)
        _, label = y_true.topk(1, dim=1)
        correct = top_k.eq(label.view(-1, 1).expand_as(top_k))
        acc_at_k = correct.float().sum(dim=1).mean()
        return acc_at_k

    def mean_reciprocal_rank(self, y_pred, y_true):
        _, rank = y_pred.sort(dim=1, descending=True)
        rank = rank.argsort(dim=1)
        _, label = y_true.topk(1, dim=1)
        rr = (1.0 / (rank.gather(1, label.view(-1, 1).long()) + 1)).mean()
        return rr

    def on_train_epoch_end(self):
        # Collect training metrics once per epoch
        epoch_loss = self.trainer.callback_metrics['train_loss']
        epoch_acc_at_k = self.trainer.callback_metrics['train_acc_at_k']
        epoch_mrr = self.trainer.callback_metrics['train_mrr']
        self.train_losses.append(epoch_loss.item())
        self.train_acc_at_k.append(epoch_acc_at_k.item())
        self.train_mrr.append(epoch_mrr.item())

        print(f"End of epoch {self.current_epoch} - Training loss: {epoch_loss.item()}, Accuracy@k: {epoch_acc_at_k.item()}, MRR: {epoch_mrr.item()}")

    def on_validation_epoch_end(self):
        # Collect validation metrics once per epoch
        epoch_val_loss = self.trainer.callback_metrics['val_loss']
        epoch_val_acc_at_k = self.trainer.callback_metrics['val_acc_at_k']
        epoch_val_mrr = self.trainer.callback_metrics['val_mrr']
        self.val_losses.append(epoch_val_loss.item())
        self.val_acc_at_k.append(epoch_val_acc_at_k.item())
        self.val_mrr.append(epoch_val_mrr.item())

        print(f"End of epoch {self.current_epoch} - Validation loss: {epoch_val_loss.item()}, Accuracy@k: {epoch_val_acc_at_k.item()}, MRR: {epoch_val_mrr.item()}")

class MetricsLengthCallback(Callback):
    def on_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch} - Training losses length: {len(pl_module.train_losses)}, Validation losses length: {len(pl_module.val_losses)}")

numepoch = 200

model = MLP(input_dim, 64, output_dim)
criterion = torch.nn.CrossEntropyLoss()
trainer = pl.Trainer(
    max_epochs=numepoch,
    log_every_n_steps=1,
    callbacks=[
        ModelCheckpoint(monitor='train_loss'),
        EarlyStopping(monitor='train_loss', patience=10),
        MetricsLengthCallback()  # Add the custom callback here
    ]
)
trainer.fit(model, trainloader, valloader)


print(model)

# Check the lengths of the lists after training
print(f"Final Training losses length: {len(model.train_losses)}")
print(f"Final Validation losses length: {len(model.val_losses)}")

In [None]:

test_result = trainer.test(model, dataloaders=testloader)[0]
test_losses = model.test_losses
train_losses = model.train_losses
val_losses = model.val_losses
print(f"Test Loss: {test_result['test_loss']}")
print(f"Test Accuracy@5: {test_result['test_acc_at_k']}")
print(f"Test MRR: {test_result['test_mrr']}")