In [19]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import os, os.path 
import numpy 
import pickle
from glob import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

"""Change to the data folder"""
train_path = "../data/new_train"
val_path = '../data/new_val_in/new_val_in'

In [2]:
class ArgoverseDataset(Dataset):
    
    """Dataset class for Argoverse"""
    def __init__(self, data_path: str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.transform = transform
        self.pkl_list = glob(os.path.join(self.data_path, '*'))
        self.pkl_list.sort()
        
    def __len__(self): #len(val_dataset)
        return len(self.pkl_list)
    
    def __getitem__(self, idx): #val_dataset[0]
        pkl_path = self.pkl_list[idx]
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
        if self.transform:
            data = self.transform(data)
        return data

In [13]:
init_dataset = ArgoverseDataset(data_path=train_path) 
val_dataset  = ArgoverseDataset(data_path=val_path) 
print(len(init_dataset), len(val_dataset))

205942 3200


In [24]:
lengths = [int(len(init_dataset)*0.7) + 1, int(len(init_dataset)*0.3)]
print(lengths)
train_dataset, test_dataset = random_split(init_dataset, lengths)
print(len(train_dataset), len(test_dataset), len(val_dataset))

[144160, 61782]
144160 61782 3200


In [9]:
batch_sz = 2
def my_collate(batch): #[scene['p_in'][scene['car_mask'].reshape(-1) == 1]]
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in']]) for scene in batch]
    out = [numpy.dstack([scene['p_out']]) for scene in batch]
    inp = torch.FloatTensor(inp) #LongTensor
    out = torch.FloatTensor(out) #LongTensor
    return [inp, out]
train_loader = DataLoader(train_dataset, batch_size=batch_sz, shuffle = False, collate_fn=my_collate, num_workers=0)

In [5]:
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

In [6]:
class LSTM(nn.Module):
    def __init__(self, input_size= 2280, hidden_layer_size=200, output_size= 3600):
        super().__init__()
        self.hidden_layer_size = hidden_layer_size

        self.lstm = nn.LSTM(input_size, hidden_layer_size, batch_first = True)

        self.linear = nn.Linear(hidden_layer_size, output_size)

        self.hidden_cell = (torch.zeros(1,4,self.hidden_layer_size),
                            torch.zeros(1,4,self.hidden_layer_size))

    def forward(self, input_seq):
        lstm_out, self.hidden_cell = self.lstm(
                                        input_seq.view(len(input_seq), 1, -1), 
                                        self.hidden_cell)
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
#looss function and optimizer use from rose
model = LSTM()
loss_function = nn.MSELoss()
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
print(model)

cuda
LSTM(
  (lstm): LSTM(2280, 200, batch_first=True)
  (linear): Linear(in_features=200, out_features=3600, bias=True)
)


In [None]:
model.train()
epochs = 5

for i in range(epochs):
    print('epoch', i)
    iterator = tqdm(train_loader, total=int(len(train_loader)))
    for i_batch, sample_batch in enumerate(iterator):
        seq, labels = sample_batch
        #new_seq = seq.permute(0, 2, 1, 3).to(device)
        seq = seq.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        model.hidden_cell = (torch.zeros(1, 2, model.hidden_layer_size).to(device),
                        torch.zeros(1, 2, model.hidden_layer_size).to(device))

        y_pred = model(seq)
        
        single_loss = loss_function(y_pred.reshape(-1), labels.reshape(-1))
        single_loss.backward()
        optimizer.step()
    print(single_loss.item())
    if i%25 == 1:
        print(f'epoch: {i:3} loss: {single_loss.item():10.8f}')

print(f'epoch: {i:3} loss: {single_loss.item():10.10f}')

  0%|          | 14/102971 [00:00<12:44, 134.58it/s]

epoch 0


100%|██████████| 102971/102971 [12:08<00:00, 141.42it/s]
  0%|          | 13/102971 [00:00<13:43, 124.97it/s]

123659.96875
epoch 1


 38%|███▊      | 39332/102971 [05:12<07:34, 139.94it/s]