In [None]:
!pip3 uninstall --yes torch torchaudio torchvision torchtext torchdata
!pip3 install torch torchaudio torchvision torchtext torchdata

Found existing installation: torch 2.5.1+cu121
Uninstalling torch-2.5.1+cu121:
  Successfully uninstalled torch-2.5.1+cu121
Found existing installation: torchaudio 2.5.1+cu121
Uninstalling torchaudio-2.5.1+cu121:
  Successfully uninstalled torchaudio-2.5.1+cu121
Found existing installation: torchvision 0.20.1+cu121
Uninstalling torchvision-0.20.1+cu121:
  Successfully uninstalled torchvision-0.20.1+cu121
[0mCollecting torch
  Downloading torch-2.5.1-cp310-cp310-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchaudio
  Downloading torchaudio-2.5.1-cp310-cp310-manylinux1_x86_64.whl.metadata (6.4 kB)
Collecting torchvision
  Downloading torchvision-0.20.1-cp310-cp310-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting torchtext
  Downloading torchtext-0.18.0-cp310-cp310-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting torchdata
  Downloading torchdata-0.9.0-cp310-cp310-manylinux1_x86_64.whl.metadata (5.5 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvi

In [None]:
import pickle
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
from torch.amp import GradScaler, autocast
import math
from torch.utils.data import random_split
from google.colab import drive, files
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
pkl_file_path = "/content/drive/MyDrive/Research: Computational Cardiovascular Models (1)/Computational Cardiac Models: Point RNN/Fall 2024: ML Models/Data Preparation/processed_fibers_longer_fibers.pkl"

with open(pkl_file_path, "rb") as f:
    fibers = pickle.load(f)

In [None]:
train_fibers, test_fibers = train_test_split(fibers, test_size=0.2, random_state=42)
len(train_fibers), len(test_fibers)

(322933, 80734)

# Data Preparation: Creating the Dataset class
Since the fibers are all variable length, I will use packing to ensure that the model does not focus on padded values during forward pass.

In [None]:
class FiberDataset(Dataset):
    def __init__(self, fibers, predict_steps=25):
        """
        Args:
            fibers (list of tensors): List of fibers, where each fiber is a tensor of shape (num_points, num_features).
            predict_steps (int): Number of future steps to predict.
        """
        self.inputs = []
        self.targets = []
        self.lengths = []
        self.predict_steps = predict_steps

        for fiber in fibers:
            seq_len = len(fiber)
            if seq_len > predict_steps:
                self.inputs.append(fiber[: -(predict_steps)]) # all points except the last `predict_steps`
                self.targets.append( torch.stack([fiber[i : i + predict_steps] for i in range(seq_len - predict_steps)]))
                self.lengths.append(seq_len - predict_steps)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx], self.lengths[idx]


def collate_fn(batch):
    """
    Collate function to pad sequences and return lengths.

    Args:
        batch: List of tuples (inputs, targets, lengths).
            inputs: Tensor of shape (seq_len, input_size).
            targets: Tensor of shape (seq_len, predict_steps, input_size).
            lengths: Sequence lengths.

    Returns:
        Padded inputs: Tensor of shape (batch_size, max_seq_len, input_size).
        Padded targets: Tensor of shape (batch_size, max_seq_len, predict_steps, input_size).
        Lengths tensor: Tensor of shape (batch_size,).
    """
    inputs, targets, lengths = zip(*batch)

    # Pad inputs to the same length
    inputs_padded = pad_sequence(inputs, batch_first=True)  # Shape: (batch_size, max_seq_len, input_size)

    # Pad targets to the same length
    max_seq_len = max([t.size(0) for t in targets])  # Find the max sequence length in the batch
    predict_steps = targets[0].size(1)  # Number of prediction steps (25 in your case)
    input_size = targets[0].size(2)  # Number of features per point (5 in your case)

    targets_padded = torch.zeros(len(targets), max_seq_len, predict_steps, input_size)
    for i, target in enumerate(targets):
        seq_len = target.size(0)
        targets_padded[i, :seq_len, :, :] = target  # Copy the target data into the padded tensor

    # Convert lengths to a tensor
    lengths_tensor = torch.tensor(lengths, dtype=torch.long)

    return inputs_padded, targets_padded, lengths_tensor

train_dataset = FiberDataset(train_fibers, predict_steps=25)
test_dataset = FiberDataset(test_fibers, predict_steps=25)

torch.manual_seed(42)
train_subset, val_subset = random_split(train_dataset, [0.9, 0.1])

train_loader = DataLoader(train_subset, batch_size=256, shuffle=True, num_workers=12, collate_fn=collate_fn, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=256, shuffle=False, num_workers=12, collate_fn=collate_fn, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=12, collate_fn=collate_fn, pin_memory=True)

len(train_subset), len(val_subset), len(test_dataset)

(290640, 32293, 80734)

In [None]:
# Verification for DataLoader and Dataset
def verify_dataloader(dataloader):
    """
    Function to verify the structure of the dataloader's output.

    Args:
        dataloader (DataLoader): Dataloader to verify.

    Returns:
        None: Prints the shapes and statistics of the batches.
    """
    for batch_idx, (inputs, targets, lengths) in enumerate(dataloader):
        print(f"Batch {batch_idx+1}:")
        print(f"  Inputs Shape: {inputs.shape}")  # Expected: (batch_size, max_seq_len, input_size)
        print(f"  Targets Shape: {targets.shape}")  # Expected: (batch_size, max_seq_len, predict_steps, input_size)
        print(f"  Lengths Shape: {lengths.shape}")  # Expected: (batch_size,)
        print(f"  Lengths: {lengths}")  # Check actual sequence lengths
        print(f"  Inputs - Min: {inputs.min()}, Max: {inputs.max()}, NaNs: {torch.isnan(inputs).sum()}")
        print(f"  Targets - Min: {targets.min()}, Max: {targets.max()}, NaNs: {torch.isnan(targets).sum()}")
        print("")

        # Break after the first batch to limit output
        if batch_idx == 0:
            break


# Check the train, validation, and test loaders
print("Verifying Train Loader...")
verify_dataloader(train_loader)

print("Verifying Validation Loader...")
verify_dataloader(val_loader)

print("Verifying Test Loader...")
verify_dataloader(test_dataloader)

Verifying Train Loader...
Batch 1:
  Inputs Shape: torch.Size([256, 513, 5])
  Targets Shape: torch.Size([256, 513, 25, 5])
  Lengths Shape: torch.Size([256])
  Lengths: tensor([381, 117,  75,  12, 266, 123,  46, 179, 253, 288,  38, 351, 103, 227,
        144,  50, 348,  49, 187,  40,  96, 314, 138,  76, 233,  98, 220,   6,
        221, 243, 325, 331, 248, 335, 245,  67, 329, 196,  89, 260, 253, 347,
        105, 281, 104,  29,  34, 219, 144,  81,  57,  83,  82, 219,  50, 149,
        220, 191, 219, 257,  41,  39,  46, 197, 216, 197, 307, 117, 144, 247,
        243, 155,  49, 252, 254, 100,  78, 230, 219,  80,  38,  74,  89, 104,
        291,  43, 248,  60,  42,  57, 256,  47, 293, 247, 291, 456, 334, 192,
        225, 245, 330,  45, 277, 132,  35, 217, 260,  53, 206, 476,  46,  37,
         75, 182,  70,  59, 201, 243, 411, 346, 174, 193,  37, 141, 125, 100,
        206,  51,  35, 256, 269, 113, 306, 255,  30, 141,  99,  34,  49, 177,
        126, 311,  89, 229, 203,  49, 116, 222, 26

# Bidirectional LSTM

In [None]:
class BidirectionalLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, predict_steps=25):
      super(BidirectionalLSTM, self).__init__()
      '''
      input_size: the number of expected features in the input x
      hidden_size: the number of features in the hidden state h
      num_layers: number of recurrent layers.
      '''
      self.bilstm = nn.LSTM(
          input_size=input_size,
          hidden_size=hidden_size,
          num_layers=num_layers,
          bidirectional=True,
          batch_first=True
        )
      self.predict_steps = predict_steps
      self.fc = nn.Linear(hidden_size * 2, self.predict_steps * input_size)  # Output size matches input for next-point prediction

    def forward(self, x, lengths):
      # Pack the padded sequence
      packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)

      # Pass through BiLSTM
      packed_out, (hidden, cell) = self.bilstm(packed_x)

      # Unpack the sequence
      out, _ = pad_packed_sequence(packed_out, batch_first=True)

      # Fully connected layer for next-point prediction
      output = self.fc(out)
      batch_size, seq_len, feature_size = output.size()
      assert feature_size == self.predict_steps * x.size(2), f"Mismatch in feature size: expected {self.predict_steps * x.size(2)}, got {feature_size}"

      # Reshape to (batch_size, seq_len, predict_steps, input_size)
      output = output.view(batch_size, seq_len, self.predict_steps, -1)

      return output, hidden


In [None]:
batch_size = 256
seq_len = 50
input_size = 5
predict_steps = 25
hidden_size = 64
num_layers = 2

# Dummy input
dummy_input = torch.rand(batch_size, seq_len, input_size).to(device)
lengths = torch.randint(low=30, high=seq_len, size=(batch_size,)).to('cpu')

# Instantiate the model
model = BidirectionalLSTM(input_size, hidden_size, num_layers, predict_steps).to(device)

# Forward pass
output, hidden = model(dummy_input, lengths)
print("Output Shape:", output.shape)  # Expected: (batch_size, seq_len, predict_steps, input_size)
print("Hidden Shape:", hidden[0].shape)  # Expected: (num_layers * 2, batch_size, hidden_size)

Output Shape: torch.Size([256, 49, 25, 5])
Hidden Shape: torch.Size([256, 64])


# Training

In [None]:
# Model Hyperparameters
input_size = 5  # [x, y, z, angle, depth]
hidden_size = 256
num_layers = 4

model = BidirectionalLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.1, min_lr=1e-6)

model = model.to(device)

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from torch.nn.utils.rnn import pad_packed_sequence

def visualize_blstm_fiber(inputs, targets, outputs, lengths, fiber_idx):
    """
    Visualize the input fiber, target points, and predicted points for one fiber processed by BLSTM.

    Args:
        inputs (Tensor): Input tensor (batch_size, seq_len, input_size).
        targets (Tensor): Target tensor (batch_size, seq_len, predict_steps, input_size).
        outputs (Tensor): Model outputs (batch_size, seq_len, predict_steps, input_size).
        lengths (Tensor): Lengths of valid sequences in the batch.
        fiber_idx (int): Index of the fiber in the batch to visualize.
    """
    # Unpack the fiber corresponding to the given fiber_idx
    seq_len = lengths[fiber_idx].item()
    input_fiber = inputs[fiber_idx, :seq_len, :3].cpu().numpy()  # [x, y, z] of input points
    target_points = targets[fiber_idx, :seq_len, :, :3].cpu().numpy()  # [x, y, z] of target points
    predicted_points = outputs[fiber_idx, :seq_len, :, :3].detach().cpu().numpy()  # [x, y, z] of predicted points

    # Last valid input point
    last_input_point = input_fiber[-1]  # Last point of the input trajectory

    # Target and predicted next 25 points from the last input point
    target_next_points = target_points[-1]  # Shape: (predict_steps, 3)
    predicted_next_points = predicted_points[-1]  # Shape: (predict_steps, 3)

    # Plot the input fiber trajectory and next points
    fig = plt.figure(figsize=(10, 7))
    ax = fig.add_subplot(111, projection='3d')

    # Plot the full input trajectory
    ax.plot(input_fiber[:, 0], input_fiber[:, 1], input_fiber[:, 2], label="Input Fiber Trajectory", color='blue')

    # Plot the target next points
    ax.scatter(target_next_points[:, 0], target_next_points[:, 1], target_next_points[:, 2], label="Target Next Points", color='green')

    # Plot the predicted next points
    ax.scatter(predicted_next_points[:, 0], predicted_next_points[:, 1], predicted_next_points[:, 2], label="Predicted Next Points", color='red')

    # Formatting the plot
    ax.set_title("BLSTM Fiber Visualization: Trajectory and Predictions")
    ax.set_xlabel("X-axis")
    ax.set_ylabel("Y-axis")
    ax.set_zlabel("Z-axis")
    ax.legend()
    plt.show()
def reload_model(load_path, hidden_size, num_layers):
  model = BidirectionalLSTM(input_size=5, hidden_size=hidden_size, num_layers=num_layers).to(device)
  checkpoint = torch.load(load_path, map_location=device)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  start_epoch = checkpoint['epoch']
  loss = checkpoint['loss']
  return model, optimizer, start_epoch, loss

def save_model(model, current_epoch, optimizer, scheduler, total_loss, val_loss, save_path):
  torch.save({
      'epoch': current_epoch,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'scheduler_state_dict': scheduler.state_dict(),
      'total training loss': total_loss,
      'validation loss': val_loss
  }, save_path)


In [None]:
def validate(model, dataloader, criterion):
    model.eval()
    total_val_loss = 0.0
    viz = False
    with torch.no_grad():
      for inputs, targets, lengths in dataloader:
          inputs = inputs.to(device)
          targets = targets.to(device)

          lengths, perm_idx = lengths.sort(0, descending=True)
          inputs = inputs[perm_idx]
          targets = targets[perm_idx]

          outputs, _ = model(inputs, lengths)
          loss = criterion(outputs, targets)

          if not viz:
              visualize_blstm_fiber(inputs, targets, outputs, lengths, fiber_idx=12)
              visualize_blstm_fiber(inputs, targets, outputs, lengths, fiber_idx=120)
              viz = True

          total_val_loss += loss.item()

    return total_val_loss / len(dataloader)

In [None]:
scaler = GradScaler()
num_epochs = 100
for epoch in tqdm(range(num_epochs)):
    model.train()
    total_loss = 0
    for batch_idx, (inputs, targets, lengths) in enumerate(train_loader):  # lengths added to the DataLoader output

        inputs = inputs.to(device)
        targets = targets.to(device)
        lengths, perm_idx = lengths.sort(0, descending=True)
        inputs = inputs[perm_idx]
        targets = targets[perm_idx]

        optimizer.zero_grad()
        with autocast(str(device)):  # Enables mixed precision
          outputs, _ = model(inputs, lengths)
          loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/50], Training Loss: {avg_loss:.4f}")

    val_loss = validate(model, val_loader, criterion)
    print(f"Epoch [{epoch+1}/50], Validation Loss: {val_loss:.4f}")

    scheduler.step(val_loss)
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}\n\n")

    if ((epoch+1) % 5 == 0) and ((epoch+1) >= 5):
        save_path = f"/content/BLSTM_25_point_prediction_epoch_{epoch+1}_version_2.pth"
        save_model(model, epoch, optimizer, scheduler, total_loss, val_loss, save_path)

        if ((epoch+1) % 10 == 0) or (epoch == num_epochs - 1):
            files.download(save_path)


# Saving the model

In [None]:
save_path = "/content/bilstm_fiber_25_points_v2.pth"

torch.save(model.state_dict(), save_path)
print(f"Model saved to: {save_path}")

In [None]:
from google.colab import files
files.download(save_path)