In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from trajectory_dataset import TrajectoryDataset
from LSTM_NN_model import TrajectoryPredictorLSTM
from CNN_trajectory_pred import ResidualCNNTrajectoryPredictor

class MainVehiclePredictor:
    def __init__(self, model_path=None):
        self.base_model = TrajectoryPredictorLSTM()
        if model_path:
            self.base_model.load_state_dict(torch.load(model_path))
            
    def train(self, data_path, epochs=200, batch_size=32, lr=0.001):
        # Use your existing dataset but only get the main vehicle
        full_dataset = TrajectoryDataset(data_path)
        
        # Assuming first valid agent is the main vehicle
        main_vehicle_history = full_dataset.history_normalized[0:1]  # Keep batch dimension
        main_vehicle_future = full_dataset.future_normalized[0:1]
        main_vehicle_history_valid = full_dataset.history_valid[0:1]
        main_vehicle_future_valid = full_dataset.future_valid[0:1]
        
        # Training loop
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.base_model = self.base_model.to(device)
        optimizer = torch.optim.Adam(self.base_model.parameters(), lr=lr)
        criterion = nn.MSELoss()
        
        self.base_model.train()
        for epoch in range(epochs):
            optimizer.zero_grad()
            
            # Move data to device
            hist = main_vehicle_history.to(device)
            fut = main_vehicle_future.to(device)
            fut_valid = main_vehicle_future_valid.to(device)
            
            # Forward pass
            pred = self.base_model(hist)
            
            # Calculate loss only on valid points
            loss = criterion(pred * fut_valid.unsqueeze(-1), 
                           fut * fut_valid.unsqueeze(-1))
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.6f}")
                self.visualize_prediction(hist, fut, pred, epoch, full_dataset)
    
    def predict(self, history):
        self.base_model.eval()
        with torch.no_grad():
            prediction = self.base_model(history)
        return prediction
    
    def visualize_prediction(self, history, future, prediction, epoch, dataset):
        plt.figure(figsize=(10, 10))
        
        # Denormalize
        hist_denorm = dataset.denormalize(history.cpu())
        fut_denorm = dataset.denormalize(future.cpu())
        pred_denorm = dataset.denormalize(prediction.cpu())
        
        # Plot trajectories
        plt.plot(hist_denorm[0, :, 0], hist_denorm[0, :, 1], 
                'b-', linewidth=2, label='History')
        plt.plot(fut_denorm[0, :, 0], fut_denorm[0, :, 1], 
                'g-', linewidth=2, label='Ground Truth')
        plt.plot(pred_denorm[0, :, 0], pred_denorm[0, :, 1], 
                'r--', linewidth=2, label='Prediction')
        
        plt.title(f'Main Vehicle Trajectory Prediction - Epoch {epoch+1}')
        plt.xlabel('X Position (meters)')
        plt.ylabel('Y Position (meters)')
        plt.axis('equal')
        plt.grid(True)
        plt.legend()
        plt.show()



def visualize_prediction(history, future, prediction, epoch, dataset):
 """   """
    Visualize the model's prediction with enhanced styling.
    
    Args:
        history: tensor of shape [1, 11, 2] - normalized historical trajectory
        future: tensor of shape [1, 80, 2] - normalized ground truth future trajectory
        prediction: tensor of shape [1, 80, 2] - normalized predicted future trajectory
        epoch: current training epoch
        dataset: dataset object for denormalization
    """
"""
    # Denormalize all trajectories
    hist_denorm = dataset.denormalize(history.cpu())
    fut_denorm = dataset.denormalize(future.cpu())
    pred_denorm = dataset.denormalize(prediction.cpu())
    
    # Convert to numpy for plotting
    hist_denorm = hist_denorm.numpy()[0]  # [11, 2]
    fut_denorm = fut_denorm.numpy()[0]    # [80, 2]
    pred_denorm = pred_denorm.numpy()[0]   # [80, 2]
    
    # Create figure
    plt.figure(figsize=(15, 10))
    
    # Plot history trajectory
    plt.plot(hist_denorm[:, 0], hist_denorm[:, 1], 
            'o-', color='blue', alpha=1.0, linewidth=2,
            label='History')
    
    # Plot ground truth future trajectory
    plt.plot(fut_denorm[:, 0], fut_denorm[:, 1], 
            'o-', color='green', alpha=1.0, linewidth=2,
            label='Ground Truth')
    
    # Plot predicted future trajectory
    plt.plot(pred_denorm[:, 0], pred_denorm[:, 1], 
            '--', color='red', alpha=0.7, linewidth=2,
            label='Prediction')
    
    # Add vehicle rectangles at current position (last historical point)
    current_pos = hist_denorm[-1]
    # Assuming a standard vehicle size (you can adjust these)
    width, length = 2.0, 4.0
    
    # Vehicle rectangle at current position
    rect = Rectangle(
        (current_pos[0] - length/2, current_pos[1] - width/2),
        length, width,
        angle=0,  # We don't have yaw information here
        color='blue',
        alpha=0.7,
        linewidth=2
    )
    plt.gca().add_patch(rect)
    
    # Add markers for start and end points
    plt.plot(hist_denorm[0, 0], hist_denorm[0, 1], 'bo', markersize=10, label='Start')
    plt.plot(fut_denorm[-1, 0], fut_denorm[-1, 1], 'go', markersize=10, label='Ground Truth End')
    plt.plot(pred_denorm[-1, 0], pred_denorm[-1, 1], 'ro', markersize=10, label='Predicted End')
    
    plt.title(f'Trajectory Prediction - Epoch {epoch + 1}\n'
             f'Blue: History | Green: Ground Truth | Red: Prediction')
    plt.xlabel('X Position (meters)')
    plt.ylabel('Y Position (meters)')
    plt.axis('equal')
    plt.grid(True)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Print prediction statistics
    error = np.mean(np.sqrt(np.sum((fut_denorm - pred_denorm)**2, axis=1)))
    final_error = np.sqrt(np.sum((fut_denorm[-1] - pred_denorm[-1])**2))
    
    print(f"\nPrediction Statistics:")
    print(f"Average prediction error: {error:.2f} meters")
    print(f"Final position error: {final_error:.2f} meters")
    print(f"Trajectory length - History: {len(hist_denorm)} steps, Future: {len(fut_denorm)} steps")
    
    plt.tight_layout()
    plt.show()

def train_model(model, train_loader, epochs=2000, lr=0.01):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    losses = []
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        
        for batch_idx, ((hist, hist_valid), (fut, fut_valid)) in enumerate(train_loader):
            hist, fut = hist.to(device), fut.to(device)
            hist_valid = hist_valid.to(device)
            fut_valid = fut_valid.to(device)
            
            optimizer.zero_grad()
            pred = model(hist)
            loss = criterion(pred * fut_valid.unsqueeze(-1), fut * fut_valid.unsqueeze(-1))
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        
        # Visualize every 10 epochs
        if (epoch + 1) % 50 == 0:
            print(f"\nEpoch {epoch + 1}/{epochs}")
            print(f"Average loss: {avg_loss:.6f}")
            
            model.eval()
            with torch.no_grad():
                # Get a sample trajectory
                hist, fut = next(iter(train_loader))[0][0][:1].to(device), next(iter(train_loader))[1][0][:1].to(device)
                pred = model(hist)
                
                # Visualize the prediction
                visualize_prediction(hist, fut, pred, epoch, train_loader.dataset)
    
    print("\nTraining completed!")
    torch.save(model.state_dict(), 'trajectory_model.pth')
    
    # Plot final training loss
    plt.figure(figsize=(10, 5))
    plt.plot(losses)
    plt.title('Training Loss Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Loss (MSE)')
    plt.grid(True)
    plt.show()
    
    return model, losses



def main():
    # Load and prepare data
    dataset = TrajectoryDataset('processed_data/processed_uncompressed_tf_example_training_training_tfexample.tfrecord-00003-of-01000.npz')
    train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
    
    # Create and train model
    model = MainVehiclePredictor()
    model, losses = train_model(model, train_loader)
    
    # Plot training loss
    plt.figure(figsize=(10, 5))
    plt.plot(losses)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.show()

if __name__ == "__main__":
    main()