In [None]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import Accuracy, F1Score

In [None]:
class BidirectionalLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout):
        super().__init__()
        self.lstm = nn.LSTM(input_dim,
                            hidden_dim,
                            num_layers=num_layers,
                            batch_first=True,
                            bidirectional=True,
                            dropout=dropout if num_layers > 1 else 0.0)

    def forward(self, x):
        print(f"Before LSTM: {x.shape}")
        x, _ = self.lstm(x)
        print(f"After LSTM: {x.shape}")
        return x


class ResidualLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout, num_blocks):
        super().__init__()

        self.first_linear = nn.Linear(input_dim, input_dim)

        output_dim = hidden_dim * 2

        self.layers = nn.ModuleList()
        for i in range(num_blocks):
            # input_dim = input_dim * (i + 1)

            print(f"Block {i}, input_dim: {input_dim}, output_dim: {output_dim}")
            self.layers.append(
                nn.Sequential(
                    BidirectionalLayer(input_dim, hidden_dim, num_layers, dropout),
                    nn.Linear(output_dim, input_dim),
                    nn.ReLU(),
                )
            )

        self.batch_norm = nn.BatchNorm1d(input_dim)

    def forward(self, x):
        x = self.first_linear(x)
        x = F.relu(x)

        residual = None
        for i, layer in enumerate(self.layers):
            # residual = x if i == 0 else self.layers[i - 1][1](x)
            layer_output = layer(x)

            print(
                f"Layer {i} input shape: {x.shape}, residual shape: {residual.shape if residual is not None else None}, layer_output shape: {layer_output.shape}")

            if i != 0:
                x = layer_output + residual
            else:
                x = layer_output

            residual = x

        x = x.transpose(1, 2)
        x = self.batch_norm(x)
        x = x.transpose(1, 2)
        return x


class DeepResBidirLSTM(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout, num_blocks, output_dim):
        super().__init__()
        self.save_hyperparameters()

        self.residual_layer = ResidualLayer(input_dim, hidden_dim, num_layers, dropout, num_blocks)
        self.final_fc = nn.Linear(input_dim, output_dim)

        self.accuracy = Accuracy(task='multiclass', num_classes=output_dim)
        self.f1_score = F1Score(num_classes=output_dim, average='weighted', task='multiclass')

    def forward(self, x):
        x = self.residual_layer(x)
        x = x[:, -1, :]  # taking last timestep's output
        x = self.final_fc(x)
        return x

    def _shared_step(self, batch, batch_idx):
        x, y = batch

        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        y = torch.argmax(y, dim=1)
        loss = F.cross_entropy(logits, y.float())
        acc = self.accuracy(preds, y)
        f1 = self.f1_score(preds, y)
        print(f"preds: {preds}, y: {y}, acc: {acc}, f1: {f1}")

        return loss, acc, f1

    def training_step(self, batch, batch_idx):
        loss, acc, f1 = self._shared_step(batch, batch_idx)
        self.log_dict({"train_loss": loss, "train_acc": acc, "train_f1": f1}, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc, f1 = self._shared_step(batch, batch_idx)
        self.log_dict({"val_loss": loss, "val_acc": acc, "val_f1": f1}, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)



In [None]:
from src.data.sensor_datamodule import SensorDataModule
from src.utils import get_partition_paths

dataset = SensorDataModule(32, partition_paths=get_partition_paths("./data/partitions", k_folds=5))
dataset.setup()

for key, fold in dataset.data_dict.items():
    train_dataloader, val_dataloader = fold['train'], fold['validate']
    trainer = pl.Trainer(max_epochs=5,
                         devices=1,
                         accelerator='mps',

                         gradient_clip_val=15,
                         gradient_clip_algorithm='norm',

                         log_every_n_steps=10)

    model = DeepResBidirLSTM(input_dim=16,
                             hidden_dim=24,
                             output_dim=dataset.num_classes,
                             num_layers=3,
                             dropout=0.2,
                             num_blocks=2)

    trainer.fit(model, train_dataloader, val_dataloader)
    break  # only train one fold for now