In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy 
import pickle
from glob import glob
import pandas as pd
import numpy as np

import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import RandomSampler, random_split, SubsetRandomSampler, WeightedRandomSampler

"""Change to the data folder"""
new_path = "./new_train/new_train"
val_path = "./new_val_in/new_val_in" 
# number of sequences in each dataset
# train:205942  val:3200 test: 36272 
# sequences sampled at 10HZ rate

In [2]:
torch.cuda.is_available()

True

### Create a dataset class 

In [38]:
class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""
    def __init__(self, data_path: str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.transform = transform

        self.pkl_list = glob(os.path.join(self.data_path, '*'))
        self.pkl_list.sort()
        
    def __len__(self):
        return len(self.pkl_list)

    def __getitem__(self, idx):

        pkl_path = self.pkl_list[idx]
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
            
        if self.transform:
            data = self.transform(data)

        return data


# intialize a dataset
train_dataset  = ArgoverseDataset(data_path=new_path)

### Create a loader to enable batch processing

In [39]:
batch_sz = 32

def my_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    out = [numpy.dstack([scene['p_out'], scene['v_out']]) for scene in batch]
    out = torch.LongTensor(out)
    inp = torch.tensor(inp, dtype=torch.float)
    out = torch.tensor(out, dtype=torch.float)
    return [inp, out]

train_dataset, val_dataset = random_split(train_dataset, (199078, 6864))

sampler = RandomSampler(train_dataset, num_samples = 50000, replacement=True)
# sampler = RandomSampler(train_dataset)

train_loader = DataLoader(train_dataset,batch_size=batch_sz, shuffle = False, collate_fn=my_collate, num_workers=0, sampler = sampler, drop_last=True)
val_loader = DataLoader(val_dataset,batch_size=batch_sz, shuffle = False, collate_fn=my_collate, num_workers=0, drop_last=True)


In [40]:
class NN(nn.Module):
    def __init__(self, input_size, output_size, hidden_dim, n_layers):
        super(NN, self).__init__()

        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.gru = nn.GRU(input_size, hidden_dim, n_layers, batch_first=True)   
        self.fc1 = nn.Linear(hidden_dim, output_size)
        self.fc2 = nn.Linear(output_size, output_size)
    
    def forward(self, x):
        batch_size = x.size(0)

        hidden = self.init_hidden(batch_size)
        out, hidden = self.gru(x, hidden)
        out = self.fc1(out)
        out = self.fc2(out)
        
        return out
    
    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim)
        return hidden.to(device)

In [41]:
from tqdm.notebook import tqdm
def train(model, device, train_loader, optimizer, epoch, log_interval=10000):
    model.train()
    iterator = tqdm(train_loader, total=int(len(train_loader)))
    loss_ema = -1
    for batch_idx, (data, target) in enumerate(iterator):
        data, target = torch.reshape(data, (batch_sz, 60, -1)).to(device), torch.reshape(target[:,:,:,:2], (batch_sz, 60, -1)).to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = nn.MSELoss()(output, target)
        loss.backward()
        optimizer.step()
        if loss_ema < 0:
            loss_ema = loss.item()
        loss_ema = loss_ema * 0.99 + loss.item()*0.01
        iterator.set_postfix(loss=loss_ema)
    return running_loss
#         iterator.set_postfix(loss=(loss.item()*data.size(0) / (counter * train_loader.batch_size)))

In [42]:
def test(model, device, test_loader):
    model.eval()
    loss_ema = -1
    with torch.no_grad():
        for data, target in test_loader:
            data, target = torch.reshape(data, (batch_sz, 60, -1)).to(device), torch.reshape(target[:,:,:,:2], (batch_sz, 60, -1)).to(device)
            output = model(data)
            loss = nn.MSELoss()(output, target)
            if loss_ema < 0:
                loss_ema = loss.item()
            loss_ema = loss_ema * 0.99 + loss.item()*0.01
    print("Test loss: {}".format(loss_ema))

In [43]:
# input dimension
input_dim = 76
hidden_dim = 60  # hidden layer dimension
layer_dim = 1   # number of hidden layers
output_dim = 60   # output dimension

if torch.cuda.is_available():  
    device = "cuda:0" 
else:  
    device = "cpu"  
    
learning_rate = 0.0001
momentum = 0.5

In [None]:
net = NN(input_dim, output_dim, hidden_dim, layer_dim)
model = net.to(device)
optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                      momentum=momentum, weight_decay=1e-5)
num_epoch = 10
losses = []
for epoch in range(1, num_epoch + 1):
    loss = train(model, device, train_loader, optimizer, epoch)
    losses.append(loss)
    test(model, device, val_loader)
    torch.save(model.state_dict(), 'checkpoints/train-epoch-gru-2linear-{}.pth'.format(epoch + 1)) 
    
with open("losses_test.txt", "w") as output:
    output.write(str(losses))

  0%|          | 0/1562 [00:00<?, ?it/s]

  out = torch.LongTensor(out)
  out = torch.tensor(out, dtype=torch.float)


## Test and predict

In [10]:
model=torch.load("checkpoints/train-epoch-lstm-ad5.pth",map_location=device)

In [11]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False
    
    model.eval()
    
    return model

In [12]:
batch_sz = 4

def my_test_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
#     print(batch)
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    inp = torch.tensor(inp, dtype=torch.float)
    agent = [scene['agent_id'] for scene in batch]
    track = [scene['track_id'] for scene in batch]
    track_ids = []
    for scene in range(len(agent)):
        for track_id in range(len(track[scene])):
            if agent[scene] == track[scene][track_id][0][0]:
                track_ids.append(track_id)
                break
    return [inp, track_ids]

# intialize a dataset
val_dataset  = ArgoverseDataset(data_path=val_path)
val_loader = DataLoader(val_dataset,batch_size=batch_sz, shuffle = False, collate_fn=my_test_collate, num_workers=0, drop_last=True)


In [13]:
def predict(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    predictions = []
    with torch.no_grad():
        for data, agent in test_loader:
            data = torch.reshape(data, (4, 60, -1)).to(device)
            output = torch.reshape(model(data), (4, 60, 30, -1))
            for i in range(len(agent)):
                scene = output[i][agent[i]]
                predictions.append(scene)
    predict = [torch.reshape(t, (-1,)) for t in predictions]
    sample = pd.read_csv('sample_submission.csv')
    preds_df = sample.set_index('ID')
    for i in range(3200):
        preds_df.iloc[i] = predict[i].tolist()
    return preds_df

In [14]:
net = NN(input_dim, output_dim, hidden_dim, layer_dim)
net.load_state_dict(torch.load("checkpoints/train-epoch-lstm7.pth",map_location=device))
model = net.to(device)
preds = predict(model, device, val_loader)

In [15]:
preds.to_csv('test_preds_lstmad_5epoch.csv')

In [16]:
preds.head()

Unnamed: 0_level_0,v1,v2,v3,v4,v5,v6,v7,v8,v9,v10,...,v51,v52,v53,v54,v55,v56,v57,v58,v59,v60
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
10002,373.583618,409.271149,373.07019,409.814087,373.118713,409.852936,373.660858,410.04187,373.374847,409.721344,...,373.421082,409.322083,373.111237,410.563293,373.395081,409.945557,372.836731,409.021606,373.423096,408.717377
10015,307.479889,594.395691,306.787903,595.172546,306.849792,595.035889,306.841583,595.468994,307.370422,594.602295,...,307.010834,594.812439,306.920563,594.825012,307.250366,595.341003,307.315063,594.517883,306.822601,594.301147
10019,307.479889,594.395691,306.787903,595.172546,306.849792,595.035889,306.841583,595.468994,307.370422,594.602295,...,307.010834,594.812439,306.920563,594.825012,307.250366,595.341003,307.315063,594.517883,306.822601,594.301147
10028,365.634521,397.756958,365.154144,398.212952,365.262207,398.203186,365.680573,398.384827,365.398865,398.197174,...,365.546722,397.687622,365.194733,398.930786,365.552155,398.312927,364.974731,397.506287,365.50946,397.180695
1003,414.335846,468.951721,414.237823,469.79364,413.88385,469.741394,414.674194,470.003113,414.295105,469.289734,...,414.399597,469.401367,413.973083,470.124146,414.219147,469.857361,413.98703,468.89801,414.446289,468.523499
