In [18]:
from __future__ import print_function, division
from torch.utils.data import Dataset, DataLoader
import scipy.io as scp
import numpy as np
import torch
import time
import torch.nn.functional as F

In [79]:
class ngsimDataset(Dataset):


    def __init__(self, mat_file, t_h=30, t_f=50, d_s=2, enc_size = 64, grid_size = (13,3)):
        self.D = scp.loadmat(mat_file)['traj']
        self.T = scp.loadmat(mat_file)['tracks']
        self.t_h = t_h  # length of track history
        self.t_f = t_f  # length of predicted trajectory
        self.d_s = d_s  # down sampling rate of all sequences
        self.enc_size = enc_size # size of encoder LSTM
        self.grid_size = grid_size # size of social context grid



    def __len__(self):
        return len(self.D)



    def __getitem__(self, idx):

        dsId = self.D[idx, 0].astype(int)
        vehId = self.D[idx, 1].astype(int)
        t = self.D[idx, 2]
        grid = self.D[idx,8:]
        neighbors = []

        # Get track history 'hist' = ndarray, and future track 'fut' = ndarray
        hist = self.getHistory(vehId,t,vehId,dsId)
        fut = self.getFuture(vehId,t,dsId)

        # Get track histories of all neighbours 'neighbors' = [ndarray,[],ndarray,ndarray]
        for i in grid:
            neighbors.append(self.getHistory(i.astype(int), t,vehId,dsId))

        # Maneuvers 'lon_enc' = one-hot vector, 'lat_enc = one-hot vector
        lon_enc = np.zeros([2])
        lon_enc[int(self.D[idx, 7] - 1)] = 1
        lat_enc = np.zeros([3])
        lat_enc[int(self.D[idx, 6] - 1)] = 1

        return hist,fut,neighbors,lat_enc,lon_enc



    ## Helper function to get track history
    def getHistory(self,vehId,t,refVehId,dsId):
        if vehId == 0:
            return np.empty([0,2])
        else:
            if self.T.shape[1]<=vehId-1:
                return np.empty([0,2])
            refTrack = self.T[dsId-1][refVehId-1].transpose()
            vehTrack = self.T[dsId-1][vehId-1].transpose()
            refPos = refTrack[np.where(refTrack[:,0]==t)][0,1:3]

            if vehTrack.size==0 or np.argwhere(vehTrack[:, 0] == t).size==0:
                 return np.empty([0,2])
            else:
                stpt = np.maximum(0, np.argwhere(vehTrack[:, 0] == t).item() - self.t_h)
                enpt = np.argwhere(vehTrack[:, 0] == t).item() + 1
                hist = vehTrack[stpt:enpt:self.d_s,1:3]-refPos

            if len(hist) < self.t_h//self.d_s + 1:
                return np.empty([0,2])
            return hist



    ## Helper function to get track future
    def getFuture(self, vehId, t,dsId):
        vehTrack = self.T[dsId-1][vehId-1].transpose()
        refPos = vehTrack[np.where(vehTrack[:, 0] == t)][0, 1:3]
        stpt = np.argwhere(vehTrack[:, 0] == t).item() + self.d_s
        enpt = np.minimum(len(vehTrack), np.argwhere(vehTrack[:, 0] == t).item() + self.t_f + 1)
        fut = vehTrack[stpt:enpt:self.d_s,1:3]-refPos
        return fut



    ## Collate function for dataloader
    def collate_fn(self, samples):

        # Initialize neighbors and neighbors length batches:
        nbr_batch_size = 0
        for _,_,nbrs,_,_ in samples:
            nbr_batch_size += sum([len(nbrs[i])!=0 for i in range(len(nbrs))])
        maxlen = self.t_h//self.d_s + 1
        nbrs_batch = torch.zeros(maxlen,nbr_batch_size,2)


        # Initialize social mask batch:
        pos = [0, 0]
        mask_batch = torch.zeros(len(samples), self.grid_size[1],self.grid_size[0],self.enc_size)
        mask_batch = mask_batch.byte()


        # Initialize history, history lengths, future, output mask, lateral maneuver and longitudinal maneuver batches:
        hist_batch = torch.zeros(maxlen,len(samples),2)
        fut_batch = torch.zeros(self.t_f//self.d_s,len(samples),2)
        op_mask_batch = torch.zeros(self.t_f//self.d_s,len(samples),2)
        lat_enc_batch = torch.zeros(len(samples),3)
        lon_enc_batch = torch.zeros(len(samples), 2)


        count = 0
        for sampleId,(hist, fut, nbrs, lat_enc, lon_enc) in enumerate(samples):

            # Set up history, future, lateral maneuver and longitudinal maneuver batches:
            hist_batch[0:len(hist),sampleId,0] = torch.from_numpy(hist[:, 0])
            hist_batch[0:len(hist), sampleId, 1] = torch.from_numpy(hist[:, 1])
            fut_batch[0:len(fut), sampleId, 0] = torch.from_numpy(fut[:, 0])
            fut_batch[0:len(fut), sampleId, 1] = torch.from_numpy(fut[:, 1])
            op_mask_batch[0:len(fut),sampleId,:] = 1
            lat_enc_batch[sampleId,:] = torch.from_numpy(lat_enc)
            lon_enc_batch[sampleId, :] = torch.from_numpy(lon_enc)

            # Set up neighbor, neighbor sequence length, and mask batches:
            for id,nbr in enumerate(nbrs):
                if len(nbr)!=0:
                    nbrs_batch[0:len(nbr),count,0] = torch.from_numpy(nbr[:, 0])
                    nbrs_batch[0:len(nbr), count, 1] = torch.from_numpy(nbr[:, 1])
                    pos[0] = id % self.grid_size[0]
                    pos[1] = id // self.grid_size[0]
                    mask_batch[sampleId,pos[1],pos[0],:] = torch.ones(self.enc_size).byte()
                    count+=1

        return hist_batch, nbrs_batch, mask_batch, lat_enc_batch, lon_enc_batch, fut_batch, op_mask_batch

In [7]:
data_dir = '/mnt/e/Northwestern University/Courses/2022winter/pattern recognition/projects/data/'
trSet = ngsimDataset(data_dir+'TrainSet.mat')
valSet = ngsimDataset(data_dir+'ValSet.mat')
#trDataloader = DataLoader(trSet,batch_size=20,shuffle=True,collate_fn=trSet.collate_fn)
#valDataloader = DataLoader(valSet,batch_size=20,shuffle=True,collate_fn=valSet.collate_fn)

In [9]:
for i, data in enumerate(trDataloader):
    hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask = data
    print(lat_enc.shape)
    print(lon_enc.shape)


torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 2])
torch.Size([

KeyboardInterrupt: 

In [8]:
trDataloader = DataLoader(trSet,batch_size=1,shuffle=True,collate_fn=trSet.collate_fn)
valDataloader = DataLoader(valSet,batch_size=1,shuffle=True,collate_fn=valSet.collate_fn)

In [5]:
type(trDataloader)

torch.utils.data.dataloader.DataLoader

In [17]:
data_dir = '/mnt/e/Northwestern University/Courses/2022winter/pattern recognition/projects/data/'
def old_padding_batching(data_dir, batch_size):

    Set = ngsimDataset(data_dir)
    Dataloader = DataLoader(Set,batch_size=1,shuffle=True,collate_fn=valSet.collate_fn)
    dataset = torch.empty(1,16,batch_size,14)
    batch_data = torch.empty(16,batch_size,14)
    for i, data in enumerate(Dataloader):
    
        hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask = data
   
    # padding for batchsize = 1:
    
        if nbrs.shape[1] >= 6:
            for j in range(6):
                hist = torch.cat((hist, nbrs[:,j,:].view(16,1,2)),2)
        if nbrs.shape[1] < 6:
            for j in range(nbrs.shape[1]):
                hist = torch.cat((hist, nbrs[:,j,:].view(16,1,2)),2)
            
            zeros = torch.zeros(hist.shape[0],6 - nbrs.shape[1], 2)
        
            for j in range(6-nbrs.shape[1]):
                hist = torch.cat((hist, zeros[:,j,:].view(16,1,2)),2)
    
        if i%batch_size == 0 and i != 0:
            dataset = torch.cat((dataset, batch_data.view(1,16,batch_size,14)),0)
            print(dataset.shape)
        
        batch_data[:,i%batch_size, :] = hist.squeeze()
    
        if i ==batch_size-1:
            dataset[0] = batch_data
    
    torch.save(dataset, 'dataset.pt')
        
        
         
            
            

In [3]:
2%64

2

# 新版padding_batching

In [51]:

def padding_batching(data_dir, batch_size):
    
    Set = ngsimDataset(data_dir)
    Dataloader = DataLoader(Set,batch_size=1,shuffle=True,collate_fn=Set.collate_fn)
    dataset = torch.empty(1,16,batch_size,14)
    batch_data = torch.empty(16,batch_size,14)
    batch_lat = torch.empty(batch_size,3)
    batch_lon = torch.empty(batch_size,2)
    batch_fut = torch.empty(25,batch_size,2)
    batch_opmask = torch.empty(25,batch_size,2)
    for i, data in enumerate(Dataloader):
    
        hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask = data
   
    # padding for batchsize = 1:
    
        if nbrs.shape[1] >= 6:
            for j in range(6):
                hist = torch.cat((hist, nbrs[:,j,:].view(16,1,2)),2)
        if nbrs.shape[1] < 6:
            
            hist = torch.cat((hist, nbrs.view(16,1,-1)),2)
            pad = (0,2 * (6 - nbrs.shape[1]))
            hist = F.pad(hist, pad, "constant", 0)
            
    # batching:
        if i%batch_size == 0 and i != 0:
            dataset = torch.cat((dataset, batch_data.view(1,16,batch_size,14)),0)
            lat_set = torch.cat((lat_set, batch_lat.view(1,batch_size,3)),0)
            lon_set = torch.cat((lon_set, batch_lon.view(1,batch_size,2)),0)
            fut_set = torch.cat((fut_set, batch_fut.view(1,25,batch_size,2)),0)
            opmask_set = torch.cat((opmask_set, batch_opmask.view(1,25,batch_size,2)),0)
            
        
        batch_data[:,i%batch_size, :] = hist.squeeze()
        batch_lat[i%batch_size] = lat_enc.squeeze()
        batch_lon[i%batch_size] = lon_enc.squeeze()
        batch_fut[:,i%batch_size, :] = fut.squeeze()
        batch_opmask[:,i%batch_size, :] = op_mask.squeeze()
        
        if i ==batch_size-1:
            dataset[0] = batch_data
            lat_set = batch_lat.view(1,batch_size, 3)
            lon_set = batch_lon.view(1,batch_size, 2)
            fut_set =  batch_fut.view(1,25,batch_size,2)
            opmask_set = batch_opmask.view(1,25,batch_size,2)
            
    output = (dataset, lat_set, lon_set,fut_set, opmask_set)
    torch.save(output, 'trainset_batchSize='+str(batch_size)+'.pt')    
    
    return (dataset, lat_set, lon_set,fut_set, opmask_set)

In [53]:
data_dir = '/mnt/e/Northwestern University/Courses/2022winter/pattern recognition/projects/data/TrainSet.mat'
padding_batching(data_dir, 64)

# 如何封装到torch.dataset

In [None]:
load_val = torch.load('valset_batchSize=64.pt')

# test

In [48]:
cout = 3
a = torch.ones(16,1,2)
b = torch.zeros(16,5,2)
a = torch.cat((a,b),1)
c = (a,b)
torch.save(c, 'temp_c'+str(cout)+'.pt')

In [44]:
load_c = torch.load('temp_c.pt')
load_c[0].shape
print(load_c[0].shape)
print(a.shape)
if torch.equal(load_c[0], a):
    print('right')
else:
    print('wrong')

torch.Size([16, 6, 2])
torch.Size([16, 6, 2])
right


In [16]:
print(dataset.shape)
if torch.equal(dataset[432], dataset[131]):
    print('wrong')
else:
    print('right')

torch.Size([2659, 16, 20, 14])
right


In [6]:
ty = torch.ones(1,2,3)
ooo = ty.squeeze()
iii = torch.randn(1,2,3)
if torch.equal(ty, iii):
    print('wrong')
else:
    print('right')

right


In [11]:
ty = torch.ones(2,2,3)
iii = torch.zeros(2,2,3)
ones = torch.ones(2,1,3)
iii[:,0,:] = ones.squeeze()
print(iii)

tensor([[[1., 1., 1.],
         [0., 0., 0.]],

        [[1., 1., 1.],
         [0., 0., 0.]]])


In [12]:
load_val = torch.load('valset_batchSize=64.pt')

In [17]:
print(load_val[3].shape)
print(load_val[0][0].shape)
print(load_val[0][234].shape)

torch.Size([13434, 25, 64, 2])
torch.Size([16, 64, 14])
torch.Size([16, 64, 14])


In [15]:
if torch.equal(load_val[0][0], load_val[0][64]):
    print('wrong')
else:
    print('right')

right


# Data Padding

In [17]:
for i, data in enumerate(trDataloader):
    hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask = data
    print(hist.shape)
    print(nbrs.shape)
    print(mask.shape)
    msoc_enc = torch.zeros_like(masks).float()
    soc_enc = soc_enc.masked_scatter_(masks, nbrs_enc)masked_scatter_

torch.Size([16, 1, 2])
torch.Size([16, 10, 2])
torch.Size([1, 3, 13, 64])
torch.Size([16, 1, 2])
torch.Size([16, 6, 2])
torch.Size([1, 3, 13, 64])
torch.Size([16, 1, 2])
torch.Size([16, 8, 2])
torch.Size([1, 3, 13, 64])
torch.Size([16, 1, 2])
torch.Size([16, 5, 2])
torch.Size([1, 3, 13, 64])
torch.Size([16, 1, 2])
torch.Size([16, 7, 2])
torch.Size([1, 3, 13, 64])
torch.Size([16, 1, 2])
torch.Size([16, 8, 2])
torch.Size([1, 3, 13, 64])
torch.Size([16, 1, 2])
torch.Size([16, 10, 2])
torch.Size([1, 3, 13, 64])
torch.Size([16, 1, 2])
torch.Size([16, 6, 2])
torch.Size([1, 3, 13, 64])
torch.Size([16, 1, 2])
torch.Size([16, 11, 2])
torch.Size([1, 3, 13, 64])
torch.Size([16, 1, 2])
torch.Size([16, 3, 2])
torch.Size([1, 3, 13, 64])
torch.Size([16, 1, 2])
torch.Size([16, 5, 2])
torch.Size([1, 3, 13, 64])
torch.Size([16, 1, 2])
torch.Size([16, 6, 2])
torch.Size([1, 3, 13, 64])
torch.Size([16, 1, 2])
torch.Size([16, 5, 2])
torch.Size([1, 3, 13, 64])
torch.Size([16, 1, 2])
torch.Size([16, 15, 2])
t

KeyboardInterrupt: 

In [19]:
t1 = torch.rand(16, 10)
t2 = t1.clone()

m = np.zeros((1, 10), dtype=np.int64)
m[0, :5] = 1
m = torch.tensor(m, dtype=torch.bool)
print(m)
print(t1)

print(t2)

print(t1.masked_scatter(m, t2))

tensor([[ True,  True,  True,  True,  True, False, False, False, False, False]])
tensor([[0.5651, 0.0893, 0.7702, 0.6041, 0.5326, 0.0232, 0.1271, 0.0067, 0.6543,
         0.9711],
        [0.9388, 0.2752, 0.3928, 0.5969, 0.5196, 0.1880, 0.4063, 0.5197, 0.5924,
         0.0421],
        [0.5485, 0.1370, 0.2724, 0.0830, 0.6892, 0.2863, 0.6979, 0.4705, 0.1858,
         0.4961],
        [0.7408, 0.6733, 0.3592, 0.2614, 0.8976, 0.5361, 0.3027, 0.8614, 0.3614,
         0.9039],
        [0.5738, 0.4041, 0.3693, 0.6725, 0.5270, 0.7226, 0.7086, 0.4040, 0.8917,
         0.4651],
        [0.8264, 0.9439, 0.0268, 0.4909, 0.6159, 0.4902, 0.1912, 0.9741, 0.5714,
         0.8919],
        [0.1610, 0.3930, 0.4787, 0.1166, 0.4690, 0.5403, 0.8536, 0.6864, 0.2693,
         0.8966],
        [0.7910, 0.1946, 0.9483, 0.0070, 0.0746, 0.6892, 0.7834, 0.9701, 0.3543,
         0.1137],
        [0.7598, 0.6623, 0.2350, 0.0906, 0.8586, 0.3470, 0.6179, 0.0538, 0.3824,
         0.9718],
        [0.9173, 0.1193, 0.3

In [19]:
data_dir = '/mnt/e/Northwestern University/Courses/2022winter/pattern recognition/projects/data/ValSet.mat'
load_val = scp.loadmat(data_dir)

In [20]:
type(load_val)

dict

In [21]:
load_val.keys()

dict_keys(['__header__', '__version__', '__globals__', 'tracks', 'traj'])

In [36]:
tracks = load_val['tracks']

In [31]:
traj = load_val['traj']

In [48]:
idx = 20
grid = traj[idx,8:]
dsId = traj[idx, 0].astype(int) # traj的第一维度
vehId = traj[idx, 1].astype(int)
t = traj[idx, 2]
t_h=30
t_f=50
d_s=2
enc_size = 64
grid_size = (13,3)
print(grid.shape)
nbrs = []
def getHistory(tracks,vehId,t,refVehId,dsId,t_h=30,d_s=2):
        if vehId == 0:
            return np.empty([0,2])
        else:
            if tracks.shape[1]<=vehId-1:
                return np.empty([0,2])
            refTrack = tracks[dsId-1][refVehId-1].transpose()
            vehTrack = tracks[dsId-1][vehId-1].transpose()
            refPos = refTrack[np.where(refTrack[:,0]==t)][0,1:3]

            if vehTrack.size==0 or np.argwhere(vehTrack[:, 0] == t).size==0:
                 return np.empty([0,2])
            else:
                stpt = np.maximum(0, np.argwhere(vehTrack[:, 0] == t).item() - t_h)
                enpt = np.argwhere(vehTrack[:, 0] == t).item() + 1
                hist = vehTrack[stpt:enpt:d_s,1:3]-refPos

            if len(hist) < t_h//d_s + 1:
                return np.empty([0,2])
            return hist
for i in grid:
            neighbors = getHistory(tracks, i.astype(int), t,vehId,dsId,t_h=30,d_s=2)
            nbrs.append(neighbors)
            print(str(i) + ':')
            print(neighbors)
print(len(nbrs))

(39,)
0.0:
[]
2186.0:
[[ -11.649    -148.25699 ]
 [ -11.675001 -143.25699 ]
 [ -11.701    -138.25699 ]
 [ -11.727001 -133.258   ]
 [ -11.752001 -128.258   ]
 [ -11.778002 -123.257996]
 [ -11.804001 -118.257996]
 [ -11.830002 -113.257996]
 [ -11.854002 -108.257996]
 [ -11.880001 -103.257996]
 [ -11.906002  -98.229   ]
 [ -11.931002  -93.40499 ]
 [ -11.879002  -88.746994]
 [ -11.818001  -83.746994]
 [ -11.758001  -78.746994]
 [ -11.696001  -73.74799 ]]
0.0:
[]
0.0:
[]
0.0:
[]
2182.0:
[[-12.983002  -94.410995 ]
 [-12.9470005 -89.74     ]
 [-12.887001  -84.74     ]
 [-12.826     -79.74     ]
 [-12.765001  -74.741    ]
 [-12.704     -69.741    ]
 [-12.644001  -64.742    ]
 [-12.583     -59.741997 ]
 [-12.522001  -54.742996 ]
 [-12.461     -49.742996 ]
 [-12.401001  -44.744995 ]
 [-12.34      -39.763    ]
 [-12.276001  -34.58     ]
 [-12.199001  -29.276001 ]
 [-12.151001  -24.202988 ]
 [-12.085001  -19.393997 ]]
0.0:
[]
0.0:
[]
0.0:
[]
2173.0:
[]
0.0:
[]
0.0:
[]
0.0:
[]
0.0:
[]
0.0:
[]
2179.

In [51]:
yi = torch.ones(2,3)
lin = torch.zeros(2,3)
yi = torch.cat((yi, lin), 1)
print(yi)

tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.]])


In [60]:
array = np.array([[1,2,3],[],[2,3,4]])
print(array.size)
Ten = torch.from_numpy(array)

3


  array = np.array([[1,2,3],[],[2,3,4]])


TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.

In [64]:
ls = np.array([[1,2,3],[1,2,3]])
print(len(ls))
print(torch.from_numpy(ls[:, 0]))

2
tensor([1, 1])


In [65]:
torch.from_numpy(ls).shape[0]

2

In [80]:
class new_ngsimDataset(Dataset):
    
    def __init__(self, mat_file, t_h=30, t_f=50, d_s=2, enc_size = 64, grid_size = (13,3)):
        self.D = scp.loadmat(mat_file)['traj']
        self.T = scp.loadmat(mat_file)['tracks']
        self.t_h = t_h  # length of track history
        self.t_f = t_f  # length of predicted trajectory
        self.d_s = d_s  # down sampling rate of all sequences
        self.enc_size = enc_size # size of encoder LSTM
        self.grid_size = grid_size # size of social context grid


    def __len__(self):
        return len(self.D)

    def __getitem__(self, idx):

        dsId = self.D[idx, 0].astype(int) # traj的第一维度
        vehId = self.D[idx, 1].astype(int)
        t = self.D[idx, 2]
        grid = self.D[idx,8:]
        neighbors = []

        # Get track history 'hist' = ndarray, and future track 'fut' = ndarray
        hist = self.getHistory(vehId,t,vehId,dsId)  # get 待测车的 traj
        fut = self.getFuture(vehId,t,dsId)          # get 待测车的 未来轨迹

        # Get track histories of all neighbours 'neighbors' = [ndarray,[],ndarray,ndarray]
        for i in grid:
            neighbors.append(self.getHistory(i.astype(int), t,vehId,dsId))

        # Maneuvers 'lon_enc' = one-hot vector, 'lat_enc = one-hot vector
        lon_enc = np.zeros([2])
        lon_enc[int(self.D[idx, 7] - 1)] = 1
        lat_enc = np.zeros([3])
        lat_enc[int(self.D[idx, 6] - 1)] = 1

        return hist,fut,neighbors,lat_enc,lon_enc

    ## Helper function to get track history
    def getHistory(self,vehId,t,refVehId,dsId):
        if vehId == 0:
            return np.empty([0,2])
        else:
            if self.T.shape[1]<=vehId-1:
                return np.empty([0,2])
            refTrack = self.T[dsId-1][refVehId-1].transpose()
            vehTrack = self.T[dsId-1][vehId-1].transpose()
            refPos = refTrack[np.where(refTrack[:,0]==t)][0,1:3]

            if vehTrack.size==0 or np.argwhere(vehTrack[:, 0] == t).size==0:
                 return np.empty([0,2])
            else:
                stpt = np.maximum(0, np.argwhere(vehTrack[:, 0] == t).item() - self.t_h)
                enpt = np.argwhere(vehTrack[:, 0] == t).item() + 1
                hist = vehTrack[stpt:enpt:self.d_s,1:3]-refPos

            if len(hist) < self.t_h//self.d_s + 1:
                return np.empty([0,2])
            return hist


    ## Helper function to get track future
    def getFuture(self, vehId, t,dsId):
        vehTrack = self.T[dsId-1][vehId-1].transpose()
        refPos = vehTrack[np.where(vehTrack[:, 0] == t)][0, 1:3]
        stpt = np.argwhere(vehTrack[:, 0] == t).item() + self.d_s
        enpt = np.minimum(len(vehTrack), np.argwhere(vehTrack[:, 0] == t).item() + self.t_f + 1)
        fut = vehTrack[stpt:enpt:self.d_s,1:3]-refPos
        return fut

    ## Collate function for dataloader
    def collate_fn(self, samples):

        maxlen = self.t_h//self.d_s + 1
        # Initialize social mask batch:
        pos = [0, 0]
        mask_batch = torch.zeros(len(samples), self.grid_size[1],self.grid_size[0],self.enc_size)
        mask_batch = mask_batch.byte()

        # Initialize history, history lengths, future, output mask, lateral maneuver and longitudinal maneuver batches:
        hist_batch = torch.zeros(maxlen,len(samples),14)    #加入neighbours后变为14
        fut_batch = torch.zeros(self.t_f//self.d_s,len(samples),2)
        op_mask_batch = torch.zeros(self.t_f//self.d_s,len(samples),2)
        lat_enc_batch = torch.zeros(len(samples),3)
        lon_enc_batch = torch.zeros(len(samples), 2)
        

        count = 0
        for sampleId,(hist, fut, nbrs, lat_enc, lon_enc) in enumerate(samples):
            hist = torch.from_numpy(hist)   # （16， 2）
            it = 0
            for id,nbr in enumerate(nbrs):
                
                if it <= 6:
                    if len(nbr)!=0:
                        nbr = torch.from_numpy(nbr)     #（16，2）
                        hist = torch.cat((hist,nbr),1)  # 拼接ego和nbrs
                        it += 1
            if hist.shape[1] < 14:
                p1d = (0,14-hist.shape[1])
                hist = F.pad(hist, p1d, "constant", 0)
                
            # Set up history, future, lateral maneuver and longitudinal maneuver batches:
            for k in range(14):
                hist_batch[0:hist.shape[0],sampleId,k] = hist[:, k]
            fut_batch[0:len(fut), sampleId, 0] = torch.from_numpy(fut[:, 0])
            fut_batch[0:len(fut), sampleId, 1] = torch.from_numpy(fut[:, 1])
            op_mask_batch[0:len(fut),sampleId,:] = 1
            lat_enc_batch[sampleId,:] = torch.from_numpy(lat_enc)
            lon_enc_batch[sampleId, :] = torch.from_numpy(lon_enc)
            
            

            # Set up neighbor, neighbor sequence length, and mask batches:
            for id,nbr in enumerate(nbrs):
                if len(nbr)!=0:
                    
                    pos[0] = id % self.grid_size[0]
                    pos[1] = id // self.grid_size[0]
                    mask_batch[sampleId,pos[1],pos[0],:] = torch.ones(self.enc_size).byte()
                    count+=1

        return hist_batch, mask_batch, lat_enc_batch, lon_enc_batch, fut_batch, op_mask_batch

In [84]:
data_dir_test = '/mnt/e/Northwestern University/Courses/2022winter/pattern recognition/projects/data/ValSet.mat'
tsSet = ngsimDataset(data_dir_test)
tsDataloader = DataLoader(tsSet,batch_size=1,shuffle=False,collate_fn=tsSet.collate_fn)
for i, data in enumerate(tsDataloader):
    hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask = data
    if i!=0:
        print(hist[0,0,:])
        print(nbrs[0,0,:])

tensor([ -0.1790, -88.8260])
tensor([-11.7580, -91.3400])
tensor([ -0.2280, -89.3260])
tensor([-11.8100, -91.3390])
tensor([ -0.2800, -89.3260])
tensor([-11.8600, -91.3390])
tensor([ -0.3320, -89.3260])
tensor([-11.9130, -91.3380])
tensor([ -0.3850, -89.3260])
tensor([-11.9660, -91.3390])
tensor([ -0.4360, -89.3240])
tensor([-12.0180, -91.3370])
tensor([ -0.4880, -89.3240])
tensor([-12.0680, -91.3370])
tensor([ -0.5400, -89.3250])
tensor([-12.1210, -91.3390])
tensor([ -0.5920, -89.3250])
tensor([-12.1730, -91.3360])
tensor([ -0.6440, -89.3250])
tensor([-12.2250, -91.3180])
tensor([ -0.6960, -89.3510])
tensor([-12.2770, -91.3330])
tensor([ -0.7480, -89.3330])
tensor([-12.2910, -91.3640])
tensor([ -0.8000, -89.1910])
tensor([-12.3280, -91.4160])
tensor([ -0.8540, -88.9790])
tensor([-12.3950, -91.5110])
tensor([ -0.9490, -88.7970])
tensor([-12.5250, -91.7800])
tensor([ -1.0430, -88.6950])
tensor([-12.6160, -92.2280])
tensor([ -1.1370, -88.7510])
tensor([-12.7080, -92.7640])
tensor([ -1.23

IndexError: index 0 is out of bounds for dimension 1 with size 0

In [85]:
data_dir_test = '/mnt/e/Northwestern University/Courses/2022winter/pattern recognition/projects/data/ValSet.mat'
Set = new_ngsimDataset(data_dir)
Dataloader = DataLoader(Set,batch_size=1,shuffle=False,collate_fn=Set.collate_fn)
for i, data in enumerate(Dataloader):
    hist, mask, lat_enc, lon_enc, fut, op_mask = data
    print(hist[0,0,:])

    #if i == 20:
        #hist_60 = hist
    #if i ==60:
        #hist_160 = hist
    #if i == 161:
        #if torch.equal(hist_60, hist_160):
            #print('wrong')
        #else:
            #print('right')

tensor([ -0.1280, -88.8260,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
          0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000])
tensor([ -0.1790, -88.8260, -11.7580, -91.3400,   0.0000,   0.0000,   0.0000,
          0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000])
tensor([ -0.2280, -89.3260, -11.8100, -91.3390,   0.0000,   0.0000,   0.0000,
          0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000])
tensor([ -0.2800, -89.3260, -11.8600, -91.3390,   0.0000,   0.0000,   0.0000,
          0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000])
tensor([ -0.3320, -89.3260, -11.9130, -91.3380,   0.0000,   0.0000,   0.0000,
          0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000])
tensor([ -0.3850, -89.3260, -11.9660, -91.3390,   0.0000,   0.0000,   0.0000,
          0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000])
tensor([ -0.4360, -89.3240, -12.0180, -91.3370,   0.0000, 

KeyboardInterrupt: 