In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import json
from torch.nn.utils.rnn import pad_sequence

if torch.cuda.is_available():
    print("CUDA is available. PyTorch can use your GPU.")
    print(f"Number of GPUs available: {torch.cuda.device_count()}")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}") # Prints the name of the first GPU
else:
    print("CUDA is not available. PyTorch will run on CPU.")

CUDA is available. PyTorch can use your GPU.
Number of GPUs available: 1
GPU Name: NVIDIA GeForce RTX 4070 Laptop GPU


In [None]:
class TrajectoryDataset(Dataset):

    def __init__(self, json_file):

        with open(json_file, 'r') as f:
            self.data_dict = json.load(f)
        
        self.data = self.data_dict['data']
        
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):

        item = self.data[idx]
        
        goal = np.array([
            item['config']['goal']['x'],
            item['config']['goal']['y'],
            item['config']['goal']['z']
        ])
        
        config = goal # Label, Shape (3, )
        
        waypoints = []
        for wp in item['trajectory']['waypoints']:
            waypoints.append([
                wp['x'], wp['y'], wp['z'],
                wp['qx'], wp['qy'], wp['qz'], wp['qw']
            ])
        trajectory = np.array(waypoints)  # Output Data, Shape: (seq_len, 7)
        
        return torch.FloatTensor(config), torch.FloatTensor(trajectory)
    

def collate_pad(batch):
    xs, ys = zip(*batch)                 # tuples
    xs = torch.stack(xs, dim=0)          # [B, 3]

    
    lengths = torch.tensor([y.size(0) for y in ys], dtype=torch.long)
    ys = pad_sequence(ys, batch_first=True) 

    B, T_max, _ = ys.shape

    mask = torch.zeros(B, T_max, dtype=torch.bool)

    for b, L in enumerate(lengths):
        mask[b, :L] = True

    return xs, ys, lengths, mask  

In [None]:
class lstm_trajectory_generator(nn.Module):
    ''' Decodes hidden state output by encoder '''
    
    def __init__(self, input_size, trans_size, quat_size, hidden_size, num_layers):

        '''
        : param input_size:     the number of features in the input X
        : param hidden_size:    the number of features in the hidden state h
        : param num_layers:     number of recurrent layers (i.e., 2 means there are
        :                       2 stacked LSTMs)
        '''
        
        super(lstm_trajectory_generator, self).__init__()

        self.lstm = nn.LSTM(input_size = input_size, hidden_size = hidden_size, num_layers = num_layers)
        self.translation_linear = nn.Linear(hidden_size, trans_size)

    def forward(self, x_input, encoder_hidden_states):
        
        '''        
        : param x_input:                    should be 2D (batch_size, input_size)
        : param encoder_hidden_states:      hidden states
        : return output, hidden:            output gives all the hidden states in the sequence;
        :                                   hidden gives the hidden state and cell state for the last
        :                                   element in the sequence 
 
        '''
        
        lstm_out, hidden = self.lstm(x_input.unsqueeze(0), encoder_hidden_states) 
        translation_pred = self.translation_linear(lstm_out.squeeze(0))
                
        return translation_pred, hidden

In [None]:
def rollout(
    model,
    goal,
    consecutive,
    device='cuda' if torch.cuda.is_available() else 'cpu',
):
    """
    Roll out trajectories for a batch of goals using convergence + proximity
    stopping rules and safety caps.

    Args:
        model: LSTM trajectory generator. Called as model(goal, hidden).
        goal:  (B, 3) tensor of goal positions.
        consecutive: int, number of consecutive small steps required
                     to declare convergence.
        device: torch device.

    Returns:
        pred_traj:  (B, T_used, 2) predicted x,y waypoints (padded if needed).
        lengths:   (B,) tensor of actual lengths for each trajectory
                   (number of valid steps per batch element).
    """
    model.eval()
    goal = goal.to(device)
    B = goal.size(0)

    # Hyperparameters for stopping
    eps_step = 0.01   # [m] step size below which we consider it "not moving" (~1 cm)
    eps_goal = 0.02   # [m] distance to goal considered "close enough" (~5 cm)
    T_max = 100     # max number of rollout steps (tune to your dataset)
    L_max = 2.0     # [m] max total path length (your earlier constraint)

    goal_xy = goal[:, :2]  # (B, 2), we only predict x,y for now

    # Storage
    pred_traj = torch.zeros(B, T_max, 2, device=device)

    # Per-sample bookkeeping
    lengths = torch.zeros(B, dtype=torch.long, device=device)   # final lengths
    finished = torch.zeros(B, dtype=torch.bool, device=device)   # done mask
    stable_counts = torch.zeros(B, dtype=torch.long, device=device)   # consecutive small-step count
    total_length = torch.zeros(B, device=device)                     # path length so far

    hidden = None
    prev_pos = None
    last_step = 0  # will track how many timesteps we actually used globally

    with torch.no_grad():
        for t in range(T_max):

            # LSTM step: input is just the goal (same every step)
            trans_pred, hidden = model(goal, hidden)   # (B, 2)
            pred_traj[:, t, :] = trans_pred
            last_step = t + 1

            # Step size & path length
            if t > 0:
                step_vec  = trans_pred - prev_pos           # (B, 2)
                step_norm = torch.norm(step_vec, dim=-1)    # (B,)

                # If ANY component of the step is negative, treat it as "backwards"
                # and subtract the norm from total_length instead of adding it.
                moving_backward = (step_vec < 0).any(dim=-1)        # (B,) bool
                signed_step = torch.where(moving_backward,
                                        -step_norm,               # subtract
                                        step_norm)                # add
                total_length += signed_step

                # Update "stable" counts (small movement)
                is_small_step = step_norm < eps_step
                stable_counts = torch.where(
                    is_small_step,
                    stable_counts + 1,
                    torch.zeros_like(stable_counts),
                )

            else:
                # First step: no previous position yet
                step_norm = torch.zeros(B, device=device)
                
            # Distance to goal
            dist_to_goal = torch.norm(trans_pred - goal_xy, dim=-1)  # (B,)

            # Termination criteria
            close_enough   = dist_to_goal < eps_goal
            stable_enough  = stable_counts >= consecutive
            converged_good = close_enough & stable_enough

            too_long_path  = total_length > L_max

            # For each element not already finished, see if it should now stop
            newly_done = (~finished) & (converged_good | too_long_path)

            # Set length for newly finished sequences (t is 0-indexed, so length = t+1)
            lengths = torch.where(
                newly_done & (lengths == 0),
                torch.full_like(lengths, t + 1),
                lengths,
            )

            finished = finished | newly_done
            prev_pos = trans_pred

            # If all samples are done, we can stop unrolling early
            if finished.all():
                break

        # For any samples that never met a stopping condition, we
        # treat their length as the number of steps we actually unrolled.
        lengths = torch.where(
            lengths == 0,
            torch.full_like(lengths, last_step),
            lengths,
        )

    # Trim padded trajectory tensor to the actually used global horizon
    pred_traj = pred_traj[:, :last_step, :]  # (B, last_step, 2)

    return pred_traj, lengths

In [None]:
def train(model, train_loader, val_loader, 
          epochs, learning_rate,
          device='cuda' if torch.cuda.is_available() else 'cpu'):
    
    model = model.to(device)
    params = model.parameters()
    optimizer = optim.Adam(params, lr=learning_rate)
    state_loss = nn.MSELoss(reduction='none')

    history = {
        'train_loss': [],
        'val_loss': [],
    }


    for epoch in range(epochs):

        model.train()
        training_loss = 0.0

        for i, batch in enumerate(train_loader):

            goal = batch[0].to(device)
            gt_traj = batch[1].to(device) # (batch_size, max_batch_traj_length, 7)
            gt_lengths = batch[2].to(device)
            gt_mask = batch[3].to(device)
            batch_size, max_traj_len, _= gt_traj.shape

            optimizer.zero_grad()

            gt_translation = gt_traj[:, :, :2]  # (B, T, 2) - just x, y

            hidden = None

            pred_translations = torch.zeros(batch_size, max_traj_len, 2).to(device) # (B, T, 2)

            for t in range(max_traj_len):
                trans_pred, hidden = model(goal, hidden)
                pred_translations[:, t, :] = trans_pred

            
            trans_loss = state_loss(pred_translations, gt_translation)  # (B, T, 2)
            trans_loss = trans_loss.mean(-1)  # Average over x, y -> (B, T)
            #trans_loss_masked = (trans_loss * gt_mask).sum() / gt_mask.sum()

            _, T_max, _ = pred_translations.shape
            time_weights = torch.linspace(1.0, 5.0, steps=T_max, device=device)  # for example
            time_weights = time_weights.unsqueeze(0)  # (1, T_max)
            weighted_loss = trans_loss * time_weights * gt_mask


            # if epoch % 25 == 0 and i == 0:  # First batch every 20 epochs
                
            #     # Translation predictions
            #     print("\n=== Translation Predictions vs Ground Truth ===")
                
            #     # Track all errors across the batch
            #     all_batch_errors = []
            #     trajectory_means = []
                
            #     for b in range(batch_size):  # Show first 3 trajectories
            #         valid_length = int(gt_lengths[b].item())
                    
            #         if valid_length > 0:
            #             print(f"\nTrajectory {b} (length={valid_length}):")
            #             print("Step |  Pred (x, y)    |   GT (x, y)     |  Error")
            #             print("-" * 55)
                        
            #             # Track errors for mean calculation
            #             trajectory_errors = []
                        
            #             # Show entire trajectory
            #             for t in range(valid_length):
            #                 pred_x = pred_translations[b, t, 0].item()
            #                 pred_y = pred_translations[b, t, 1].item()
            #                 gt_x = gt_translation[b, t, 0].item()
            #                 gt_y = gt_translation[b, t, 1].item()
                            
            #                 error = (pred_x - gt_x)**2 + (pred_y - gt_y)**2
            #                 trajectory_errors.append(error)
            #                 all_batch_errors.append(error)
                            
            #                 print(f"{t:4d} | ({pred_x:6.3f}, {pred_y:6.3f}) | "
            #                     f"({gt_x:6.3f}, {gt_y:6.3f}) | {error:6.4f}")
                        
            #             # Print mean error for this trajectory
            #             mean_error = sum(trajectory_errors) / len(trajectory_errors)
            #             trajectory_means.append(mean_error)
            #             print("-" * 55)
            #             print(f"Mean trajectory error: {mean_error:6.4f}")
                
            #     # Print both types of batch mean errors
            #     print("=" * 55)
            #     if all_batch_errors:
            #         batch_mean_error = sum(all_batch_errors) / len(all_batch_errors)
            #         print(f"BATCH MEAN ERROR (all timesteps): {batch_mean_error:6.4f}")

            #     if trajectory_means:
            #         mean_of_means = sum(trajectory_means) / len(trajectory_means)
            #         print(f"BATCH MEAN ERROR (mean of means): {mean_of_means:6.4f}")


            loss = weighted_loss.sum() / (time_weights * gt_mask).sum()
            loss.backward()
            optimizer.step()
            training_loss += loss.item()

        average_train_loss = training_loss/len(train_loader)
        history['train_loss'].append(average_train_loss)

        print(f'Epoch {epoch+1}/{epochs}: Train Loss = {average_train_loss:.4f}')

        # Validation #
        with torch.no_grad():

            val_loss = 0.0

            for i, batch in enumerate(val_loader):

                goal      = batch[0].to(device)   # (B, 3)
                gt_traj   = batch[1].to(device)   # (B, T_gt_max, 7)
                gt_lengths = batch[2].to(device)  # (B,)

                batch_size, max_traj_len, _ = gt_traj.shape

                pred_traj, pred_len = rollout(model, goal, consecutive=3, device=device)
                # pred_traj: (B, T_pred_max, 2)
                # pred_len:  (B,)

                # We'll just look at x,y from the GT trajectory
                gt_translation = gt_traj[:, :, :2]

                # === DEBUG PRINTS ===
                # First batch every 25 epochs (mirroring your training logic)
                if epoch % 25 == 0 and i == 0:

                    print("\n=== [VALIDATION] Translation Predictions vs Ground Truth ===")

                    # Show up to first 3 trajectories in this batch
                    for b in range(min(batch_size, 3)):
                        gt_L    = int(gt_lengths[b].item())
                        pred_L  = int(pred_len[b].item())
                        max_L   = max(gt_L, pred_L)

                        print(f"\nTrajectory {b} (GT length={gt_L}, Pred length={pred_L}):")
                        print("Step |    Pred (x, y)      |     GT (x, y)       |  Error")
                        print("-" * 70)

                        for t in range(max_L):
                            # Predicted point (if within predicted length)
                            if t < pred_L:
                                pred_x = pred_traj[b, t, 0].item()
                                pred_y = pred_traj[b, t, 1].item()
                                pred_str = f"({pred_x:7.3f}, {pred_y:7.3f})"
                            else:
                                pred_x = pred_y = None
                                pred_str = "      ---           "

                            # Ground-truth point (if within GT length)
                            if t < gt_L:
                                gt_x = gt_translation[b, t, 0].item()
                                gt_y = gt_translation[b, t, 1].item()
                                gt_str = f"({gt_x:7.3f}, {gt_y:7.3f})"
                            else:
                                gt_x = gt_y = None
                                gt_str = "      ---           "

                            # Per-step squared error only if both exist
                            if (pred_x is not None) and (gt_x is not None):
                                err = (pred_x - gt_x)**2 + (pred_y - gt_y)**2
                                err_str = f"{err:7.4f}"
                            else:
                                err_str = "  n/a  "

                            print(f"{t:4d} | {pred_str} | {gt_str} | {err_str}")

                        print("-" * 70)
                    # (No mean trajectory error / batch mean errors here, per your request)

    return history

In [None]:
dataset = TrajectoryDataset('datasets/goal_input_datasets/trajectory_dataset_test.json')

N = len(dataset)
train_size = int(0.8 * N)
val_size = int(0.2 * N)

train_dataset, val_dataset = random_split(
    dataset, [train_size, val_size],
)

train_dataloader = DataLoader(
    train_dataset,  
    batch_size=8,
    collate_fn=collate_pad,
    shuffle=True,
    num_workers=0  
)
        
val_dataloader = DataLoader(
    val_dataset, 
    batch_size=8,
    collate_fn=collate_pad,
    shuffle=False,
    num_workers=0
)

model = lstm_trajectory_generator(input_size=3, trans_size=2, quat_size=2, hidden_size=128, num_layers=2)

EPOCHS = 500
history = train(model=model, train_loader=train_dataloader, val_loader=val_dataloader, epochs=EPOCHS, learning_rate=0.01)

In [None]:
import matplotlib.pyplot as plt

# Lists of 100 floats
train_losses = history['train_loss']
#val_losses = history['val_loss']

epochs = range(1, EPOCHS+1)

plt.figure()
plt.plot(epochs, train_losses, label='Train')
#plt.plot(epochs, val_losses, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
