In [None]:
!pip3 uninstall --yes torch torchaudio torchvision torchtext torchdata
!pip3 install torch torchaudio torchvision torchtext torchdata

In [None]:
import pickle
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
from torch.amp import GradScaler, autocast
from torch.nn.utils.rnn import pad_sequence
import math
from torch.utils.data import random_split

In [None]:
pkl_file_path = "processed_fibers_longer_fibers.pkl"
with open(pkl_file_path, "rb") as f:
    fibers = pickle.load(f)

In [None]:
train_fibers, test_fibers = train_test_split(fibers, test_size=0.2, random_state=42)
len(train_fibers), len(test_fibers), train_fibers[0].shape, test_fibers[0].shape

(322933, 80734, torch.Size([178, 5]), torch.Size([119, 5]))

In [None]:
class FiberDataset(Dataset):
    def __init__(self, fibers, predict_steps=25):
        """
        Args:
            fibers (list of tensors): List of fibers, where each fiber is a tensor of shape (num_points, num_features).
            predict_steps (int): Number of future steps to predict.
        """
        self.inputs = []
        self.targets = []
        self.lengths = []
        self.predict_steps = predict_steps

        for fiber in fibers:
            seq_len = len(fiber)
            if seq_len > predict_steps:
                self.inputs.append(fiber[: -(predict_steps)]) # all points except the last `predict_steps`
                self.targets.append( torch.stack([fiber[i : i + predict_steps] for i in range(seq_len - predict_steps)]))
                self.lengths.append(seq_len - predict_steps)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx], self.lengths[idx]


def collate_fn(batch):
    """
    Collate function to pad sequences and return lengths.

    Args:
        batch: List of tuples (inputs, targets, lengths).
            inputs: Tensor of shape (seq_len, input_size).
            targets: Tensor of shape (seq_len, predict_steps, input_size).
            lengths: Sequence lengths.

    Returns:
        Padded inputs: Tensor of shape (batch_size, max_seq_len, input_size).
        Padded targets: Tensor of shape (batch_size, max_seq_len, predict_steps, input_size).
        Lengths tensor: Tensor of shape (batch_size,).
    """
    inputs, targets, lengths = zip(*batch)

    # Pad inputs to the same length
    inputs_padded = pad_sequence(inputs, batch_first=True)  # Shape: (batch_size, max_seq_len, input_size)

    # Pad targets to the same length
    max_seq_len = max([t.size(0) for t in targets])  # Find the max sequence length in the batch
    predict_steps = targets[0].size(1)  # Number of prediction steps (25 in your case)
    input_size = targets[0].size(2)  # Number of features per point (5 in your case)

    targets_padded = torch.zeros(len(targets), max_seq_len, predict_steps, input_size)
    for i, target in enumerate(targets):
        seq_len = target.size(0)
        targets_padded[i, :seq_len, :, :] = target  # Copy the target data into the padded tensor

    # Convert lengths to a tensor
    lengths_tensor = torch.tensor(lengths, dtype=torch.long)

    return inputs_padded, targets_padded, lengths_tensor

train_dataset = FiberDataset(train_fibers, predict_steps=25)
test_dataset = FiberDataset(test_fibers, predict_steps=25)

train_subset, val_subset = random_split(train_dataset, [0.9, 0.1])

len(train_subset), len(val_subset), len(test_dataset)

(290640, 32293, 80734)

In [None]:
train_loader = DataLoader(train_subset, batch_size=128, shuffle=True, num_workers=12, collate_fn=collate_fn, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=128, shuffle=False, num_workers=12, collate_fn=collate_fn, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=12, collate_fn=collate_fn, pin_memory=True)

# Model Architectures

In [None]:
class BidirectionalLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, predict_steps=25):
      super(BidirectionalLSTM, self).__init__()
      '''
      input_size: the number of expected features in the input x
      hidden_size: the number of features in the hidden state h
      num_layers: number of recurrent layers.
      '''
      self.bilstm = nn.LSTM(
          input_size=input_size,
          hidden_size=hidden_size,
          num_layers=num_layers,
          bidirectional=True,
          batch_first=True
        )
      self.predict_steps = predict_steps
      self.fc = nn.Linear(hidden_size * 2, self.predict_steps * input_size)  # Output size matches input for next-point prediction

    def forward(self, x, lengths):
      # Pack the padded sequence
      packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)

      # Pass through BiLSTM
      packed_out, (hidden, cell) = self.bilstm(packed_x)

      # Unpack the sequence
      out, _ = pad_packed_sequence(packed_out, batch_first=True)

      # Fully connected layer for next-point prediction
      output = self.fc(out)
      batch_size, seq_len, feature_size = output.size()
      assert feature_size == self.predict_steps * x.size(2), f"Mismatch in feature size: expected {self.predict_steps * x.size(2)}, got {feature_size}"

      # Reshape to (batch_size, seq_len, predict_steps, input_size)
      output = output.view(batch_size, seq_len, self.predict_steps, -1)

      return output, hidden


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.MSELoss()

In [None]:
model_name_2 = "/content/drive/MyDrive/Fall 2024: ML Models/Model 1: Bi-directional LSTM/bilstm_fiber_25_points_v2.pth"

In [None]:
def load_blstm_model(load_path, device, hidden_size, num_layers):
  model = BidirectionalLSTM(input_size=5, hidden_size=hidden_size, num_layers=num_layers)
  model.load_state_dict(torch.load(load_path, map_location=device))
  model.to(device)
  print("Model loaded successfully!")
  return model

In [None]:
model_2 = load_blstm_model(model_name_2, device, hidden_size=256, num_layers=4)

In [None]:
def test_blstm_model(model, dataloader, device):
  model.eval()
  total_loss = 0.0
  with torch.no_grad():
    for inputs, targets, lengths in dataloader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        lengths, perm_idx = lengths.sort(0, descending=True)
        inputs = inputs[perm_idx]
        targets = targets[perm_idx]

        outputs, _ = model(inputs, lengths)
        loss = criterion(outputs, targets)

        total_loss += loss.item()
        print(f"Test Loss for batch: {loss.item()}")

  return total_loss / len(dataloader)

In [None]:
# outputs errors for x, y, z, HA, TA separately
def test_blstm_model_per_feature(model, dataloader, device):
    model.eval()
    all_loss_per_feature = []
    with torch.no_grad():
        for inputs, targets, lengths in dataloader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Sort the sequences by length for the BLSTM
            lengths, perm_idx = lengths.sort(0, descending=True)
            inputs = inputs[perm_idx]
            targets = targets[perm_idx]

            outputs, _ = model(inputs, lengths)

            # Compute squared error without reduction:
            squared_error = (outputs - targets) ** 2  # Shape: (B, L, input_dim)
            lengths = lengths.to(device)

            # If your data has variable lengths and padded elements, construct a mask:
            max_seq_len = outputs.shape[1]
            mask = torch.arange(max_seq_len, device=device)[None, :] < lengths[:, None]  # Shape: (B, L)
            mask = mask.unsqueeze(-1).unsqueeze(-1) # Add two more dimensions to match the target tensor
            mask = mask.expand(-1, -1, outputs.shape[2], outputs.shape[3]) # Expand to match the size of squared_error

            # Zero out padded positions:
            squared_error = squared_error * mask

            # Compute the mean error per feature over valid points:
            valid_counts = mask.sum(dim=(0, 1))  # Shape: (1, input_dim)
            loss_per_feature = squared_error.sum(dim=(0, 1)) / valid_counts
            all_loss_per_feature.append(loss_per_feature)

    # Aggregate the losses over all batches:
    aggregated_loss = torch.stack(all_loss_per_feature, dim=0).mean(dim=0)
    return aggregated_loss

In [None]:
#test_blstm_model(model_1, test_dataloader, device)
model_1_rseults = test_blstm_model_per_feature(model_1, test_dataloader, device)

In [None]:
model_1_rseults.mean(dim=0)

tensor([1.0182e-01, 1.0780e-01, 6.2997e-02, 2.4478e+00, 1.9589e-03],
       device='cuda:0')

In [None]:
#test_blstm_model(model_2, test_dataloader, device)
model_2_rseults = test_blstm_model_per_feature(model_2, test_dataloader, device)

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from torch.nn.utils.rnn import pad_packed_sequence

def visualize_blstm_fiber(inputs, targets, outputs, lengths, fiber_idx):
    """
    Visualize the input fiber, target points, and predicted points for one fiber processed by BLSTM.

    Args:
        inputs (Tensor): Input tensor (batch_size, seq_len, input_size).
        targets (Tensor): Target tensor (batch_size, seq_len, predict_steps, input_size).
        outputs (Tensor): Model outputs (batch_size, seq_len, predict_steps, input_size).
        lengths (Tensor): Lengths of valid sequences in the batch.
        fiber_idx (int): Index of the fiber in the batch to visualize.
    """
    # Unpack the fiber corresponding to the given fiber_idx
    seq_len = lengths[fiber_idx].item()
    input_fiber = inputs[fiber_idx, :seq_len, :3].cpu().numpy()  # [x, y, z] of input points
    target_points = targets[fiber_idx, :seq_len, :, :3].cpu().numpy()  # [x, y, z] of target points
    predicted_points = outputs[fiber_idx, :seq_len, :, :3].detach().cpu().numpy()  # [x, y, z] of predicted points

    # Last valid input point
    last_input_point = input_fiber[-1]  # Last point of the input trajectory

    # Target and predicted next 25 points from the last input point
    target_next_points = target_points[-1]  # Shape: (predict_steps, 3)
    predicted_next_points = predicted_points[-1]  # Shape: (predict_steps, 3)

    # Plot the input fiber trajectory and next points
    fig = plt.figure(figsize=(10, 7))
    ax = fig.add_subplot(111, projection='3d')

    # Plot the full input trajectory
    ax.plot(input_fiber[:, 0], input_fiber[:, 1], input_fiber[:, 2], label="Input Fiber Trajectory", color='blue')

    # Plot the target next points
    ax.scatter(target_next_points[:, 0], target_next_points[:, 1], target_next_points[:, 2], label="Target Next Points", color='green')

    # Plot the predicted next points
    ax.scatter(predicted_next_points[:, 0], predicted_next_points[:, 1], predicted_next_points[:, 2], label="Predicted Next Points", color='red')

    # Formatting the plot
    ax.set_title("BLSTM Fiber Visualization: Trajectory and Predictions")
    ax.set_xlabel("X-axis")
    ax.set_ylabel("Y-axis")
    ax.set_zlabel("Z-axis")
    ax.legend()
    plt.show()
inputs, targets, lengths = next(iter(test_dataloader))
outputs_1, _ = model_1(inputs.to(device), lengths)
outputs_2, _ = model_2(inputs.to(device), lengths)
visualize_blstm_fiber(inputs, targets, outputs_1, lengths, fiber_idx=10)

In [None]:
visualize_blstm_fiber(inputs, targets, outputs_2, lengths, fiber_idx=10)

In [None]:
for fiber_idx in range(120):
  visualize_blstm_fiber(inputs, targets, outputs_1, lengths, fiber_idx=fiber_idx)

In [None]:
for fiber_idx in range(120):
  visualize_blstm_fiber(inputs, targets, outputs_2, lengths, fiber_idx=fiber_idx)