# Import

In [None]:
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))

from typing import Any, Dict, Optional, Tuple

import pandas as pd
import numpy as np

import torch
from torch import nn

import lightning as L
from lightning import LightningDataModule
from torch.utils.data import TensorDataset, ConcatDataset, DataLoader, Dataset, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import transforms

from lightning import LightningModule
from torchmetrics import MinMetric, MeanMetric
from torch.nn.functional import mse_loss
from torchmetrics.classification.accuracy import Accuracy

# Model

In [None]:
class SimpleDenseNet(nn.Module):
    """A simple fully-connected neural net for computing predictions."""

    def __init__(
        self,
        input_size: int = 6,
        lin1_size: int = 64,
        lin2_size: int = 64,
        lin3_size: int = 64,
        output_size: int = 2,
    ) -> None:
        """Initialize a `SimpleDenseNet` module.

        :param input_size: The number of input features.
        :param lin1_size: The number of output features of the first linear layer.
        :param lin2_size: The number of output features of the second linear layer.
        :param lin3_size: The number of output features of the third linear layer.
        :param output_size: The number of output features of the final linear layer.
        """
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(input_size, lin1_size),
            nn.BatchNorm1d(lin1_size),
            nn.ReLU(),
            nn.Linear(lin1_size, lin2_size),
            nn.BatchNorm1d(lin2_size),
            nn.ReLU(),
            nn.Linear(lin2_size, lin3_size),
            nn.BatchNorm1d(lin3_size),
            nn.ReLU(),
            nn.Linear(lin3_size, output_size),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Perform a single forward pass through the network.

        :param x: The input tensor.
        :return: A tensor of predictions.
        """
        batch_size, features = x.size()

        # (batch, 1, width, height) -> (batch, 1*width*height)
        x = x.view(batch_size, -1)

        return self.model(x)


# [Test] Dataset

## Loading Dataset

In [None]:
# Input

filename = 'dataset_left_chest_20240627_162502_402545.csv'

In [None]:
# Pandas Dataframe

csv_path = os.path.join(os.path.dirname(os.getcwd()),'data',filename)
df = pd.read_csv(csv_path)
print(len(df))
df.head()

## Separation of Training and Validation Set

In [None]:
# Unique Motor Commands

unq_theta_left_pan_cmd = df['theta_left_pan_cmd_tminus1'].unique()
unq_theta_tilt_cmd = df['theta_tilt_cmd_tminus1'].unique()

print('unq_theta_left_pan_cmd:', unq_theta_left_pan_cmd)
print('unq_theta_tilt_cmd:', unq_theta_tilt_cmd)

In [None]:
# Separation of Training and Validation Set

# Define the list of choices
val_theta_left_pan = [-14,-10,-6,-2, 2, 6, 10, 14]
train_theta_left_pan = sorted(list(set(unq_theta_left_pan_cmd) - set(val_theta_left_pan)))

print('train_theta_left_pan:', train_theta_left_pan)
print('val_theta_left_pan:', val_theta_left_pan)

In [None]:
# Training Set

train_df = df[df['theta_left_pan_cmd_tminus1'].isin(train_theta_left_pan)].reset_index(drop=True)
print(len(train_df))
train_df.head()

In [None]:
# Validation Set

val_df = df[df['theta_left_pan_cmd_tminus1'].isin(val_theta_left_pan)].reset_index(drop=True)
print(len(val_df))
val_df.head()

## Preprocessing

In [None]:
# Train Set Motor Command Conversion (deg -> rad)

train_df.iloc[:,3:8] = train_df.iloc[:,3:8].apply(np.radians)
train_df.head()

In [None]:
# Validation Set Motor Command Conversion (deg -> rad)

val_df.iloc[:,3:8] = val_df.iloc[:,3:8].apply(np.radians)
val_df.head()

## Datamodule

In [None]:
# Input

batch_size = 32

### Training Set

In [None]:
# Convert the DataFrame to PyTorch tensors
X = torch.tensor(train_df.iloc[:,:6].values, dtype=torch.float32)
y = torch.tensor(train_df.iloc[:,6:8].values, dtype=torch.float32)

# Create a PyTorch dataset
train_dataset = TensorDataset(X, y)

# Create a PyTorch data loader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=31)

### Validation Set

In [None]:
# Convert the DataFrame to PyTorch tensors
X = torch.tensor(val_df.iloc[:,:6].values, dtype=torch.float32)
y = torch.tensor(val_df.iloc[:,6:8].values, dtype=torch.float32)

# Create a PyTorch dataset
val_dataset = TensorDataset(X, y)

# Create a PyTorch data loader
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=31)

# Lightning Datamodule

In [None]:
class LeftChestLitModule(LightningModule):
    """Example of a `LightningModule` for MNIST classification.

    A `LightningModule` implements 8 key methods:

    ```python
    def __init__(self):
    # Define initialization code here.

    def setup(self, stage):
    # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
    # This hook is called on every process when using DDP.

    def training_step(self, batch, batch_idx):
    # The complete training step.

    def validation_step(self, batch, batch_idx):
    # The complete validation step.

    def test_step(self, batch, batch_idx):
    # The complete test step.

    def predict_step(self, batch, batch_idx):
    # The complete predict step.

    def configure_optimizers(self):
    # Define and configure optimizers and LR schedulers.
    ```

    Docs:
        https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
    """

    def __init__(
        self,
        net: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler,
        compile: bool,
    ) -> None:
        """Initialize a `MNISTLitModule`.

        :param net: The model to train.
        :param optimizer: The optimizer to use for training.
        :param scheduler: The learning rate scheduler to use for training.
        """
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.net = net

        # loss function
        self.criterion = mse_loss

        # for averaging loss across batches
        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()
        self.test_loss = MeanMetric()

        # for tracking best so far validation accuracy
        self.val_loss_best = MinMetric()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Perform a forward pass through the model `self.net`.

        :param x: A tensor of images.
        :return: A tensor of logits.
        """
        return self.net(x)

    def on_train_start(self) -> None:
        """Lightning hook that is called when training begins."""
        # by default lightning executes validation step sanity checks before training starts,
        # so it's worth to make sure validation metrics don't store results from these checks
        self.val_loss.reset()
        self.val_loss_best.reset()

    def model_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Perform a single model step on a batch of data.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.

        :return: A tuple containing (in order):
            - A tensor of losses.
            - A tensor of predictions.
            - A tensor of target labels.
        """
        x, y = batch
        preds = self.forward(x)
        loss = self.criterion(preds, y)
        return loss, preds, y

    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        """Perform a single training step on a batch of data from the training set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        :return: A tensor of losses between model predictions and targets.
        """
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        self.train_loss(loss)
        self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)

        # return loss or backpropagation will fail
        return loss

    def on_train_epoch_end(self) -> None:
        "Lightning hook that is called when a training epoch ends."
        pass

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        """Perform a single validation step on a batch of data from the validation set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        """
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        self.val_loss(loss)
        # print("val/loss:", self.val_loss.compute())  # print
        self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)

    def on_validation_epoch_end(self) -> None:
        "Lightning hook that is called when a validation epoch ends."
        loss = self.val_loss.compute()  # get current val acc
        self.val_loss_best(loss)  # update best so far val acc
        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
        # otherwise metric would be reset by lightning after each epoch
        print("val/loss_best:", self.val_loss_best.compute())  # print
        self.log("val/loss_best", self.val_loss_best.compute(), sync_dist=True, prog_bar=True)

    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        """Perform a single test step on a batch of data from the test set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        """
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        self.test_loss(loss)
        self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)

    def on_test_epoch_end(self) -> None:
        """Lightning hook that is called when a test epoch ends."""
        pass

    def setup(self, stage: str) -> None:
        """Lightning hook that is called at the beginning of fit (train + validate), validate,
        test, or predict.

        This is a good hook when you need to build models dynamically or adjust something about
        them. This hook is called on every process when using DDP.

        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
        """
        if self.hparams.compile and stage == "fit":
            self.net = torch.compile(self.net)

    def configure_optimizers(self) -> Dict[str, Any]:
        """Choose what optimizers and learning-rate schedulers to use in your optimization.
        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        Examples:
            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers

        :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
        """
        # optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        if self.hparams.scheduler is not None:
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": "epoch",
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}


# Training

In [None]:
# Iniitialization

dense_net = LeftChestLitModule(SimpleDenseNet(), None, None, False)
trainer = L.Trainer(min_epochs=1, max_epochs=100)

In [None]:
# Training

trainer.fit(model=dense_net, train_dataloaders=train_loader, val_dataloaders=val_loader)