In [None]:
# Cell 1: Import Statements and Settings
import logging
import lightning as pl
import torch
import wandb
from callbacks import create_callbacks
from lit_datamodule import inD_RecordingModule
from lit_module import LitModule
from utils import create_wandb_logger, get_data_path, build_module
from nn_modules import ConstantVelocityModel, MultiLayerPerceptron, LSTMModel
from select_features import select_features

torch.set_float32_matmul_precision('medium')
torch.autograd.set_detect_anomaly(True)

In [None]:
# Cell 2: Initialization
data_path, log_path = get_data_path()
wandb.login()

project_name = "SS2024_motion_prediction"
stage = input("Please enter the stage (fit or test): ").strip().lower()
recording_ID = 26

In [None]:
# Cell 3: Feature Selection
features, number_of_features = select_features()

In [None]:
# Common Parameters
past_sequence_length = 6
future_sequence_length = 3
sequence_length = past_sequence_length + future_sequence_length
batch_size = 32

In [None]:
# Cell 4: Model Selection
print("Please select the model you wish to use:")
print("1: Multi-Layer Perceptron (MLP)")
print("2: Long Short-Term Memory (LSTM)")
print("3: Constant Velocity Model")
print("4: Constant Acceleration Model")
model_choice = int(input("Enter the number corresponding to your model choice: ").strip())

In [None]:
if model_choice == 1:
    # MLP Model Configuration
    input_size = number_of_features * past_sequence_length
    output_size = number_of_features
    hidden_size = 32
    epochs = 3
    
    mdl = MultiLayerPerceptron(input_size, hidden_size, output_size)
    
    dm = inD_RecordingModule(
        data_path, recording_ID, sequence_length, past_sequence_length, 
        future_sequence_length, features, batch_size=batch_size
    )
    
    model_params = {
        'model': mdl,
        'number_of_features': number_of_features,
        'sequence_length': sequence_length,
        'past_sequence_length': past_sequence_length,
        'future_sequence_length': future_sequence_length,
        'batch_size': batch_size
    }
    model = LitModule(**model_params)

    # Data Module Setup
    dm.setup(stage=stage)

    # Log Path
    print(log_path)

    # Callbacks and WandB Logger
    callbacks = create_callbacks()
    wandb_logger = create_wandb_logger(log_path, project_name, recording_ID)
    wandb_logger.experiment.config.update({
        "batch_size": batch_size,
        "sequence_length": sequence_length
    })
    logging.getLogger(f"{log_path}lightning").setLevel(logging.ERROR)

    # Trainer Setup
    trainer = pl.Trainer(max_epochs=epochs,
                         fast_dev_run=False,
                         devices="auto",
                         accelerator="auto",
                         log_every_n_steps=5,
                         logger=wandb_logger,
                         callbacks=callbacks,
                         check_val_every_n_epoch=1,
                         precision=32)

    # Training or Testing
    if stage == "fit":
        trainer.fit(model, dm)
    elif stage == "test":
        checkpoint_path = input("Please enter the checkpoint path: ").strip()
        model = LitModule.load_from_checkpoint(checkpoint_path, **model_params)
        trainer.test(model, dm)
        wandb.finish()

In [None]:
if model_choice == 2:
    # LSTM Model Configuration
    input_size = number_of_features
    output_size = number_of_features  # Assuming output is a 2D coordinate
    hidden_size = 64
    num_layers = 2
    epochs = 3
    mdl = LSTMModel(input_size, hidden_size, num_layers, output_size, future_sequence_length)
    
    dm = inD_RecordingModule(
        data_path, recording_ID, sequence_length, past_sequence_length, 
        future_sequence_length, features, batch_size=batch_size, 
        use_lstm=True
    )
    
    model_params = {
        'model': mdl,
        'number_of_features': number_of_features,
        'sequence_length': sequence_length,
        'past_sequence_length': past_sequence_length,
        'future_sequence_length': future_sequence_length,
        'batch_size': batch_size,
        'use_lstm': True
    }
    model = LitModule(**model_params)

    # Data Module Setup
    dm.setup(stage=stage)

    # Log Path
    print(log_path)

    # Callbacks and WandB Logger
    callbacks = create_callbacks()
    wandb_logger = create_wandb_logger(log_path, project_name, recording_ID)
    wandb_logger.experiment.config.update({
        "batch_size": batch_size,
        "sequence_length": sequence_length
    })
    logging.getLogger(f"{log_path}lightning").setLevel(logging.ERROR)

    # Trainer Setup
    trainer = pl.Trainer(max_epochs=epochs,
                         fast_dev_run=False,
                         devices="auto",
                         accelerator="auto",
                         log_every_n_steps=5,
                         logger=wandb_logger,
                         callbacks=callbacks,
                         check_val_every_n_epoch=1,
                         precision=32)

    # Training or Testing
    if stage == "fit":
        trainer.fit(model, dm)
    elif stage == "test":
        checkpoint_path = input("Please enter the checkpoint path: ").strip()
        model = LitModule.load_from_checkpoint(checkpoint_path, **model_params)
        trainer.test(model, dm)
        wandb.finish()


In [None]:
if model_choice == 3:
    # Constant Velocity Model Configuration
    mdl = ConstantVelocityModel()
    
    dm = inD_RecordingModule(
        data_path, recording_ID, sequence_length, past_sequence_length, 
        future_sequence_length, features, batch_size=batch_size
    )
    
    model_params = {
        'model': mdl,
        'number_of_features': number_of_features,
        'sequence_length': sequence_length,
        'past_sequence_length': past_sequence_length,
        'future_sequence_length': future_sequence_length,
        'batch_size': batch_size
    }
    model = LitModule(**model_params)

    # Data Module Setup
    dm.setup(stage=stage)

    # Log Path
    print(log_path)

    # Testing
    if stage == "test":
        # Direct prediction using Constant Velocity Model
        test_dataloader = dm.test_dataloader()
        model.eval()
        with torch.no_grad():
            for batch in test_dataloader:
                x, y = batch
                y_hat = model(x)
                loss = model.loss_function(y_hat, y)
                print(f"Test loss: {loss.item()}")

In [None]:
if model_choice == 4:
    # Constant Acceleration Model Configuration
    class ConstantAccelerationModel(torch.nn.Module):
        def __init__(self, dt=1.0):
            super(ConstantAccelerationModel, self).__init__()
            self.dt = dt

        def forward(self, x):
            x = x[:, -1, :]
            return x + self.dt * x + 0.5 * (self.dt ** 2) * x

        def loss_function(self, y_hat, y):
            return torch.nn.functional.mse_loss(y_hat, y)
    
    mdl = ConstantAccelerationModel()
    
    dm = inD_RecordingModule(
        data_path, recording_ID, sequence_length, past_sequence_length, 
        future_sequence_length, features, batch_size=batch_size
    )
    
    model_params = {
        'model': mdl,
        'number_of_features': number_of_features,
        'sequence_length': sequence_length,
        'past_sequence_length': past_sequence_length,
        'future_sequence_length': future_sequence_length,
        'batch_size': batch_size
    }
    model = LitModule(**model_params)

    # Data Module Setup
    dm.setup(stage=stage)

    # Log Path
    print(log_path)

    # Testing
    if stage == "test":
        # Direct prediction using Constant Acceleration Model
        test_dataloader = dm.test_dataloader()
        model.eval()
        with torch.no_grad():
            for batch in test_dataloader:
                x, y = batch
                y_hat = model(x)
                loss = model.loss_function(y_hat, y)
                print(f"Test loss: {loss.item()}")