In [None]:
import torch
import torch.nn as nn
import lightning as pl
from torchmetrics import Accuracy, F1Score

from src.data.dataset import SensorDataModule
from src.data.partition_helper import get_partitioned_data, get_partition_paths

dataset = SensorDataModule(get_partition_paths("./data/splits", k_folds=5), batch_size=32)

dataset.setup()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics import Accuracy, F1Score

class BidirectionalLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=True, dropout=dropout if num_layers > 1 else 0.0)

    def forward(self, x):
        x, _ = self.lstm(x)
        return x

class ResidualLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout, num_blocks):
        super().__init__()
        output_dim = hidden_dim * 2
        self.layers = nn.ModuleList()
        self.batch_norm = nn.BatchNorm1d(output_dim)

        for i in range(num_blocks):
            self.layers.append(
                nn.Sequential(
                    BidirectionalLayer(input_dim if i == 0 else output_dim, hidden_dim, num_layers, dropout),
                    nn.Linear(output_dim, output_dim),
                    nn.ReLU()
                )
            )

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            residual = x if i == 0 else self.layers[i-1][1](x)
            x = layer(x) + residual
        x = x.transpose(1, 2)
        x = self.batch_norm(x)
        x = x.transpose(1, 2)
        return x

class DeepBidirectionalLSTMs(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout, num_blocks, output_dim):
        super().__init__()
        self.save_hyperparameters()
        self.lstm_networks = ResidualLayer(input_dim, hidden_dim, num_layers, dropout, num_blocks)
        self.final_fc = nn.Linear(hidden_dim * 2, output_dim)
        self.accuracy = Accuracy(task='multiclass', num_classes=output_dim)
        self.f1_score = F1Score(num_classes=output_dim, average='weighted', task='multiclass')

    def forward(self, x):
        x = self.lstm_networks(x)
        x = x[:, -1, :]  # taking last timestep's output
        x = self.final_fc(x)
        return x

    def _shared_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)

        preds = torch.argmax(logits, dim=1) 
        loss = F.cross_entropy(logits, y.float())

        y = torch.argmax(y, dim=1)
        acc = self.accuracy(preds, y)
        f1 = self.f1_score(preds, y) 
        print(f"preds: {preds}, y: {y}, acc: {acc}, f1: {f1}")

        return loss, acc, f1


    def training_step(self, batch, batch_idx):
        loss, acc, f1 = self._shared_step(batch, batch_idx)
        self.log_dict({"train_loss": loss, "train_acc": acc, "train_f1": f1}, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc, f1 = self._shared_step(batch, batch_idx)
        self.log_dict({"val_loss": loss, "val_acc": acc, "val_f1": f1}, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)



In [None]:

for key, fold in dataset.data_dict.items():
    train_dataloader, val_dataloader = fold['train'], fold['validate']
    trainer = pl.Trainer(max_epochs=20, devices=1, accelerator='mps', log_every_n_steps=10)
    model = DeepBidirectionalLSTMs(input_dim=48, hidden_dim=24, output_dim=dataset.num_classes, num_layers=3, dropout=0.2, num_blocks=2)
    trainer.fit(model, train_dataloader, val_dataloader)
    break


In [None]:
# Assuming `train_dataloader` is already defined and available
# Fetch the first batch
first_batch = next(iter(dataset.data_dict[0]['validate']))

# Unpack the first batch
data_tensors, labels = first_batch

# Print shapes and types to understand the structure
print("Data tensors shape:", data_tensors.shape)
print("Data tensors type:", type(data_tensors))
print("Labels shape:", labels.shape)
print("Labels type:", type(labels))

# If your tensors are dictionaries (which might be causing the error), print the keys
if isinstance(data_tensors, dict):
    print("Data tensor keys:", data_tensors.keys())

# Optionally, you can visualize or print part of the tensors to understand the actual data
print("First few data points:", data_tensors[:5])  # Adjust slicing based on your data size
print("First few labels:", labels[:5])
