In [1]:
import numpy as np
import pickle
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import weight_norm
import time
import torch.nn.functional as F
import math
from torch.autograd import Variable
import os
import argparse

In [2]:
SEED=5
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [3]:
train_datasets = r"/root/data_apolloscape/train_data_119.pkl"
test_datasets = r"/root/data_apolloscape/test_data_119.pkl"
model_save_path = r"/root/data_apolloscape/model_save/"
max_object_num = 115
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
lr = 0.001
batch_size = 32
epochs = 200
dropout = 0.2
in_channels = 6
out_channels = 2
hidden_size = 64
heads = 8
layers = 5
history_frames = 6
future_frames = 6
kernel_size = [2,3]
paddings = [[1,2,2],[2,2,2]]
dilations = [[1,2,2],[1,1,1]]

# kernel_size = [2,5]
# paddings = [[1,2,2],[4,4]]
# dilations = [[1,2,2],[1,1]]

In [4]:
class Feeder(torch.utils.data.Dataset):
    def __init__(self,
                 data_path,
                 data_cache,
                 train_percent=0.8,
                 train_val_test='train'):

        self.data_path = data_path
        self.data_cache = data_cache
        self.train_val_test = train_val_test

        self.load_data()

        total_num = len(self.all_data)
        # equally choose validation set
        train_id_list = list(np.linspace(0, total_num-1, int(total_num*train_percent)).astype(int))
        val_id_list = list(set(list(range(total_num))) - set(train_id_list))

        # # last 20% data as validation set
        if train_val_test.lower() == 'train':
            self.all_data = self.all_data[train_id_list]
        elif train_val_test.lower() == 'val':
            self.all_data = self.all_data[val_id_list]

    def load_data(self):

        with open(self.data_cache, 'rb') as reader:
            [self.all_data] = pickle.load(reader)

    def __len__(self):
        return len(self.all_data)

    def __getitem__(self, idx):

        data = self.all_data[idx].copy()

        if self.train_val_test.lower() == 'train':
            th = np.random.random() * np.pi * 2
            data['features'][:, :, 0] = data['features'][:, :, 0] * np.cos(th) - data['features'][:, :, 0] * np.sin(th)
            data['features'][:, :, 1] = data['features'][:, :, 1] * np.sin(th) + data['features'][:, :, 1] * np.cos(th)
            
        if self.train_val_test.lower() == 'test':
            return data['features'],data['masks'],data['origin'],data['distance_adj'],data['heading_adj'],data["mean"]
        else:
            return data['features'],data['masks'],data['distance_adj'],data['heading_adj'],data["mean"]

In [5]:
trainLoader = Feeder(r"/kaggle/input/",train_datasets, 0.8, 'train')
train_loader = DataLoader(dataset=trainLoader, batch_size=batch_size,shuffle=True,num_workers=2)
valLoader = Feeder(r"/kaggle/input/", train_datasets, 0.8, 'val')
val_loader = DataLoader(dataset=valLoader,batch_size=batch_size,shuffle=True,num_workers=2)

In [6]:
def get_lap(adj):
    adj = torch.nan_to_num(adj / torch.sum(adj, dim=-1).unsqueeze(dim=-1), nan=0)
    return adj


# def get_lap(adj):
#     # (64,6,115,115)
#     batch,step,num_object = adj.shape[0], adj.shape[1],adj.shape[2]
#     adj = adj.reshape(batch*step,num_object,num_object) #(64*6,115,115)
#     D = torch.sum(adj, dim=-1)
#     D = D**(-0.5)
#     D[D==torch.inf]=0
#     D = torch.diag_embed(D)
#     lap = torch.matmul(torch.matmul(D,adj),D)
#     lap = lap.reshape(batch,step,num_object,num_object)
#     return lap


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForward, self).__init__()
        self._linear1 = nn.Linear(d_model, d_ff)
        self._linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self._linear2(F.relu(self._linear1(x)))


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len

        # 创建位置编码矩阵
        pe = torch.zeros(self.max_seq_len, self.d_model).cuda()

        # 计算位置编码的值
        position = torch.arange(0, self.max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2) * (-math.log(10000.0) / self.d_model))

        # 调整位置编码矩阵的值
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # 添加一个维度作为可学习的参数
        pe = pe.unsqueeze(0)  # (1,6,h)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # 将位置编码加到输入张量中
        # (32,115,6,h)
        x = x + self.pe[:, :x.shape[2]]
        return x


# class PositionalEncoding(nn.Module):
#     def __init__(self, d_model, max_seq_len, period=6):
#         super(PositionalEncoding, self).__init__()
#         self.d_model = d_model
#         self.max_seq_len = max_seq_len

#         # 创建位置编码矩阵
#         pe = torch.zeros(self.max_seq_len, self.d_model).cuda()

#         pos = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
#         pe = torch.sin(pos * 2 * np.pi / period)
#         pe = pe.repeat((1, d_model))

#         # 添加一个维度作为可学习的参数
#         pe = pe.unsqueeze(0)  # (1,6,h)
#         self.register_buffer('pe', pe)

#     def forward(self, x):
#         # 将位置编码加到输入张量中
#         # (32,115,6,h)
#         x = x + self.pe[:, :x.shape[2]]
#         return x


class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super().__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        """
        其实这就是一个裁剪的模块，裁剪多出来的padding
        """
        return x[:, :, :-self.chomp_size].contiguous()


class TCNBlock(nn.Module):
    def __init__(self, hidden_size, padding, dilation, kernel_size, dropout):
        super().__init__()

        self.conv1 = weight_norm(
            nn.Conv2d(in_channels=hidden_size, out_channels=hidden_size, kernel_size=(kernel_size, 1),
                      padding=(padding, 0),
                      stride=1,
                      dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(
            nn.Conv2d(in_channels=hidden_size, out_channels=hidden_size, kernel_size=(kernel_size, 1),
                      padding=(padding, 0),
                      stride=1,
                      dilation=dilation))
        self.chomp2 = Chomp1d(padding)  # 裁剪掉多出来的padding部分，维持输出时间步为seq_len
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)
        # self.residual =  nn.Conv2d(in_channels=hidden_size, out_channels=hidden_size, kernel_size=(1, 1))

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)

        self.init_weights()

    def init_weights(self):
        """
        参数初始化
        """
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        # self.residual.weight.data.normal_(0, 0.01)

    def forward(self, x):
        # x:(64,T,114,32)
        x = x.transpose(2, 3).transpose(1, 2)  # (64,32,6,114)
        out = self.net(x)  # (64,32,6,114)
        out = out + x
        out = out.transpose(1, 2).transpose(2, 3)
        return F.relu(out)


class TCN(nn.Module):
    def __init__(self, hidden_size, paddings, dilations, kernel_size, dropout):
        super().__init__()
        layers = []
        for i in range(len(dilations)):
            dilation_size = dilations[i]  # 膨胀系数：1，2，2
            padding_size = paddings[i]
            layers += [
                TCNBlock(hidden_size=hidden_size, dilation=dilation_size, padding=padding_size, kernel_size=kernel_size,
                         dropout=dropout)]
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        # x:(64,T,114,32)
        return self.network(x)
    

class Matmul(nn.Module):
    def __init__(self):
        super().__init__()
        pass

    def forward(self, x, A):
        x = torch.matmul(A, x)
        return x.contiguous()


class DiffusionConv(nn.Module):
    def __init__(self, c_in, c_out, dropout, support_len=2, order=2):
        super().__init__()
        self.nconv = Matmul()
        c_in = (order * support_len + 1) * c_in
        self.mlp = nn.Linear(c_in, c_out)
        self.dropout = dropout
        self.order = order

    def forward(self, x, dist_adj, heading_adj):
        support = [dist_adj,heading_adj]
        out = [x]
        for a in support:
            x1 = self.nconv(x, a)
            out.append(x1)
            for k in range(2, self.order + 1):
                x2 = self.nconv(x1, a)
                out.append(x2)
                x1 = x2

        h = torch.cat(out, dim=-1)
        h = self.mlp(h)
        h = F.dropout(h, self.dropout, training=self.training)
        return h


class GlobalSpatialMHA(nn.Module):
    def __init__(self, heads, hidden_size):
        super().__init__()
        self.d_v = hidden_size // heads
        self.heads = heads
        self.hidden_size = hidden_size
        self.fc = nn.Linear(hidden_size, hidden_size)
        self.W_Q = nn.Linear(hidden_size, hidden_size)
        self.W_K = nn.Linear(hidden_size, hidden_size)
        self.W_V = nn.Linear(hidden_size, hidden_size)

    def forward(self, input_q, input_k, input_v):
        # x(64,115,h)
        batch, num_object = input_v.shape[0], input_v.shape[1]
        Q = self.W_Q(input_q).reshape(batch, num_object, self.heads, self.d_v).transpose(1, 2)  # (64, heads, 115, d_v)
        K = self.W_K(input_k).reshape(batch, num_object, self.heads, self.d_v).transpose(1, 2)  # (64, heads, 115, d_v)
        V = self.W_V(input_v).reshape(batch, num_object, self.heads, self.d_v).transpose(1, 2)  # (64, heads, 115, d_v)
        attention = torch.matmul(Q, K.transpose(-1, -2)) / (self.d_v ** 0.5)  # (64, heads, 115,115)
        attention = F.softmax(attention, dim=-1)
        context = torch.matmul(attention, V)  # (64, heads, 115,h)
        context = context.transpose(1, 2)
        context = context.reshape(batch, num_object, self.heads * self.d_v)  # (64, 115, heads*d_v)
        context = self.fc(context)  # x:(64,115,h)
        return context


class SpatialMultiHeadAttention(nn.Module):
    def __init__(self, heads, hidden_size, history_frames):
        super().__init__()
        self.d_v = hidden_size // heads
        self.heads = heads
        self.history_frames = history_frames
        self.hidden_size = hidden_size
        self.W_Q = nn.Linear(hidden_size, hidden_size)
        self.W_K = nn.Linear(hidden_size, hidden_size)
        self.W_V = nn.Linear(hidden_size, hidden_size)
        self.fc = nn.Linear(hidden_size, hidden_size)

    def forward(self, input_q, input_k, input_v):
        # x:(64,6,115,h)
        batch, num_object = input_q.shape[0], input_q.shape[2]
        Q = self.W_Q(input_q).reshape(batch, self.history_frames, num_object, self.heads, self.d_v).transpose(2, 3)
        K = self.W_K(input_k).reshape(batch, self.history_frames, num_object, self.heads, self.d_v).transpose(2, 3)
        V = self.W_V(input_v).reshape(batch, self.history_frames, num_object, self.heads, self.d_v).transpose(2, 3)
        attention = torch.matmul(Q, K.transpose(-1, -2)) / (self.d_v ** 0.5)  # (64, 6, heads, 115,115)
        attention = F.softmax(attention, dim=-1)
        context = torch.matmul(attention, V)  # (64, 6, heads, 115,h)
        context = context.transpose(2, 3)
        context = context.reshape(batch, self.history_frames, num_object,
                                  self.heads * self.d_v)  # (64, 6, 115, heads*h)
        context = self.fc(context)  # x:(64,6,115,h)
        # context = self.dropout(context)  # x:(64,115,6,h)
        return context


class TemporalMultiHeadAttention(nn.Module):
    def __init__(self, heads, hidden_size, history_frames, dropout):
        super().__init__()
        self.d_v = hidden_size // heads
        self.heads = heads
        self.history_frames = history_frames
        self.hidden_size = hidden_size
        self.W_Q = nn.Linear(hidden_size, hidden_size)
        self.W_K = nn.Linear(hidden_size, hidden_size)
        self.W_V = nn.Linear(hidden_size, hidden_size)
        self.fc = nn.Linear(hidden_size, hidden_size)
        self.masks = torch.tril(torch.ones((history_frames, history_frames)), diagonal=0).cuda()
        self.layerNorm1 = nn.LayerNorm(hidden_size)
        self.layerNorm2 = nn.LayerNorm(hidden_size)
        self.feedForward = FeedForward(hidden_size, hidden_size)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, input_q, input_k, input_v):
        residual = input_q
        batch, num_object = input_q.shape[0], input_q.shape[1]
        Q = self.W_Q(input_q).reshape(batch, num_object, self.history_frames, self.heads, self.d_v).transpose(2, 3)
        K = self.W_K(input_k).reshape(batch, num_object, self.history_frames, self.heads, self.d_v).transpose(2, 3)
        V = self.W_V(input_v).reshape(batch, num_object, self.history_frames, self.heads, self.d_v).transpose(2, 3)
        attention = torch.matmul(Q, K.transpose(-1, -2)) / (self.d_v ** 0.5)  # (64, 115, heads, 6, 6)
        context = torch.matmul(F.softmax(attention, dim=-1), V)  # (64, 115, heads, 6,h)
        context = context.transpose(2, 3)
        context = context.reshape(batch, num_object, self.history_frames,
                                  self.heads * self.d_v)  # (64, 115, 6, heads*h)
        context = self.fc(context)  # x:(64,115,6,h)
        context = self.dropout(context)  # x:(64,115,6,h)
        context = self.layerNorm1(context + residual)  # (64,6,114,h)

        context_1 = self.feedForward(context)  # (64,6,114,h)
        context_1 = self.dropout(context_1)  # (64,6,114,h)
        last_out = self.layerNorm2(context_1 + context)  # (64,6,114,h)
        last_out = last_out.transpose(1, 2)
        return last_out


class GlobalSpatial(nn.Module):
    def __init__(self, heads, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.gru = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        self.gsa = GlobalSpatialMHA(heads, hidden_size)

    def forward(self, x):
        # x(64,T,114,h)
        batch, step, num_object = x.shape[0], x.shape[1], x.shape[2]
        x = x.transpose(1, 2)
        x = x.reshape(batch * num_object, step, self.hidden_size)  # x(64*115,T,h)
        x,gt = self.gru(x)[1][0],self.gru(x)[0]  # (1,64*115,h)
        x = x.squeeze().reshape(batch, num_object, self.hidden_size)  # (64,115,h)
        x = self.gsa(x, x, x)  # (64,115,h)
        x = torch.unsqueeze(x, dim=1)  # (64,1,115,h)
        gt = gt.reshape(batch, num_object, step, self.hidden_size).transpose(1,2)
        return x,gt # (64,1,115,h)


class TemporalBlock(nn.Module):
    def __init__(self, heads, history_frames, hidden_size, paddings, dilations, kernel_size, dropout):
        super().__init__()
        self.gt = TemporalMultiHeadAttention(heads, hidden_size, history_frames, dropout)
        # self.pe = PositionalEncoding(hidden_size, 32)
        # self.fc_1 = nn.Linear(hidden_size, hidden_size)
        # self.fc_2 = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        x = x.transpose(1, 2)  # (32,115,6,h)
        # x = self.pe(x)

        Q, K, V = x, x, x
        tmha = self.gt(Q, K, V)
        
        # z = torch.sigmoid(self.fc_1(tmha) + self.fc_2(gt))
        # last_out = z * tmha + (1 - z) * gt
        return tmha


class SpatialBlock(nn.Module):
    def __init__(self, heads, hidden_size, history_frames, dropout):
        super().__init__()
        self.satt = SpatialMultiHeadAttention(heads, hidden_size, history_frames)
        self.dc = DiffusionConv(hidden_size, hidden_size, dropout)
        self.fc_1 = nn.Linear(hidden_size, hidden_size)
        self.fc_2 = nn.Linear(hidden_size, hidden_size)
        self.fc_3 = nn.Linear(hidden_size, hidden_size)
        self.fc_4 = nn.Linear(hidden_size, hidden_size)

    def forward(self, x, dist_adj, heading_adj, gs_out):
        # (64,6,3,115,115)
        Q, K, V = x, x, x
        satt_out = self.satt(Q, K, V)
        gcn_out = self.dc(x, dist_adj, heading_adj)
        
        z = torch.sigmoid(self.fc_1(gcn_out) + self.fc_2(satt_out))
        last_out = z * gcn_out + (1 - z) * satt_out

        a = torch.sigmoid(self.fc_3(last_out) + self.fc_4(gs_out))
        last_out = a * last_out + (1 - a) * gs_out
        return last_out


class SpatialTemporal(nn.Module):
    def __init__(self, hidden_size, history_frames, heads, paddings, dilations, kernel_size, dropout):
        super().__init__()
        self.gs = GlobalSpatial(heads, hidden_size)
        self.spatial = SpatialBlock(heads, hidden_size, history_frames, dropout)
        self.temporal = TemporalBlock(heads, history_frames, hidden_size, paddings, dilations, kernel_size, dropout)

    def forward(self, x, dist_adj, heading_adj):
        # x:(64,T,114,h), (64,6,115,115), (64,6,115,115)
        gs_out,gt = self.gs(x)
        x = self.spatial(x, dist_adj, heading_adj, gs_out)
        x = self.temporal(x)
        x = x + gt
        return x


# class Seq2Seq(nn.Module):
#     def __init__(self, hidden_size, out_channels, history_frames, future_frames, max_object_num):
#         super().__init__()
#         self.out_channels = out_channels
#         self.hidden_size = hidden_size
#         self.future_frames = future_frames
#         self.max_object_num = max_object_num
#         self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
#         self.gru_cell = nn.GRU(hidden_size + hidden_size, hidden_size, batch_first=True)
#         self.W_e = nn.Linear(hidden_size,hidden_size)
#         self.W_d = nn.Linear(hidden_size,hidden_size)
#         self.W_a = nn.Linear(hidden_size, 1)
#         self.dropout = nn.Dropout(p=0.2)
#         self.fc_2 = nn.Linear(hidden_size, out_channels)
# #         self.reset_parameters()

# #     def reset_parameters(self):
# #         nn.init.xavier_uniform_(self.W_e)
# #         nn.init.xavier_uniform_(self.W_d)
# #         nn.init.xavier_uniform_(self.W_a)

#     def forward(self, h, last_position, teacher_location=None):
#         # (64,6,114,h), (64,114,2)
#         if teacher_location is not None:
#             teacher_location = teacher_location.transpose(1, 2)
#             teacher_location = teacher_location.reshape(-1, teacher_location.shape[2], 2)  # (64*114,6,2)
#         h = h.transpose(1, 2)
#         h = h.reshape(-1, h.shape[2], self.hidden_size)  # (64*114, 6, h)
#         last_position = last_position.reshape(-1, 2).unsqueeze(dim=1)  # (64*114, 1, 2)
#         last_out = torch.zeros((h.shape[0], self.future_frames, 2)).cuda()  # (64*114,6,2)

#         x = torch.zeros((h.shape[0], 1, self.hidden_size)).cuda()  # (64*114,1,h)
#         x = torch.cat([last_position, x], dim=-1)  # (64*114,1,h+2)

#         output, h_t = self.gru(h)  # (64*114,6,h), (1,64*114,h)
#         for step in range(self.future_frames):
#             if step == 0:
#                 new_out, h_t = self.gru_cell(x, h_t)  # (64*114,1,h), (1,64*114,h)
#                 # new_out =   # (64*114, 1, 2)
#                 # new_out = new_out + last_position  # (64*114, 1, 2)
#                 last_out[:, step:step + 1, :] = self.fc_2(new_out)
#             else:
#                 a = F.softmax(self.W_a(torch.tanh(self.W_e(output) + self.W_d(h_t.transpose(0, 1)))),dim=1)  # (64*114,6,1)
#                 c = torch.matmul(output.transpose(1, 2), a).transpose(1, 2)  # (64*114,1,h)
#                 teacher_force = np.random.random() < 0.5
#                 new_out = (teacher_location[:, step - 1:step] if (type(teacher_location) is not type(
#                     None)) and teacher_force else new_out)
#                 new_out, h_t = self.gru_cell(torch.cat([new_out, c], dim=-1), h_t)  # (64*114,1,h), (64*114,1,h)
#                 last_out[:, step:step + 1, :] = self.fc_2(new_out)
#         last_out = last_out.reshape(-1, self.max_object_num, self.future_frames, self.out_channels)
#         last_out = last_out.transpose(1, 2)
#         return last_out


# class Seq2Seq(nn.Module):
#     def __init__(self, hidden_size, out_channels, history_frames, future_frames, max_object_num):
#         super().__init__()
#         self.out_channels = out_channels
#         self.hidden_size = hidden_size
#         self.future_frames = future_frames
#         self.max_object_num = max_object_num
#         self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
#         self.gru_cell = nn.GRU(hidden_size + hidden_size, hidden_size, batch_first=True)
#         self.fc = nn.Linear(out_channels,hidden_size)
#         self.W_e = nn.Linear(hidden_size,hidden_size)
#         self.W_d = nn.Linear(hidden_size,hidden_size)
#         self.W_a = nn.Linear(hidden_size, 1)
#         self.dropout = nn.Dropout(p=0.2)
#         self.fc_2 = nn.Linear(hidden_size, out_channels)

#     def forward(self, h, last_position, teacher_location=None):
#         # (64,6,114,h), (64,114,2)
#         if teacher_location is not None:
#             teacher_location = teacher_location.transpose(1, 2)
#             teacher_location = teacher_location.reshape(-1, teacher_location.shape[2], 2)  # (64*114,6,2)
#             teacher_location = self.fc(teacher_location)
#         h = h.transpose(1, 2)
#         h = h.reshape(-1, h.shape[2], self.hidden_size)  # (64*114, 6, h)
#         last_position = last_position.reshape(-1, 2).unsqueeze(dim=1)  # (64*114, 1, 2)
#         last_position = self.fc(last_position)
#         last_out = torch.zeros((h.shape[0], self.future_frames, 2)).cuda()  # (64*114,6,2)

#         x = torch.zeros((h.shape[0], 1, self.hidden_size)).cuda()  # (64*114,1,h)
#         x = torch.cat([last_position, x], dim=-1)  # (64*114,1,h+2)

#         output, h_t = self.gru(h)  # (64*114,6,h), (1,64*114,h)
#         for step in range(self.future_frames):
#             if step == 0:
#                 new_out, h_t = self.gru_cell(x, h_t)  # (64*114,1,h), (1,64*114,h)
#                 last_out[:, step:step + 1, :] = self.fc_2(new_out)
#             else:
#                 a = F.softmax(self.W_a(torch.tanh(self.W_e(output) + self.W_d(h_t.transpose(0, 1)))),dim=1)  # (64*114,6,1)
#                 c = torch.matmul(output.transpose(1, 2), a).transpose(1, 2)  # (64*114,1,h)
#                 teacher_force = np.random.random() < 0.5
#                 new_out = (teacher_location[:, step - 1:step] if (type(teacher_location) is not type(
#                     None)) and teacher_force else new_out)
#                 new_out, h_t = self.gru_cell(torch.cat([new_out, c], dim=-1), h_t)  # (64*114,1,h), (64*114,1,h)
#                 last_out[:, step:step + 1, :] = self.fc_2(new_out)
#         last_out = last_out.reshape(-1, self.max_object_num, self.future_frames, self.out_channels)
#         last_out = last_out.transpose(1, 2)
#         return last_out
    

class Model(nn.Module):
    def __init__(self, in_channels, out_channels, heads, hidden_size, layers, history_frames, max_object_num, paddings,
                 dilations, kernel_size, dropout):
        super().__init__()
        self.embed = nn.Linear(in_channels, hidden_size)
        self.layers = layers
        self.hidden_size = hidden_size
        self.history_frames = history_frames
        self.max_object_num = max_object_num
        self.st_block = SpatialTemporal(hidden_size, history_frames, heads, paddings, dilations, kernel_size, dropout)
        self.tcn_1 = TCN(hidden_size, paddings[0], dilations[0], kernel_size[0], dropout)
        self.tcn_2 = TCN(hidden_size, paddings[1], dilations[1], kernel_size[1], dropout)
        self.fc_1 = nn.Linear(hidden_size*2, hidden_size)
        self.fc_2 = nn.Linear(hidden_size, out_channels)
        # self.seq2seq = Seq2Seq(hidden_size,out_channels,history_frames,history_frames,max_object_num)

    def forward(self, x, dist_adj, heading_adj):
        # last_position = x[:,-1,:,:2]
        x = self.embed(x)  # (32,6,115,h)
        residual = x  # (64,T,115,h)

        dist_adj = get_lap(dist_adj)
        heading_adj = get_lap(heading_adj)
     
        for layer in range(self.layers):
            x = self.st_block(x, dist_adj, heading_adj) + residual # (64,6,115,h)
            residual = x
        # last_out = self.seq2seq(x,last_position,teacher_location)
        # last_out = self.tcn_1(x)
        last_out = torch.cat([self.tcn_1(x),self.tcn_2(x)],dim=-1)
        last_out = self.fc_1(last_out)
        last_out = self.fc_2(last_out)
        return last_out

In [7]:
model = Model(in_channels, out_channels, heads, hidden_size, layers, history_frames, max_object_num, paddings, dilations, kernel_size, dropout)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, threshold=0.01,threshold_mode="abs",verbose=False)

In [8]:
def inverse_transform(prediction, last_position):
    # (64,6,115,2),(64,115,2)
    for step in range(prediction.shape[1]):
        prediction[:,step,:,:] = prediction[:,step,:,:] + last_position
        last_position = prediction[:,step,:,:]
    return prediction


def eva(predicted, ground_truth, future_masks):
    #(64,6,114,2),(64,6,114,8),(64,6,114,1)
    predicted = predicted.transpose(2,3).transpose(1,2)  # (64,2,6,114)
    ground_truth = ground_truth.transpose(2,3).transpose(1,2)  # (64,8,6,114)
    future_masks = future_masks.transpose(2,3).transpose(1,2)  # (64,1,6,114)
    
    category_mask = ground_truth[:, 2:3, :, :]  # (N, C, T, V)=(64, 1, 6, 114)
    
    ### overall dist
    overall_sum_time, overall_num = compute_RMSE(predicted, ground_truth[:,-2:,:,:], future_masks)

    ### car dist
    car_mask = (((category_mask == 1) + (category_mask == 2)) > 0).float().to(device)
    car_mask = future_masks * car_mask
    car_sum_time, car_num = compute_RMSE(predicted, ground_truth[:,-2:,:,:], car_mask)

    ### human dist
    human_mask = (category_mask == 3).float().to(device)
    human_mask = future_masks * human_mask
    human_sum_time, human_num = compute_RMSE(predicted, ground_truth[:,-2:,:,:], human_mask)

    ### bike dist
    bike_mask = (category_mask == 4).float().to(device)
    bike_mask = future_masks * bike_mask
    bike_sum_time, bike_num = compute_RMSE(predicted, ground_truth[:,-2:,:,:], bike_mask)
    
    return overall_num,overall_sum_time, car_num,car_sum_time, human_num,human_sum_time,bike_num,bike_sum_time


def compute_RMSE(predicted, ground_truth, masks, error_order=2):
    predicted = predicted * masks  # (N, C, T, V)=(N, 2, 6, 114)
    ground_truth = ground_truth * masks  # (N, C, T, V)=(N, 2, 6, 114)

    x2y2 = torch.sum(torch.abs(predicted - ground_truth) ** error_order, dim=1)  # x^2+y^2, (N, C, T, V)->(N, T, V)=(64, 6, 114)
    total_sum_time = x2y2.sum(dim=-1)  # (N, T, V) -> (N, T)=(64, 6)
    total_mask = masks.sum(dim=1).sum(dim=-1)  # (N, C, T, V) -> (N, T)=(N, 6)

    return total_sum_time.detach().cpu().numpy(), total_mask.detach().cpu().numpy()


def display_result(pra_results, pra_pref='Train_epoch'):
    # all_overall_sum_list,all_overall_num_list:(num_batch*batch_size,6)
    all_overall_sum_list, all_overall_num_list = pra_results
    overall_sum_time = np.sum(all_overall_sum_list ** 0.5, axis=0)
    overall_num_time = np.sum(all_overall_num_list, axis=0)
    overall_loss_time = (overall_sum_time / overall_num_time)
    return overall_loss_time


def show_result(result_car,result_human,result_bike,stage="val"):
    result = 0.20 * result_car + 0.58 * result_human + 0.22 * result_bike
    WSADE = np.sum(result)/6
    ADE_v = np.sum(result_car)/6
    ADE_p = np.sum(result_human)/6
    ADE_b = np.sum(result_bike)/6
    
    WSFDE = result[-1]
    FDE_v = result_car[-1]
    FDE_p = result_human[-1]
    FDE_b = result_bike[-1]
    
    if stage=="val":
        log = 'val ADEv: {:.4f}, val ADEp: {:.4f}, val ADEb: {:.4f}, val WSADE: {:.4f}, val FDEv: {:.4f}, val FDEp: {:.4f}, val FDEb: {:.4f},val WSFDE: {:.4f}'
        print(log.format(ADE_v, ADE_p, ADE_b, WSADE,FDE_v, FDE_p, FDE_b, WSFDE), flush=True)   
    
    return WSADE,WSFDE

In [9]:
train_time = []
best_wsade = []
best_epoch = []
best_wsfde = []


for epoch in range(1, epochs + 1):
    print("Train start")
    print("Epoch:", epoch)
    for param_group in optimizer.param_groups:
        print(f'Learning Rate: {param_group["lr"]}')
        
    model.train()
    epoch_start_time = time.perf_counter()
    for i, batch_data in enumerate(train_loader):
        features,masks,distance_adj,heading_adj,mean = batch_data
        masks = masks.to(device)
        distance_adj = distance_adj.to(device)
        heading_adj = heading_adj.to(device)
        # category_adj = category_adj.to(device)
        features_x = features[:, :history_frames, :, :6].to(device) # (64,6,114,6)
        features_y = features[:, history_frames:, :, :].to(device)  # (64,6,114,8)
        future_masks = masks[:, history_frames:, :, :]  # (64,6,114,1)
        
        prediction = model(features_x,distance_adj,heading_adj)  # (64,6,114,2)
        
        loss = torch.sum(torch.abs(prediction*future_masks - features_y[:,:,:,:2]*future_masks))/torch.sum(future_masks)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 40 == 0:
            log = 'Train Iter: {:03d}, Train Loss: {:.4f}'
            print(log.format(i, loss.item()), flush=True)
    
    # val
    all_overall_sum_list = []
    all_overall_num_list = []
    all_car_sum_list = []
    all_car_num_list = []
    all_human_sum_list = []
    all_human_num_list = []
    all_bike_sum_list = []
    all_bike_num_list = []
    model.eval()
    for i, batch_data in enumerate(val_loader):
        features,masks,distance_adj,heading_adj,mean = batch_data
        masks = masks.to(device)
        distance_adj = distance_adj.to(device)
        heading_adj = heading_adj.to(device)
        # category_adj = category_adj.to(device)
        features_x = features[:, :history_frames, :, :6].to(device)  # (64,6,114,6)
        features_y = features[:, history_frames:, :, :].to(device)  # (64,6,114,8)
        
        future_masks = masks[:, history_frames:, :, :]  # (64,6,114,1)
        last_position = features[:,history_frames - 1,:,-2:].to(device)
        
        prediction = model(features_x,distance_adj,heading_adj)  # (64,6,114,2)
        prediction = inverse_transform(prediction, last_position)
        
        a,b,c,d,e,f,g,h = eva(prediction,features_y,future_masks)
        all_overall_num_list.extend(a)
        all_overall_sum_list.extend(b)
        all_car_num_list.extend(c)
        all_car_sum_list.extend(d)
        all_human_num_list.extend(e)
        all_human_sum_list.extend(f)
        all_bike_num_list.extend(g)
        all_bike_sum_list.extend(h)
        
    result_car = display_result([np.array(all_car_sum_list), np.array(all_car_num_list)], pra_pref='car')
    result_human = display_result([np.array(all_human_sum_list), np.array(all_human_num_list)], pra_pref='human')
    result_bike = display_result([np.array(all_bike_sum_list), np.array(all_bike_num_list)], pra_pref='bike')
    WSADE,WSFDE = show_result(result_car,result_human,result_bike)
    torch.save(model.state_dict(), model_save_path + "epoch_" + str(epoch) + "_" + str(round(WSADE, 4)) + ".pth")
    best_wsade.append(WSADE)
    best_wsfde.append(WSFDE)
    best_epoch.append(epoch)
    epoch_end_time = time.perf_counter()
    print("epoch spend time: %.4f" %(epoch_end_time-epoch_start_time))
    print("-----------------------------------------------------------")
    scheduler.step(WSADE)

Train start
Epoch: 1
Learning Rate: 0.001
Train Iter: 000, Train Loss: 2.9891
Train Iter: 040, Train Loss: 1.7460
Train Iter: 080, Train Loss: 1.3850
Train Iter: 120, Train Loss: 1.2637
val ADEv: 1.9133, val ADEp: 0.6939, val ADEb: 2.0662, val WSADE: 1.2397, val FDEv: 3.2484, val FDEp: 1.1793, val FDEb: 3.3169,val WSFDE: 2.0634
epoch spend time: 31.7196
-----------------------------------------------------------
Train start
Epoch: 2
Learning Rate: 0.001
Train Iter: 000, Train Loss: 1.0143
Train Iter: 040, Train Loss: 1.0160
Train Iter: 080, Train Loss: 1.1911
Train Iter: 120, Train Loss: 1.1380
val ADEv: 1.7318, val ADEp: 0.6653, val ADEb: 1.8684, val WSADE: 1.1433, val FDEv: 2.9653, val FDEp: 1.1607, val FDEb: 3.0358,val WSFDE: 1.9341
epoch spend time: 31.3398
-----------------------------------------------------------
Train start
Epoch: 3
Learning Rate: 0.001
Train Iter: 000, Train Loss: 1.0089
Train Iter: 040, Train Loss: 1.0776
Train Iter: 080, Train Loss: 0.9674
Train Iter: 120, T

In [10]:
testLoader = Feeder(r"/content/drive/MyDrive/data/Apolloscape/", test_datasets, 0.8, 'test')
test_loader = torch.utils.data.DataLoader(dataset=testLoader,batch_size=1,shuffle=False,num_workers=2)

In [11]:
bestid = np.argmin(best_wsade) # index

In [12]:
best_wsade[bestid]

0.7371586163838705

In [13]:
best_epoch[bestid]

200

In [14]:
model.load_state_dict(torch.load(model_save_path + "epoch_" + str(best_epoch[bestid]) + "_" + str(round(best_wsade[bestid], 4)) + ".pth"))

<All keys matched successfully>

In [15]:
# model.load_state_dict(torch.load(r"/root/data_apolloscape/epoch_179_0.7353.pth"))

In [16]:
def inverse_transform_1(data, last_position,mean_xy):
    # (64,6,115,2),(64,115,2),(64,2)
    mean_xy = mean_xy.unsqueeze(dim=1)
    last_position = last_position + mean_xy
    for step in range(data.shape[1]):
        data[:,step,:,:] = data[:,step,:,:] + last_position
        last_position = data[:,step,:,:]
    return data


def save_result1(prediction, origin,text):
    # prediction:(1,6,115,2), origin:(1,1,115,3)
    with open("/root/data_apolloscape/prediction_result/prediction_result/prediction_result_"+text+".txt", 'a') as writer:
        for step in range(prediction.shape[1]):
            idx = torch.where(origin[0,0,:,0]!=0)[0]
            step_info = prediction[0,step][idx]
            front = origin[0,0][idx]
            front[:,0] = front[:,0] + step + 1
            all_info = torch.cat([front,step_info],dim=1)
            for i in range(all_info.shape[0]):
                a = str(int(all_info[i,0])) + " " + str(int(all_info[i,1])) + " " + str(int(all_info[i,2]))+ " " + str(float(all_info[i,3]))+ " " + str(float(all_info[i,4]))+ "\n"
                writer.write(a)

In [17]:
text = "no_pe"

In [18]:
for i, batch_data in enumerate(test_loader):
    features, masks,origin,distance_adj,heading_adj,mean_xy = batch_data
    masks = masks.to(device)
    features = features.to(device)
    mean_xy = mean_xy.to(device)  # (64,6,114,1)
    origin = origin.to(device)
    heading_adj = heading_adj.to(device)  # (64,6,115,115)
    distance_adj = distance_adj.to(device)  # (64,6,115,115)
    # category_adj = category_adj.to(device)
    last_position = features[:,history_frames - 1,:,-2:]
    with torch.no_grad():
        prediction = model(features[:,:,:,:6],distance_adj,heading_adj)  # (64,6,114,2)
    prediction = inverse_transform_1(prediction, last_position,mean_xy)
    save_result1(prediction, origin,text)

In [19]:
# layer=2:269
# layer=4:270

In [20]:
# heading_adj_ = heading_adj[0,-1,:15,:15]
# #heading_adj_ = heading_adj_[:,[0,1,2,3,6,7,9,11,12,13]]

In [21]:
# distance_adj_ = distance_adj[0,-1,[0,1,2,3,6,7,9,11,12,13]]
# #distance_adj_ = distance_adj_[:,[0,1,2,3,6,7,9,11,12,13]]

In [22]:
# aaa = (gs_attention[0] + gs_attention[1] + gs_attention[2] + gs_attention[3] + gs_attention[4] + gs_attention[5])/6

In [23]:
# gs_attention_ = aaa[0,-1,:,:14,:14]
# #gs_attention_ = gs_attention_[:,:,[0,1,2,3,6,7,9,11,12,13]]
# gs_attention_ = torch.sum(gs_attention_,dim=0)/8

In [24]:
# weight_matrix = np.array(gs_attention_.cpu())

# # 绘制热力图
# plt.imshow(weight_matrix, cmap='Reds', interpolation='nearest')

# # 添加颜色条
# plt.colorbar()
# # for i in range(11):
# #     for j in range(11):
# #         plt.text(j, i, f'{weight_matrix[i, j]:.4f}', ha='center', va='center', color='black',fontsize=6)
# # 显示图形
# plt.title('GS_Attention Weight Matrix')
# plt.show()


In [25]:
# fps_attention_ = fps_attention[0,:,:15,:15]
# # fps_attention_ = fps_attention_[:,:,[0,1,2,3,6,7,9,11,12,13]]
# fps_attention_ = torch.sum(fps_attention_,dim=0)/8

In [26]:
# import matplotlib.pyplot as plt

In [27]:
# weight_matrix = np.array(fps_attention_.cpu())

# # 绘制热力图
# plt.imshow(weight_matrix, cmap='Reds', interpolation='nearest')

# # 添加颜色条
# plt.colorbar()
# # for i in range(11):
# #     for j in range(11):
# #         plt.text(j, i, f'{weight_matrix[i, j]:.4f}', ha='center', va='center', color='black',fontsize=6)
# # 显示图形
# plt.title('heading Weight Matrix')
# plt.show()


In [28]:
# weight_matrix = np.array(distance_adj_.cpu())

# # 绘制热力图
# plt.imshow(weight_matrix, cmap='Reds', interpolation='nearest')

# # 添加颜色条
# plt.colorbar()
# # for i in range(11):
# #     for j in range(11):
# #         plt.text(j, i, f'{weight_matrix[i, j]:.4f}', ha='center', va='center', color='black',fontsize=6)
# # 显示图形
# plt.title('distance Weight Matrix')
# plt.show()


In [29]:
# weight_matrix = np.array(gs_attention_.cpu())

# # 绘制热力图
# plt.imshow(weight_matrix, cmap='Reds', interpolation='nearest')

# # 添加颜色条
# plt.colorbar()
# # for i in range(11):
# #     for j in range(11):
# #         plt.text(j, i, f'{weight_matrix[i, j]:.4f}', ha='center', va='center', color='black',fontsize=6)
# # 显示图形
# plt.title('GS_Attention Weight Matrix')
# plt.show()


In [30]:
# weight_matrix = np.array(fps_attention_.cpu())

# # 绘制热力图
# plt.imshow(weight_matrix, cmap='Reds', interpolation='nearest')

# # 添加颜色条
# plt.colorbar()
# # for i in range(11):
# #     for j in range(11):
# #         plt.text(j, i, f'{weight_matrix[i, j]:.4f}', ha='center', va='center', color='black',fontsize=6)
# # 显示图形
# plt.title('FPS_Attention Weight Matrix')
# plt.show()
