In [17]:
from __future__ import print_function, division
import torch
from torch import nn as nn
import torch.optim as optim
import numpy as np
import random
import torch.nn.functional as F
import time
import math
from torch.utils.data import Dataset, DataLoader
import scipy.io as scp
from tqdm import tqdm
from utils import ngsimDataset,maskedNLL,maskedMSE,maskedNLLTest


In [18]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cpu


In [14]:
class Maneuver_class(nn.Module):
    
    def __init__(self, input_size, hidden_size=128, num_layers = 1):
        
        super(Maneuver_class, self).__init__()
        self.input_size = input_size
        self.embedding = nn.Linear(input_size, 64)
        self.hidden_size = 128
        self.num_layers = 1

        # define LSTM layer
        self.lstm = nn.LSTM(64, hidden_size = 128,            # input 是 (sequence_size, batch_size, input_size)
                            num_layers = 1, batch_first=False)       
        self.lat_linear = nn.Linear(self.hidden_size, 3) 
        self.lon_linear = nn.Linear(self.hidden_size, 2) 
        
        #define activation:
        self.leaky_relu = torch.nn.LeakyReLU(0.1)
        self.softmax = torch.nn.Softmax(dim=1)
        
    def forward(self, x_input):
        
        embedded = self.embedding(x_input)
        embedded = self.leaky_relu(embedded)
        lstm_out, self.hidden = self.lstm(embedded)
        lat_temp = self.lat_linear(lstm_out[-1])
        lat_pred = self.softmax(lat_temp)
        
        lon_temp = self.lon_linear(lstm_out[-1])
        lon_pred = self.softmax(lon_temp)
        
        
        return  lat_pred, lon_pred

In [19]:
def prepare_data(batch_size, t_h): #假设 t_h = 8
    inputs = torch.randn(t_h, batch_size, 14) # len(features) = 14
    return inputs
input_mc = prepare_data(20, 8)
classmodel = Maneuver_class(14)
lat_pred, lon_pred = classmodel(input_mc)
print(lat_pred.shape)
print(lon_pred.shape)

torch.Size([20, 3])
torch.Size([20, 2])


# Training

In [None]:
def cross_entropy(pred, target):
    loss = -np.sum(target * np.log(pred))
    return loss/float(pred.shape[0])

In [20]:
def valid(model, valid_loader, loss_fn):
    model.eval()
    valid_loss = 0.
    for data in tqdm(valid_loader):
        hist, nbrs, mask, lat_enc, lon_enc, fut, op_mas = data
        hist,nbrs,lat_enc, lon_enc= hist.to(device), nbrs.to(device), lat_enc.to(device), lon_enc.to(device)
        lat_pred, lon_pred = model(hist)    
        lat_loss = loss_fn(lat_pred, lat_enc)
        lon_loss = loss_fn(lon_pred, lon_enc)
            
        loss = lat_loss + lon_loss
        valid_loss += loss.item()
    return valid_loss / len(valid_loader)

In [27]:
def train_model(data_dir, input_size, n_epochs=100, _batch_size=1, _lr = 0.001, load_model=False, model_path='output/model0.pth'):

    trSet = ngsimDataset(data_dir+'TrainSet.mat')
    valSet = ngsimDataset(data_dir+'ValSet.mat')
    trDataloader = DataLoader(trSet,batch_size=_batch_size,shuffle=True,collate_fn=trSet.collate_fn)
    valDataloader = DataLoader(valSet,batch_size=_batch_size,shuffle=True,collate_fn=valSet.collate_fn)
    # training
    model = Maneuver_class(input_size).to(device) # 目前尝试只考虑本车：input_size=2

    if load_model == True:
        ckpt = torch.load(model_path)
        model.load_state_dict(ckpt['state_dict'])

    loss_fn = torch.nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=_lr)
        
    min_loss = 100.  # why?

    for epoch in range(n_epochs):
        total_loss = 0.
        it = 0
        print(f"Epoch {epoch+1}\n-------------------------------")
 
        for data in tqdm(trDataloader):
            hist, nbrs, mask, lat_enc, lon_enc, fut, op_mas = data
            hist,nbrs,lat_enc, lon_enc = hist.to(device), nbrs.to(device), lat_enc.to(device), lon_enc.to(device)  #缺一个padding方法,目前input_size要设置为2
            lat_pred, lon_pred = model(hist)
            
            lat_loss = loss_fn(lat_pred, lat_enc)
            lon_loss = loss_fn(lon_pred, lon_enc)
            
            loss = lat_loss + lon_loss
            # zero the gradient
            optimizer.zero_grad()
                          
            # backpropagation
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            it += 1
            if it%500==0:
                print("Training Iteration {} of epoch {} complete. Loss: {}".
                    format(it, epoch, loss.item()))

        epoch_loss = total_loss/len(trDataloader)
        print('epoch:{},avg_loss:{}'.format(epoch, epoch_loss))
        #save best model
        if epoch_loss < min_loss:
            min_loss = epoch_loss
            torch.save({'epoch': epoch, 'state_dict': model.state_dict()},
                           'output/model_{}.pth'.format(epoch))
            print("epoch:%d Model Saved" % epoch)

        #validation
        if epoch % 10 == 0:
            valid_loss = valid(model, valDataloader, loss_fn)
            print('validation epoch:{},valid_loss:{}'.format(epoch, valid_loss))
            model.train()

In [28]:
data_dir = '/mnt/e/Northwestern University/Courses/2022winter/pattern recognition/projects/data/'
train_model(data_dir, 2, n_epochs=100, _batch_size=1, _lr = 0.001, load_model=False, model_path='output/model0.pth')

Epoch 1
-------------------------------


  0%|          | 527/5922867 [00:04<11:11:46, 146.93it/s]

Training Iteration 500 of epoch 0 complete. Loss: 0.4189721345901489


  0%|          | 1018/5922867 [00:07<12:29:20, 131.71it/s]

Training Iteration 1000 of epoch 0 complete. Loss: 0.2595510482788086


  0%|          | 1515/5922867 [00:11<10:52:24, 151.27it/s]

Training Iteration 1500 of epoch 0 complete. Loss: 0.1260056048631668


  0%|          | 2015/5922867 [00:14<14:33:55, 112.92it/s]

Training Iteration 2000 of epoch 0 complete. Loss: 0.9727218747138977


  0%|          | 2520/5922867 [00:18<13:01:02, 126.33it/s]

Training Iteration 2500 of epoch 0 complete. Loss: 0.37654462456703186


  0%|          | 3013/5922867 [00:22<13:33:48, 121.24it/s]

Training Iteration 3000 of epoch 0 complete. Loss: 2.6680634021759033


  0%|          | 3527/5922867 [00:26<14:19:38, 114.76it/s]

Training Iteration 3500 of epoch 0 complete. Loss: 0.11133483797311783


  0%|          | 4027/5922867 [00:30<11:22:47, 144.48it/s]

Training Iteration 4000 of epoch 0 complete. Loss: 0.07220134139060974


  0%|          | 4528/5922867 [00:33<10:46:56, 152.47it/s]

Training Iteration 4500 of epoch 0 complete. Loss: 0.1188773438334465


  0%|          | 5023/5922867 [00:37<13:13:29, 124.30it/s]

Training Iteration 5000 of epoch 0 complete. Loss: 1.1283870935440063


  0%|          | 5525/5922867 [00:41<12:49:25, 128.18it/s]

Training Iteration 5500 of epoch 0 complete. Loss: 0.25366970896720886


  0%|          | 6025/5922867 [00:45<12:19:23, 133.37it/s]

Training Iteration 6000 of epoch 0 complete. Loss: 0.12776094675064087


  0%|          | 6518/5922867 [00:50<17:49:02, 92.24it/s] 

Training Iteration 6500 of epoch 0 complete. Loss: 0.06721433252096176


  0%|          | 7025/5922867 [00:54<11:48:53, 139.09it/s]

Training Iteration 7000 of epoch 0 complete. Loss: 0.4834219813346863


  0%|          | 7520/5922867 [00:59<12:34:43, 130.63it/s]

Training Iteration 7500 of epoch 0 complete. Loss: 0.42419925332069397


  0%|          | 8022/5922867 [01:03<12:21:40, 132.92it/s]

Training Iteration 8000 of epoch 0 complete. Loss: 0.08208039402961731


  0%|          | 8523/5922867 [01:07<12:58:28, 126.62it/s]

Training Iteration 8500 of epoch 0 complete. Loss: 1.5559724569320679


  0%|          | 9023/5922867 [01:10<11:42:21, 140.33it/s]

Training Iteration 9000 of epoch 0 complete. Loss: 0.09450817108154297


  0%|          | 9525/5922867 [01:14<11:35:28, 141.71it/s]

Training Iteration 9500 of epoch 0 complete. Loss: 2.1284477710723877


  0%|          | 10022/5922867 [01:18<12:32:06, 131.03it/s]

Training Iteration 10000 of epoch 0 complete. Loss: 0.0983513817191124


  0%|          | 10528/5922867 [01:22<11:23:57, 144.07it/s]

Training Iteration 10500 of epoch 0 complete. Loss: 0.21564935147762299


  0%|          | 11022/5922867 [01:26<12:39:23, 129.75it/s]

Training Iteration 11000 of epoch 0 complete. Loss: 0.3915429711341858


  0%|          | 11525/5922867 [01:30<13:45:04, 119.41it/s]

Training Iteration 11500 of epoch 0 complete. Loss: 1.6647355556488037


  0%|          | 12024/5922867 [01:34<14:06:59, 116.31it/s]

Training Iteration 12000 of epoch 0 complete. Loss: 0.1953759640455246


  0%|          | 12524/5922867 [01:38<13:13:49, 124.09it/s]

Training Iteration 12500 of epoch 0 complete. Loss: 4.433499336242676


  0%|          | 13021/5922867 [01:42<12:02:10, 136.39it/s]

Training Iteration 13000 of epoch 0 complete. Loss: 0.03486473113298416


  0%|          | 13528/5922867 [01:46<11:47:18, 139.25it/s]

Training Iteration 13500 of epoch 0 complete. Loss: 0.14394450187683105


  0%|          | 14026/5922867 [01:50<12:11:09, 134.69it/s]

Training Iteration 14000 of epoch 0 complete. Loss: 0.2779296636581421


  0%|          | 14527/5922867 [01:54<12:03:11, 136.16it/s]

Training Iteration 14500 of epoch 0 complete. Loss: 0.25119680166244507


  0%|          | 15013/5922867 [01:58<12:48:00, 128.21it/s]

Training Iteration 15000 of epoch 0 complete. Loss: 0.2099154144525528


  0%|          | 15520/5922867 [02:01<12:32:00, 130.92it/s]

Training Iteration 15500 of epoch 0 complete. Loss: 0.21651123464107513


  0%|          | 16021/5922867 [02:05<12:59:52, 126.23it/s]

Training Iteration 16000 of epoch 0 complete. Loss: 1.0933839082717896


  0%|          | 16514/5922867 [02:10<14:57:15, 109.71it/s]

Training Iteration 16500 of epoch 0 complete. Loss: 0.11075998842716217


  0%|          | 16864/5922867 [02:13<12:58:35, 126.43it/s]


KeyboardInterrupt: 

In [None]:
def test_model(data_dir, _batch_size=16, model_path = 'output/model0.pth'):
    testSet = ngsimDataset(data_dir+'TestSet.mat')
    test_loader = DataLoader(testSet,batch_size=_batch_size,shuffle=False,collate_fn=testSet.collate_fn)
    
    model = Maneuver_class(14,128).to(device)
    ckpt = torch.load(model_path)
    model.load_state_dict(ckpt['state_dict'])
    model.eval()

    for data in tqdm(test_loader):
        hist, nbrs, mask, lat_enc, lon_enc, fut, op_mas = data
        hist,nbrs,lat_enc, lon_enc = hist.to(device), nbrs.to(device), lat_enc.to(device), lon_enc.to(device) 
        lat_pred, lon_pred = model(hist)  #这里还缺padding 操作