In [None]:
import pandas as pd
import glob
from typing import Tuple
import torch
from torch import nn
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from src.small_dataset import ReviewsDataModule

In [None]:
# Hyperparameters
batch_size = 64
reviews_history_size = 2
epochs = 1
learning_rate = 1e-3

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(2 + 4 * reviews_history_size, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)


class NNMemoryModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = NeuralNetwork()
        self.loss_fn = nn.BCELoss()

    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)
        loss = self.loss_fn(pred, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=learning_rate)

    def test_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)
        loss = self.loss_fn(pred, y)
        self.log("test_loss", loss)


model = NNMemoryModel()
data = ReviewsDataModule(batch_size, reviews_history_size)
wandb_logger = WandbLogger(project="Memory ML")

trainer = L.Trainer(max_epochs=epochs, logger=wandb_logger)
trainer.fit(model, data)
trainer.test(model, data)