In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm  # Import tqdm for progress tracking


class TrajectoryDataset(Dataset):
    def __init__(self, dataframe, window_length=100):
        # Perform the custom transformation
        sliced_df = self.custom_transformation(dataframe.to_numpy(), window_length=window_length)
        self.data = torch.tensor(sliced_df, dtype=torch.float32)

    def __len__(self):
        # Return the number of trajectories
        return self.data.shape[0]

    def __getitem__(self, idx):
        # Get the trajectory at the given index
        return self.data[idx]

    def custom_transformation(self, dataframe_array, window_length):
        window_length += 1  # get one more column as targets

        # Preallocate memory for the slices
        sliced_data = np.lib.stride_tricks.sliding_window_view(dataframe_array, window_shape=(window_length,), axis=1)

        # Reshape into a flat 2D array for DataFrame-like output
        sliced_data = sliced_data.reshape(-1, window_length)

        return sliced_data


class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.attention = nn.Linear(hidden_size * 2, 1, bias=False)  # *2 for bidirectional

    def forward(self, lstm_output):
        attn_weights = torch.softmax(self.attention(lstm_output), dim=1)
        context = torch.sum(attn_weights * lstm_output, dim=1)
        return context


# Define the BiLSTM with Attention model
class BiLSTMWithAttention(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.5):
        super(BiLSTMWithAttention, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True, dropout=dropout)
        self.attention = Attention(hidden_size)
        self.fc = nn.Linear(hidden_size * 2, output_size)  # *2 for bidirectional

    def forward(self, x):
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)  # *2 for bidirectional
        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        lstm_output, _ = self.lstm(x, (h0, c0))
        context = self.attention(lstm_output)
        out = self.fc(context)
        return out

## Training loop

In [3]:
import os

dataset_path = "./dataset/"
# dataset_path = "/content/drive/Othercomputers/My Laptop/Sem_3/CSE_575_SML/Projects/Individual_Project/dataset/"

# Get the relative path of a file in the current working directory
train_path = os.path.join(dataset_path + "train.csv")
val_path = os.path.join(dataset_path + "val.csv")
test_path = os.path.join(dataset_path + "test.csv")

train_df = pd.read_csv(train_path, header=0).drop("ids", axis=1)
val_df = pd.read_csv(val_path, header=0).drop("ids", axis=1)
test_df = pd.read_csv(test_path, header=0).drop("ids", axis=1)

# print the training data shape
print(f"Training data shape: {train_df.shape}")
print(f"Validation data shape: {val_df.shape}")
print(f"Testing data shape: {test_df.shape}")

Training data shape: (963, 7560)
Validation data shape: (963, 1500)
Testing data shape: (963, 1500)


In [4]:
# Check if MPS is available and set the device accordingly
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA")

# Hyperparameters
window_length = 100
batch_size = 64
input_size = 1  # For univariate time series
hidden_size = 128
num_layers = 2
output_size = 1  # For univariate time series prediction
learning_rate = 0.001
num_epochs = 10

Using CUDA


In [5]:
# Prepare dataset and dataloader for training
dataset = TrajectoryDataset(dataframe=train_df, window_length=window_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Instantiate the model, loss function, optimizer, and scheduler
model = BiLSTMWithAttention(input_size, hidden_size, num_layers, output_size, dropout=0.5).to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)  # L2 regularization

In [6]:
import os

# Define the path to save the model checkpoints
checkpoint_path = "model_checkpoint.pth"

# Function to save the model checkpoint
def save_checkpoint(model, optimizer, epoch, loss, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, path)

# Function to load the model checkpoint
def load_checkpoint(model, optimizer, path):
    if os.path.isfile(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        print(f"Checkpoint loaded: epoch {epoch}, loss {loss:.4f}")
        return epoch, loss
    else:
        print("No checkpoint found, starting from scratch.")
        return 0, float('inf')

# Load the checkpoint if it exists
start_epoch, _ = load_checkpoint(model, optimizer, checkpoint_path)

# Training loop with tqdm for progress tracking
for epoch in tqdm(range(start_epoch, num_epochs), desc="Epochs", unit="epoch"):
    model.train()
    running_loss = 0.0

    # Use tqdm to track batch progress within each epoch
    for batch_idx, data in tqdm(enumerate(dataloader), desc=f"Epoch {epoch + 1}", unit="batch", leave=False):

        # Separate inputs and targets
        inputs = data[:, :-1].unsqueeze(2).to(device)
        targets = data[:, -1].to(device)  # Last column is the target (next value)

        optimizer.zero_grad() # Zero the parameter gradients

        outputs = model(inputs) # Forward pass

        loss = criterion(outputs.squeeze(), targets) # Compute the loss

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(dataloader):.4f}")

    # Save the checkpoint at the end of each epoch
    save_checkpoint(model, optimizer, epoch + 1, running_loss / len(dataloader), checkpoint_path)

No checkpoint found, starting from scratch.


Epochs:   0%|          | 0/10 [00:00<?, ?epoch/s]
Epoch 1: 0batch [00:00, ?batch/s][A
Epoch 1: 1batch [00:01,  1.73s/batch][A
Epoch 1: 21batch [00:01, 15.81batch/s][A
Epoch 1: 43batch [00:01, 35.88batch/s][A
Epoch 1: 68batch [00:02, 62.06batch/s][A
Epoch 1: 93batch [00:02, 89.87batch/s][A
Epoch 1: 118batch [00:02, 117.55batch/s][A
Epoch 1: 143batch [00:02, 143.51batch/s][A
Epoch 1: 168batch [00:02, 166.49batch/s][A
Epoch 1: 193batch [00:02, 185.70batch/s][A
Epoch 1: 218batch [00:02, 201.11batch/s][A
Epoch 1: 243batch [00:02, 213.14batch/s][A
Epoch 1: 268batch [00:02, 221.89batch/s][A
Epoch 1: 293batch [00:02, 228.78batch/s][A
Epoch 1: 318batch [00:03, 233.71batch/s][A
Epoch 1: 343batch [00:03, 237.50batch/s][A
Epoch 1: 368batch [00:03, 239.85batch/s][A
Epoch 1: 393batch [00:03, 241.54batch/s][A
Epoch 1: 418batch [00:03, 242.68batch/s][A
Epoch 1: 443batch [00:03, 243.27batch/s][A
Epoch 1: 468batch [00:03, 243.71batch/s][A
Epoch 1: 493batch [00:03, 244.16batch/s][A


Epoch [1/10], Loss: 0.0008



Epoch 2: 0batch [00:00, ?batch/s][A
Epoch 2: 1batch [00:00,  4.54batch/s][A
Epoch 2: 26batch [00:00, 99.76batch/s][A
Epoch 2: 51batch [00:00, 151.80batch/s][A
Epoch 2: 76batch [00:00, 182.98batch/s][A
Epoch 2: 101batch [00:00, 202.64batch/s][A
Epoch 2: 126batch [00:00, 215.52batch/s][A
Epoch 2: 151batch [00:00, 224.03batch/s][A
Epoch 2: 176batch [00:00, 229.74batch/s][A
Epoch 2: 201batch [00:01, 233.75batch/s][A
Epoch 2: 226batch [00:01, 236.45batch/s][A
Epoch 2: 251batch [00:01, 238.17batch/s][A
Epoch 2: 276batch [00:01, 239.72batch/s][A
Epoch 2: 301batch [00:01, 240.79batch/s][A
Epoch 2: 326batch [00:01, 241.43batch/s][A
Epoch 2: 351batch [00:01, 241.66batch/s][A
Epoch 2: 376batch [00:01, 242.12batch/s][A
Epoch 2: 401batch [00:01, 242.63batch/s][A
Epoch 2: 426batch [00:01, 242.70batch/s][A
Epoch 2: 451batch [00:02, 242.94batch/s][A
Epoch 2: 476batch [00:02, 243.13batch/s][A
Epoch 2: 501batch [00:02, 242.56batch/s][A
Epoch 2: 526batch [00:02, 242.51batch/s][A
E

Epoch [2/10], Loss: 0.0005



Epoch 3: 0batch [00:00, ?batch/s][A
Epoch 3: 1batch [00:00,  4.48batch/s][A
Epoch 3: 26batch [00:00, 98.94batch/s][A
Epoch 3: 51batch [00:00, 150.80batch/s][A
Epoch 3: 76batch [00:00, 182.17batch/s][A
Epoch 3: 101batch [00:00, 201.95batch/s][A
Epoch 3: 126batch [00:00, 214.94batch/s][A
Epoch 3: 151batch [00:00, 223.51batch/s][A
Epoch 3: 176batch [00:00, 229.67batch/s][A
Epoch 3: 201batch [00:01, 233.35batch/s][A
Epoch 3: 226batch [00:01, 236.51batch/s][A
Epoch 3: 251batch [00:01, 238.60batch/s][A
Epoch 3: 276batch [00:01, 239.78batch/s][A
Epoch 3: 301batch [00:01, 240.82batch/s][A
Epoch 3: 326batch [00:01, 241.46batch/s][A
Epoch 3: 351batch [00:01, 241.81batch/s][A
Epoch 3: 376batch [00:01, 242.03batch/s][A
Epoch 3: 401batch [00:01, 242.26batch/s][A
Epoch 3: 426batch [00:01, 242.59batch/s][A
Epoch 3: 451batch [00:02, 242.71batch/s][A
Epoch 3: 476batch [00:02, 242.73batch/s][A
Epoch 3: 501batch [00:02, 242.74batch/s][A
Epoch 3: 526batch [00:02, 242.83batch/s][A
E

Epoch [3/10], Loss: 0.0005



Epoch 4: 0batch [00:00, ?batch/s][A
Epoch 4: 1batch [00:00,  5.74batch/s][A
Epoch 4: 26batch [00:00, 114.09batch/s][A
Epoch 4: 51batch [00:00, 165.09batch/s][A
Epoch 4: 76batch [00:00, 192.68batch/s][A
Epoch 4: 101batch [00:00, 209.66batch/s][A
Epoch 4: 126batch [00:00, 220.55batch/s][A
Epoch 4: 151batch [00:00, 227.77batch/s][A
Epoch 4: 176batch [00:00, 232.11batch/s][A
Epoch 4: 201batch [00:00, 235.44batch/s][A
Epoch 4: 226batch [00:01, 237.79batch/s][A
Epoch 4: 251batch [00:01, 239.39batch/s][A
Epoch 4: 276batch [00:01, 240.60batch/s][A
Epoch 4: 301batch [00:01, 241.17batch/s][A
Epoch 4: 326batch [00:01, 241.55batch/s][A
Epoch 4: 351batch [00:01, 241.96batch/s][A
Epoch 4: 376batch [00:01, 242.33batch/s][A
Epoch 4: 401batch [00:01, 242.55batch/s][A
Epoch 4: 426batch [00:01, 242.07batch/s][A
Epoch 4: 451batch [00:02, 242.15batch/s][A
Epoch 4: 476batch [00:02, 242.57batch/s][A
Epoch 4: 501batch [00:02, 242.71batch/s][A
Epoch 4: 526batch [00:02, 242.89batch/s][A


Epoch [4/10], Loss: 0.0005



Epoch 5: 0batch [00:00, ?batch/s][A
Epoch 5: 1batch [00:00,  4.62batch/s][A
Epoch 5: 26batch [00:00, 100.83batch/s][A
Epoch 5: 51batch [00:00, 152.87batch/s][A
Epoch 5: 76batch [00:00, 183.71batch/s][A
Epoch 5: 101batch [00:00, 203.18batch/s][A
Epoch 5: 126batch [00:00, 215.93batch/s][A
Epoch 5: 151batch [00:00, 224.22batch/s][A
Epoch 5: 176batch [00:00, 229.80batch/s][A
Epoch 5: 201batch [00:01, 233.70batch/s][A
Epoch 5: 226batch [00:01, 236.66batch/s][A
Epoch 5: 251batch [00:01, 238.62batch/s][A
Epoch 5: 276batch [00:01, 239.80batch/s][A
Epoch 5: 301batch [00:01, 240.50batch/s][A
Epoch 5: 326batch [00:01, 241.15batch/s][A
Epoch 5: 351batch [00:01, 241.77batch/s][A
Epoch 5: 376batch [00:01, 242.10batch/s][A
Epoch 5: 401batch [00:01, 242.45batch/s][A
Epoch 5: 426batch [00:01, 242.65batch/s][A
Epoch 5: 451batch [00:02, 242.64batch/s][A
Epoch 5: 476batch [00:02, 242.38batch/s][A
Epoch 5: 501batch [00:02, 242.48batch/s][A
Epoch 5: 526batch [00:02, 242.81batch/s][A


Epoch [5/10], Loss: 0.0005



Epoch 6: 0batch [00:00, ?batch/s][A
Epoch 6: 1batch [00:00,  4.39batch/s][A
Epoch 6: 25batch [00:00, 95.06batch/s][A
Epoch 6: 50batch [00:00, 148.35batch/s][A
Epoch 6: 75batch [00:00, 180.34batch/s][A
Epoch 6: 100batch [00:00, 200.19batch/s][A
Epoch 6: 125batch [00:00, 213.60batch/s][A
Epoch 6: 150batch [00:00, 222.58batch/s][A
Epoch 6: 175batch [00:00, 228.60batch/s][A
Epoch 6: 200batch [00:01, 232.79batch/s][A
Epoch 6: 225batch [00:01, 235.78batch/s][A
Epoch 6: 250batch [00:01, 237.78batch/s][A
Epoch 6: 275batch [00:01, 239.06batch/s][A
Epoch 6: 300batch [00:01, 240.13batch/s][A
Epoch 6: 325batch [00:01, 241.13batch/s][A
Epoch 6: 350batch [00:01, 241.66batch/s][A
Epoch 6: 375batch [00:01, 242.21batch/s][A
Epoch 6: 400batch [00:01, 242.59batch/s][A
Epoch 6: 425batch [00:01, 242.70batch/s][A
Epoch 6: 450batch [00:02, 242.78batch/s][A
Epoch 6: 475batch [00:02, 242.57batch/s][A
Epoch 6: 500batch [00:02, 242.22batch/s][A
Epoch 6: 525batch [00:02, 242.28batch/s][A
E

Epoch [6/10], Loss: 0.0005



Epoch 7: 0batch [00:00, ?batch/s][A
Epoch 7: 1batch [00:00,  4.96batch/s][A
Epoch 7: 25batch [00:00, 102.17batch/s][A
Epoch 7: 50batch [00:00, 155.68batch/s][A
Epoch 7: 75batch [00:00, 186.39batch/s][A
Epoch 7: 100batch [00:00, 205.65batch/s][A
Epoch 7: 125batch [00:00, 217.97batch/s][A
Epoch 7: 150batch [00:00, 226.03batch/s][A
Epoch 7: 175batch [00:00, 231.25batch/s][A
Epoch 7: 200batch [00:01, 235.24batch/s][A
Epoch 7: 225batch [00:01, 237.47batch/s][A
Epoch 7: 250batch [00:01, 239.48batch/s][A
Epoch 7: 275batch [00:01, 240.47batch/s][A
Epoch 7: 300batch [00:01, 241.49batch/s][A
Epoch 7: 325batch [00:01, 241.92batch/s][A
Epoch 7: 350batch [00:01, 241.99batch/s][A
Epoch 7: 375batch [00:01, 242.48batch/s][A
Epoch 7: 400batch [00:01, 242.36batch/s][A
Epoch 7: 425batch [00:01, 242.14batch/s][A
Epoch 7: 450batch [00:02, 242.60batch/s][A
Epoch 7: 475batch [00:02, 242.91batch/s][A
Epoch 7: 500batch [00:02, 243.13batch/s][A
Epoch 7: 525batch [00:02, 242.89batch/s][A


Epoch [7/10], Loss: 0.0005



Epoch 8: 0batch [00:00, ?batch/s][A
Epoch 8: 1batch [00:00,  5.20batch/s][A
Epoch 8: 25batch [00:00, 105.07batch/s][A
Epoch 8: 50batch [00:00, 157.71batch/s][A
Epoch 8: 75batch [00:00, 188.06batch/s][A
Epoch 8: 100batch [00:00, 206.32batch/s][A
Epoch 8: 125batch [00:00, 218.32batch/s][A
Epoch 8: 150batch [00:00, 226.31batch/s][A
Epoch 8: 175batch [00:00, 231.76batch/s][A
Epoch 8: 200batch [00:01, 235.42batch/s][A
Epoch 8: 225batch [00:01, 237.84batch/s][A
Epoch 8: 250batch [00:01, 239.86batch/s][A
Epoch 8: 275batch [00:01, 241.14batch/s][A
Epoch 8: 300batch [00:01, 241.84batch/s][A
Epoch 8: 325batch [00:01, 242.38batch/s][A
Epoch 8: 350batch [00:01, 242.98batch/s][A
Epoch 8: 375batch [00:01, 243.35batch/s][A
Epoch 8: 400batch [00:01, 243.54batch/s][A
Epoch 8: 425batch [00:01, 243.52batch/s][A
Epoch 8: 450batch [00:02, 243.47batch/s][A
Epoch 8: 475batch [00:02, 243.26batch/s][A
Epoch 8: 500batch [00:02, 243.23batch/s][A
Epoch 8: 525batch [00:02, 243.42batch/s][A


Epoch [8/10], Loss: 0.0005



Epoch 9: 0batch [00:00, ?batch/s][A
Epoch 9: 1batch [00:00,  6.42batch/s][A
Epoch 9: 25batch [00:00, 118.19batch/s][A
Epoch 9: 50batch [00:00, 169.21batch/s][A
Epoch 9: 74batch [00:00, 194.89batch/s][A
Epoch 9: 98batch [00:00, 210.12batch/s][A
Epoch 9: 123batch [00:00, 220.39batch/s][A
Epoch 9: 147batch [00:00, 226.40batch/s][A
Epoch 9: 172batch [00:00, 230.81batch/s][A
Epoch 9: 197batch [00:00, 233.96batch/s][A
Epoch 9: 222batch [00:01, 236.22batch/s][A
Epoch 9: 246batch [00:01, 237.22batch/s][A
Epoch 9: 271batch [00:01, 238.16batch/s][A
Epoch 9: 296batch [00:01, 239.10batch/s][A
Epoch 9: 321batch [00:01, 239.56batch/s][A
Epoch 9: 346batch [00:01, 240.01batch/s][A
Epoch 9: 371batch [00:01, 240.20batch/s][A
Epoch 9: 396batch [00:01, 240.23batch/s][A
Epoch 9: 421batch [00:01, 240.21batch/s][A
Epoch 9: 446batch [00:02, 240.65batch/s][A
Epoch 9: 471batch [00:02, 240.93batch/s][A
Epoch 9: 496batch [00:02, 240.93batch/s][A
Epoch 9: 521batch [00:02, 240.63batch/s][A
E

Epoch [9/10], Loss: 0.0005



Epoch 10: 0batch [00:00, ?batch/s][A
Epoch 10: 1batch [00:00,  4.68batch/s][A
Epoch 10: 26batch [00:00, 101.46batch/s][A
Epoch 10: 51batch [00:00, 153.26batch/s][A
Epoch 10: 76batch [00:00, 183.76batch/s][A
Epoch 10: 101batch [00:00, 203.30batch/s][A
Epoch 10: 126batch [00:00, 216.01batch/s][A
Epoch 10: 151batch [00:00, 224.72batch/s][A
Epoch 10: 176batch [00:00, 230.17batch/s][A
Epoch 10: 201batch [00:01, 234.07batch/s][A
Epoch 10: 226batch [00:01, 236.65batch/s][A
Epoch 10: 251batch [00:01, 238.70batch/s][A
Epoch 10: 276batch [00:01, 239.90batch/s][A
Epoch 10: 301batch [00:01, 240.81batch/s][A
Epoch 10: 326batch [00:01, 241.43batch/s][A
Epoch 10: 351batch [00:01, 241.99batch/s][A
Epoch 10: 376batch [00:01, 242.30batch/s][A
Epoch 10: 401batch [00:01, 242.47batch/s][A
Epoch 10: 426batch [00:01, 242.55batch/s][A
Epoch 10: 451batch [00:02, 242.58batch/s][A
Epoch 10: 476batch [00:02, 242.73batch/s][A
Epoch 10: 501batch [00:02, 242.88batch/s][A
Epoch 10: 526batch [00

Epoch [10/10], Loss: 0.0005





## Evaluation Loop

In [7]:
# Autoregressive prediction function
def autoregressive_predict(model, input_matrix, prediction_length):
    """
    Perform autoregressive prediction using the learned model.

    Args:
    - model: The trained PyTorch model.
    - input_matrix: A matrix of initial time steps (e.g., shape (963, window_length)).
    - prediction_length: The length of the future trajectory to predict.

    Returns:
    - output_matrix: A tensor of the predicted future trajectory of the same length as `prediction_length`.
    """
    model.eval()  # Set model to evaluation mode
    input_matrix = input_matrix.to(next(model.parameters()).device)  # Move to model's device
    output_matrix = torch.empty(input_matrix.shape[0], 0).to(next(model.parameters()).device)  # Initialize on the model's device
    current_input = input_matrix

    with torch.no_grad():  # No need to calculate gradients for prediction
        for _ in range(prediction_length):
            # Predict the next time step
            next_pred = model(current_input.unsqueeze(2))

            # Concatenating the new column along dimension 1 (columns)
            output_matrix = torch.cat((output_matrix, next_pred), dim=1)

            # Use the predicted value as part of the next input
            current_input = torch.cat((current_input[:, 1:], next_pred), dim=1)

    return output_matrix

In [8]:
# Prepare validation and test data
train_set = torch.tensor(train_df.values[:, :].astype(np.float32), dtype=torch.float32).to(device)  # Move to the appropriate device
val_set = torch.tensor(val_df.values[:, :].astype(np.float32), dtype=torch.float32).to(device)  # Move to the appropriate device
test_set = torch.tensor(val_df.values[:, :].astype(np.float32), dtype=torch.float32).to(device)  # Move to the appropriate device


# Generate predictions for validation set
initial_input = train_set[:, -window_length:]  # use the last window of training set as initial input
val_predictions_tensor = autoregressive_predict(model, initial_input, val_set.shape[1])
print(f"Validation Predictions Tensor Shape: {val_predictions_tensor.shape}")

# Calculate MSE between predicted trajectories and actual validation trajectories for validation
mse_loss = nn.MSELoss()  # Calculate MSE for validation set
mse = mse_loss(val_predictions_tensor, val_set)  # Compute MSE
print(f"Autoregressive Validation MSE: {mse.item():.4f}")

Validation Predictions Tensor Shape: torch.Size([963, 1500])
Autoregressive Validation MSE: 0.0128


## Plot it out to see what is like

In [9]:
for row_idx in range(3):
    initial_input = val_set[row_idx, :window_length].unsqueeze(0)

    # Use the previously generated prediction for the validation set
    predicted_trajectory = val_predictions_tensor[row_idx].cpu().numpy()

    # Get the actual trajectory for comparison
    actual_trajectory = val_set[row_idx].cpu().numpy()

    # Plot the actual vs predicted trajectory
    plt.figure(figsize=(4, 4))
    plt.plot(range(len(actual_trajectory)), actual_trajectory, label="Actual Trajectory", color="blue", marker="o")
    plt.plot(range(len(predicted_trajectory)), predicted_trajectory, label="Predicted Trajectory", color="red", linestyle="--", marker="x")
    plt.title(f"Actual vs Predicted Trajectory (Row {row_idx})")
    plt.xlabel("Time Step")
    plt.ylabel("Value")
    plt.legend()
    plt.grid(True)
    plt.savefig(f"trajectory_{row_idx}.png", dpi=200)
    plt.close()

In [10]:
# Generate predictions for the test dataset
initial_input = val_predictions_tensor[:, -window_length:]
test_predictions_tensor = autoregressive_predict(model, initial_input, test_set.shape[1])

print(f"Test Predictions Tensor Shape: {test_predictions_tensor.shape}")

Test Predictions Tensor Shape: torch.Size([963, 1500])


In [11]:
def generate_submissions_v4(pred_val_tensor, pred_test_tensor, original_val_path, original_test_path):
    # Read the original validation and testing datasets
    original_val_df = pd.read_csv(original_val_path)
    original_test_df = pd.read_csv(original_test_path)

    # Ensure the shape of pred_val_tensor and pred_test_tensor is correct
    assert pred_val_tensor.shape[0] * pred_val_tensor.shape[1] == original_val_df.shape[0] * (original_val_df.shape[1] - 1)
    assert pred_test_tensor.shape[0] * pred_test_tensor.shape[1] == original_test_df.shape[0] * (original_test_df.shape[1] - 1)

    # Create empty lists to store ids and values
    ids = []
    values = []

    # Process validation set
    for col_idx, col in enumerate(original_val_df.columns[1:]):  # Skip the 'ids' column
        for row_idx, _ in enumerate(original_val_df[col]):
            ids.append(str(f"{col}_traffic_val_{row_idx}"))
            values.append(float(pred_val_tensor[row_idx, col_idx]))

    # Process testing set
    for col_idx, col in enumerate(original_test_df.columns[1:]):  # Skip the 'ids' column
        for row_idx, _ in enumerate(original_test_df[col]):
            ids.append(str(f"{col}_traffic_test_{row_idx}"))
            values.append(float(pred_test_tensor[row_idx, col_idx]))

    # Create the submissions dataframe
    submissions_df = pd.DataFrame({"ids": ids, "value": values})

    # Impute any null values
    submissions_df.fillna(100, inplace=True)

    # Assert the shape of the dataframe
    assert submissions_df.shape[1] == 2
    assert submissions_df.shape[0] == (original_val_df.shape[0] * (original_val_df.shape[1] - 1)) + (
        original_test_df.shape[0] * (original_test_df.shape[1] - 1)
    )
    assert "ids" in submissions_df.columns
    assert "value" in submissions_df.columns

    # Save to CSV
    submissions_df.to_csv("submissions_v3.csv", index=False)


# Call the function
generate_submissions_v4(val_predictions_tensor, test_predictions_tensor, val_path, test_path)