In [4]:
import torch
import torch.nn as nn
import numpy as np
from DDBSCAN import Raster_DBSCAN
from Models import *
from torch.utils.data import Dataset,DataLoader
import torch.optim as optim

import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import matplotlib.colors as mcolors
# times new roman font
plt.rcParams["font.family"] = "Times New Roman"
seed = 414
# np.random.seed(seed)
colors = np.random.rand(600, 3)
colors = np.concatenate([np.array([[0,0,0]]),colors],axis = 0)
colormap = mcolors.ListedColormap(colors)

In [5]:
class LaneSocialLSTM(nn.Module):
    def __init__(self, 
                 input_size=1,             # Position coordinate (x only)
                 output_size=2,            # Parameters for output distribution (mean, std)
                 embedding_size=64,        # Embedding dimension
                 rnn_size=128,             # LSTM hidden state size
                 grid_size=32,             # Number of lane cells to consider for social context
                 dropout=0.2,              # Dropout probability
                 use_cuda=True,            # GPU acceleration
                 seq_length=12,            # Sequence length for training
                 gru=False,                # Use GRU instead of LSTM
                 infer=False):             # Inference mode
        """
        Lane-based Social LSTM implementation adapted from Alahi et al. 2016
        Modified to work with 1D lane cell representation
        """
        super(LaneSocialLSTM, self).__init__()

        # Store parameters
        self.input_size = input_size
        self.output_size = output_size
        self.embedding_size = embedding_size
        self.rnn_size = rnn_size
        self.grid_size = grid_size  # Number of lane cells for social context
        self.use_cuda = use_cuda and torch.cuda.is_available()
        self.infer = infer
        self.gru = gru
        
        # Sequence length depends on training or inference
        self.seq_length = 1 if infer else seq_length

        # Linear embedding layers
        self.input_embedding_layer = nn.Linear(self.input_size, self.embedding_size)
        self.tensor_embedding_layer = nn.Linear(self.grid_size*self.rnn_size, self.embedding_size)
        
        # RNN Cell (LSTM or GRU)
        if self.gru:
            self.cell = nn.GRUCell(2*self.embedding_size, self.rnn_size)
        else:
            self.cell = nn.LSTMCell(2*self.embedding_size, self.rnn_size)
            
        # Output layer
        self.output_layer = nn.Linear(self.rnn_size, self.output_size)
        
        # Activation and regularization
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def create_social_tensor(self, lane_occupancy, hidden_states):
        """
        Computes the social tensor using lane occupancy states
        
        Args:
            lane_occupancy: Lane occupancy masks [num_vehicles, grid_size]
                            1 if a cell is occupied by another vehicle, 0 otherwise
            hidden_states: Hidden states of all vehicles [num_vehicles, rnn_size]
            
        Returns:
            social_tensor: Pooled tensor of lane information
        """
        num_vehicles = lane_occupancy.size(0)
        
        # Create empty tensor to hold social context
        social_tensor = torch.zeros(num_vehicles, self.grid_size, self.rnn_size)
        if self.use_cuda:
            social_tensor = social_tensor.cuda()
        
        # For each vehicle, compute its social context along the lane
        for vehicle in range(num_vehicles):
            # Multiply lane occupancy mask with hidden states to get pooled states
            # For each cell in the lane, add the hidden states of vehicles present there
            for cell in range(self.grid_size):
                # Find which vehicles occupy this cell (excluding the current vehicle)
                for other_vehicle in range(num_vehicles):
                    if other_vehicle != vehicle and lane_occupancy[other_vehicle, cell] > 0:
                        # Add the hidden state of the occupying vehicle
                        social_tensor[vehicle, cell] += hidden_states[other_vehicle] * lane_occupancy[other_vehicle, cell]
        
        # Reshape to form the final social tensor
        social_tensor = social_tensor.view(num_vehicles, self.grid_size*self.rnn_size)
        return social_tensor
    
    def forward(self, input_data, lane_occupancy, hidden_states, cell_states, vehicle_ids, num_vehicles_per_frame, look_up):
        """
        Forward pass for the Lane-based Social LSTM model
        
        Args:
            input_data: Input positions [seq_length, max_num_vehicles, input_size]
                        Contains x-coordinate in lane cells
            lane_occupancy: Lane occupancy masks [seq_length, max_num_vehicles, grid_size]
                           Represents lane cell occupancy around each vehicle
            hidden_states: Hidden states [num_total_vehicles, rnn_size]
            cell_states: Cell states [num_total_vehicles, rnn_size] 
            vehicle_ids: List of vehicle IDs present in each frame
            num_vehicles_per_frame: Number of vehicles in each frame
            look_up: Mapping from vehicle ID to index in hidden_states
            
        Returns:
            outputs: Predicted distributions [seq_length, num_vehicles, output_size]
            hidden_states: Updated hidden states
            cell_states: Updated cell states
        """
        # Get dimensions
        num_total_vehicles = len(look_up)
        
        # Prepare output tensor
        outputs = torch.zeros(self.seq_length * num_total_vehicles, self.output_size)
        if self.use_cuda:
            outputs = outputs.cuda()
        
        # Process each frame in the sequence
        for frame_idx, frame in enumerate(input_data):
            # Get IDs of vehicles present in current frame
            current_vehicle_ids = [int(veh_id) for veh_id in vehicle_ids[frame_idx]]
            
            # Skip if no vehicles in this frame
            if len(current_vehicle_ids) == 0:
                continue
            
            # Get indices in the hidden/cell states for the current vehicles
            current_indices = [look_up[x] for x in current_vehicle_ids]
            current_indices_tensor = torch.LongTensor(current_indices)
            if self.use_cuda:
                current_indices_tensor = current_indices_tensor.cuda()
            
            # Get current frame inputs for these vehicles
            current_input = frame[current_indices]
            
            # Get current lane occupancy for social pooling
            current_lane_occupancy = lane_occupancy[frame_idx]
            
            # Get current hidden states
            current_hidden = hidden_states[current_indices_tensor]
            
            if not self.gru:
                current_cell = cell_states[current_indices_tensor]
            
            # Compute social tensor through lane-based pooling
            social_tensor = self.create_social_tensor(current_lane_occupancy, current_hidden)
            
            # Embed inputs
            input_embedded = self.dropout(self.relu(self.input_embedding_layer(current_input)))
            
            # Embed social tensor
            social_embedded = self.dropout(self.relu(self.tensor_embedding_layer(social_tensor)))
            
            # Concatenate both embeddings
            concat_embedded = torch.cat((input_embedded, social_embedded), dim=1)
            
            # Process through RNN cell
            if self.gru:
                next_hidden = self.cell(concat_embedded, current_hidden)
                next_cell = None
            else:
                next_hidden, next_cell = self.cell(concat_embedded, (current_hidden, current_cell))
            
            # Compute outputs
            frame_outputs = self.output_layer(next_hidden)
            
            # Insert into the right positions in the output tensor
            for i, idx in enumerate(current_indices):
                outputs[frame_idx * num_total_vehicles + idx] = frame_outputs[i]
            
            # Update hidden and cell states
            hidden_states[current_indices_tensor] = next_hidden
            if not self.gru:
                cell_states[current_indices_tensor] = next_cell
        
        # Reshape outputs to [seq_length, num_vehicles, output_size]
        outputs_reshaped = torch.zeros(self.seq_length, num_total_vehicles, self.output_size)
        if self.use_cuda:
            outputs_reshaped = outputs_reshaped.cuda()
            
        for frame_idx in range(self.seq_length):
            for veh_idx in range(num_total_vehicles):
                outputs_reshaped[frame_idx, veh_idx] = outputs[frame_idx * num_total_vehicles + veh_idx]
        
        return outputs_reshaped, hidden_states, cell_states

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
val_dataset = TrajDataset(r"D:\TimeSpaceDiagramDataset\EncoderDecoder_EvenlySampled_FreeflowAug_0914_5res_lanechange_signal\100_frame\val",time_span)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=1)

# find out the number of 1 over the total number of elements
positive_ratio = []
for batch in tqdm(val_loader):
    post_occ_X = batch['post_occ_X'].to(device)
    target = batch['target'].to(device)
    speed = batch['speed_target'].to(device)
    positive_ratio.append(np.float32(((target == 1).sum() / target.numel()).cpu().numpy())) 