In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from expm.pytorch_expm.expm_taylor import expm_taylor
from torch.distributions.multivariate_normal import MultivariateNormal

from datetime import datetime

import glob
from dataset import NusceneDataset

import os

import matplotlib.pyplot as plt
from visualizer import plot_to_image, plot_single_batch, plot_predictions_single_batch

# notebook
from IPython.display import clear_output
from tqdm import tqdm

In [2]:
torch.set_default_dtype(torch.double)

In [3]:
# os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2'
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
LOAD_TRAINED = False

In [4]:
!nvidia-smi

Wed Jun 24 14:26:52 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.59       Driver Version: 440.59       CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  TITAN RTX           Off  | 00000000:19:00.0 Off |                  N/A |
|  0%   33C    P8    32W / 280W |    175MiB / 24220MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  TITAN RTX           Off  | 00000000:1A:00.0 Off |                  N/A |
|  0%   30C    P8     6W / 280W |      1MiB / 24220MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  TITAN RTX           Off  | 00000000:67:00.0 Off |                  N/A |
|  0%   

In [5]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print ('Available devices ', torch.cuda.device_count())
# print ('Current cuda device ', torch.cuda.current_device())
# print(torch.cuda.get_device_name(device))

# # GPU 할당 변경하기
# GPU_NUM = 0 # 원하는 GPU 번호 입력
# device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
# torch.cuda.set_device(device) # change allocation of current GPU
# print ('Current cuda device ', torch.cuda.current_device()) # check

# #Additional Infos
# if device.type == 'cuda':
#     print(torch.cuda.get_device_name(GPU_NUM))
#     print('Memory Usage:')
#     print('Allocated:', round(torch.cuda.memory_allocated(GPU_NUM)/1024**3,1), 'GB')
#     print('Cached:   ', round(torch.cuda.memory_cached(GPU_NUM)/1024**3,1), 'GB')

In [6]:
torch.manual_seed(8324)
torch.cuda.manual_seed(8324)
np.random.seed(8324)

B, A, T, D = 8, 5, 20, 2

dataset_path = "/home/mmc-server2/data/serialized_nuscenes"
max_data = int(1024)
train_dills = np.asarray(sorted(glob.glob(dataset_path + "/train/sdata*")))[:max_data]

train_dataset = NusceneDataset(train_dills, max_A=5)
dataloader = DataLoader(train_dataset, batch_size=B, shuffle=True, num_workers=4)

In [7]:
train_test_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=False, num_workers=4)

# Simple network

In [8]:
class SimpleNet(nn.Module):
    def __init__(self, use_cuda=True):
        super(SimpleNet, self).__init__()
        # static
        self.past_rnn = nn.GRU(2, hidden_size=128, num_layers=1, bias=True, batch_first=True)
        self.social_mlp = nn.Sequential(
            nn.Linear(8, 128),
            nn.Tanh(),
            nn.Linear(128, 100)    
        )
        
        self.cnn = nn.Sequential(
            nn.Conv2d(5, 32, 3, 1, padding=1),
            nn.ReLU(True),
            nn.Conv2d(32, 32, 3, 1, padding=1),
            nn.ReLU(True),
            nn.Conv2d(32, 32, 3, 1, padding=1),
            nn.ReLU(True),
            nn.Conv2d(32, 32, 3, 1, padding=1),
            nn.ReLU(True),
            nn.Conv2d(32, 12, 3, 1, padding=1),
            nn.ReLU(True)
        )
        
        # dynamic loop
        self.future_rnn = nn.GRUCell(100+256+10+12*5, 100)  # social_mlp + past * 2 + A*2 + CA
        self.future_mlp = nn.Sequential(
            nn.Linear(100, 256),
            nn.Tanh(),
            nn.Linear(256, 6)
        )
        
        if use_cuda:
            self.past_rnn.cuda()
            self.cnn.cuda()
            self.social_mlp.cuda()
            self.future_rnn.cuda()
            self.future_mlp.cuda()
        
    def forward(self, player_past, other_pasts, overhead_sdt_features, z):
        # dataset
        pasts = torch.cat((torch.unsqueeze(player_past, dim=1), other_pasts), dim=1)
        agent_past_list = torch.split(pasts, split_size_or_sections=1, dim=1)
        
        # forward 1: static
        last_hiddens = []
        for past in agent_past_list:
            past = torch.squeeze(past, dim=1)
            _, last_hidden = self.past_rnn(past)
            last_hiddens.append(last_hidden)  # (1, B, hidden)

        alpha1 = torch.cat(last_hiddens, dim=0)  # (A, B, hidden)

        alpha2_list = []
        for last_hidden in last_hiddens:
            other_sum = torch.sum(alpha1, dim=0, keepdim=True) - last_hidden  # (1, B, hidden)
            alpha2 = torch.cat((last_hidden, other_sum), dim=2)  # (1, B, hidden*2)
            alpha2 = torch.squeeze(alpha2, dim=0)  # (B, hidden*2)
            alpha2_list.append(alpha2)
            
        # cnn
        map_features = self.cnn(overhead_sdt_features)
        
            
        # forward 2: dynamic loop
        future_rnn_hs = [None] * 5

        B, A, _, _ = pasts.shape
        T = 20

        currents = pasts[..., -1, :]
        previous = pasts[..., -2, :]

        nexts_list  = []
        sigmas_list = []

        for t in range(T):
            a_next_list = []
            a_sigma_list = []
            
            # social map feature 구하는 방법이 구림: 순서 바뀌면 네트워크 결과 바뀜
            PIXELS_PER_METER = 2.0
            H = 200
            W = 200
            currents_in_grid = H // 2 + currents * PIXELS_PER_METER  # (B, A, 2)
            currents_in_grid = currents_in_grid / H  # normalize to between 0 and 1
            currents_in_grid = currents_in_grid * 2 - 1.0  # normalize to between -1 and 1
            
            currents_in_grid = currents_in_grid.unsqueeze(2)  # (B, A, 1, D=2)
            interp_out = nn.functional.grid_sample(map_features, currents_in_grid, align_corners=True)  # (B, C=8, A, 1)
            social_map_features = torch.squeeze(interp_out, dim=-1)  # (B, C=8, A)
            social_map_features = torch.flatten(social_map_features, start_dim=1, end_dim=2)  # (B, CA)
            
            for a in range(A):
                a_current  = currents[:, a, :]
                a_previous = previous[:, a, :] 

                other_currents = [currents[:, j, :] for j in range(A) if j!=a]

                displacements = [a_current - oc for oc in other_currents]
                displacements = torch.stack(displacements, dim=1)  # (B, A, 2)

                displacements_flatten = torch.flatten(displacements, start_dim=1)
                social_feature = self.social_mlp(displacements_flatten)
                currents_flatten = torch.flatten(currents, start_dim=1)

#                 joint_feature = torch.cat((alpha2_list[a], social_feature, currents_flatten), dim=1)
                joint_feature = torch.cat((alpha2_list[a], social_feature, currents_flatten, social_map_features), dim=1)

                h = self.future_rnn(joint_feature, future_rnn_hs[a])
                future_rnn_hs[a] = h

                out = self.future_mlp(h)
                m_at, zeta_at = out[:, :2], out[:, 2:6]
                zeta_at = torch.reshape(zeta_at, (-1, 2, 2))
                sigma_at = expm_taylor(zeta_at + torch.transpose(zeta_at, 1, 2))

                z_at = z[:, a, t]

                sigma_prod_z_at = torch.squeeze(torch.matmul(sigma_at, z_at), dim=2)
                a_next = 2 * a_current - a_previous + m_at + sigma_prod_z_at

                a_next_list.append(a_next)
                a_sigma_list.append(sigma_at)

            # preparing next t
            nexts = torch.stack(a_next_list, dim=1)
            sigmas = torch.stack(a_sigma_list, dim=1)
            nexts_list.append(nexts)
            sigmas_list.append(sigmas)

            previous = currents
            currents = nexts

        # q_distribution
        prediction_mus  = torch.stack(nexts_list, dim=2)
        prediction_covs = torch.stack(sigmas_list, dim=2)
        
        return prediction_mus, prediction_covs

# Summary Helper

In [9]:
def get_model_summaries(model):
    module_names = []
    for param in model.named_parameters():
        param_name, param_val = param
        names = param_name.split('.')

        module_name = names[0]
        if module_name not in module_names:
            module_names.append(module_name)

    model_summary = dict()
    for module_name in module_names:
        model_summary['model/' + module_name + '/value/' + 'weights'] = []
        model_summary['model/' + module_name + '/value/' + 'biases'] = []
        model_summary['model/' + module_name + '/grad/' + 'weights'] = []
        model_summary['model/' + module_name + '/grad/' + 'biases'] = []

    for param in model.named_parameters():
        param_name, param_val = param
        names = param_name.split('.')

        module_name = names[0]

        if 'weight' in names[-1]:
            model_summary['model/' + module_name + '/value/' + 'weights'].append(param_val.view(-1))
            model_summary['model/' + module_name + '/grad/' + 'weights'].append(param_val.grad.view(-1))
        elif 'bias' in names[-1]:
            model_summary['model/' + module_name + '/value/' + 'biases'].append(param_val.view(-1))
            model_summary['model/' + module_name + '/grad/' + 'biases'].append(param_val.grad.view(-1))

    scalar_model_summary = dict()
    for key, val in model_summary.items():
        scalar_model_summary[key] = torch.norm(torch.cat([v for v in val]))
    return model_summary, scalar_model_summary

# Train

In [10]:
if LOAD_TRAINED:
    trained = torch.load("./log/trainset_all_b8_wobn_2020-06-21-15-17/model.pt")
    trained.keys()

In [11]:
model = SimpleNet()
if LOAD_TRAINED:
    model.load_state_dict(trained['model_state_dict'])
model.cuda()

optimizer = optim.Adam(model.parameters(), lr=3e-4)

if LOAD_TRAINED:
    optimizer.load_state_dict(trained['optimizer_state_dict'])
    optimizer.lr = 3e-5
    logdir = f"./log/trainset_all_b8_wobn_2020-06-21-15-17"
else:
    logdir = f"./log/trainset_all_b8_wobn_{datetime.now():%Y-%m-%d-%H-%M}"

writer = SummaryWriter(logdir)

In [12]:
def write_model_summary(model, writer):
    _, scalar_model_summaries = get_model_summaries(model)
    for key, val in scalar_model_summaries.items():
        writer.add_scalar(key, val, global_iterations)
        
    del scalar_model_summaries

In [13]:
def evaluate_dataset(model, dataloader):
#     print('[*] evaluating ...')
    losses, loss_poss = [], []
    minADEs, minFDEs, miss_rate_2s = [], [], []
    
    model = model.eval()
#     for i, data in enumerate(tqdm(dataloader, 0)):
    for i, data in enumerate(dataloader, 0):
        player_past           = data['player_past'].cuda()
        other_pasts           = data['other_pasts'].cuda()
        overhead_sdt_features = data['overhead_sdt_features'].cuda()

        B = player_past.shape[0]
        z = torch.randn((B, A, T, 2, 1)).cuda()

        # model
        with torch.no_grad():
            mus, covs = model(player_past, other_pasts, overhead_sdt_features, z)

            # q distribution
            q_dist = MultivariateNormal(loc=mus, covariance_matrix=covs)

            # target distribution (p)
            player_future = data['player_expert'].cuda()
            other_futures = data['other_experts'].cuda()
            futures       = torch.cat((torch.unsqueeze(player_future, dim=1), other_futures), dim=1)
            p_cov         = 0.01 * torch.eye(2)[None, None, None, ...].repeat(B, A, T, 1, 1).cuda()
            p_dist        = MultivariateNormal(loc=futures, covariance_matrix=p_cov)

            # cross entropy loss with noise
            p_samples = p_dist.sample(sample_shape=[12])
            loss      = -q_dist.log_prob(p_samples).mean()
            loss_lb   = -p_dist.log_prob(p_samples).mean()
            loss_pos  = loss - loss_lb
        
        # some metrics
        q_dist_samples = q_dist.sample(sample_shape=[12])
        _futures = torch.unsqueeze(futures, dim=0)
        
        err = q_dist_samples - _futures
        norm2 = torch.norm(err, dim=-1)
        
        ADE = torch.mean(norm2, dim=-1)
        minADE = torch.min(ADE, dim=0)
        minADE_avg_over_agent = torch.mean(minADE.values, dim=-1)
        minADE_avg = torch.mean(minADE_avg_over_agent)
        
        fde = err[..., -1, :]
        fde_L2 = torch.norm(fde, dim=-1)
        fde_L2_per_sample = torch.reshape(fde_L2, (-1, B, 5))
        minFDE = torch.min(fde_L2_per_sample, dim=0)
        minFDE_avg_over_agent = torch.mean(minFDE.values, dim=-1)
        minFDE_avg = torch.mean(minFDE_avg_over_agent)
        
        miss_2 = (norm2 > 2.0)
        miss_rate_2 = torch.mean(miss_2, dtype=torch.float32)
        
        losses.append(loss)
        loss_poss.append(loss_pos)
        
        minADEs.append(minADE_avg)
        minFDEs.append(minFDE_avg)
        miss_rate_2s.append(miss_rate_2)
        
        del loss, loss_pos, minADE_avg, minFDE_avg, miss_rate_2
        del player_past, other_pasts, overhead_sdt_features, z, mus, covs, player_future, other_futures, futures
    
    loss = torch.Tensor(losses).mean()
    loss_pos = torch.Tensor(loss_poss).mean()
    minADE = torch.Tensor(minADEs).mean()
    minFDE = torch.Tensor(minFDEs).mean()
    miss_rate_2 = torch.Tensor(miss_rate_2s).mean()
    
    return loss, loss_pos, minADE, minFDE, miss_rate_2

In [14]:
def draw_predictions(model, dataset, dataset_idxs, global_iterations, loss, loss_pos, minADE, minFDE, miss_rate):
    assert len(dataset_idxs) == 4
    model = model.eval()
    with torch.no_grad():
        # dataset
        player_past           = torch.cat([torch.Tensor(dataset[i]['player_past'][None]) for i in dataset_idxs], dim=0).cuda()
        other_pasts           = torch.cat([torch.Tensor(dataset[i]['other_pasts'][None]) for i in dataset_idxs], dim=0).cuda()
        overhead_sdt_features = torch.cat([torch.Tensor(dataset[i]['overhead_sdt_features'][None]) for i in dataset_idxs], dim=0).cuda()

        player_future = torch.cat([torch.Tensor(dataset[i]['player_expert'][None]) for i in dataset_idxs], dim=0).cuda()
        other_futures = torch.cat([torch.Tensor(dataset[i]['other_experts'][None]) for i in dataset_idxs], dim=0).cuda()

        B = player_past.shape[0]
        z = torch.randn((B, A, T, 2, 1)).cuda()

        # model
        mus, covs = model(player_past, other_pasts, overhead_sdt_features, z)

        # q distribution
        q_dist = MultivariateNormal(loc=mus, covariance_matrix=covs)

    fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(12, 12))
    ax = ax.reshape(-1)
    q_dist_samples = q_dist.sample(sample_shape=[7])
    for b in range(B):
        fig, ax[b] = plot_single_batch(
            player_past[b].cpu(), 
            other_pasts[b].cpu(), 
            player_future[b].cpu(), 
            other_futures[b].cpu(), 
            fig, ax[b]
        )

        fig, ax[b] = plot_predictions_single_batch(
            fig, ax[b], q_dist_samples[:, b, ...].cpu()
        )

        ax[b].axis('off')

    plt.tight_layout()
    fig.suptitle(f"epoch={epoch:05d}, iter={global_iterations:07d} \nHpq={loss:.3f}, Hp'q={loss_pos:.3f}, \nminADE={minADE:.2f}, minFDE={minFDE:.3f}, \nmiss_rate_2={miss_rate*100.0:.2f}%", fontsize=24)
    image = plot_to_image(fig)
    
    del player_past, other_pasts, overhead_sdt_features, player_future, other_futures, z, mus, covs
    
    del fig, ax
    return image

In [15]:
dataset_idxs = np.random.randint(0, len(train_dataset), 4)
print(dataset_idxs)

[ 26  14 131 306]


In [16]:
MAX_EPOCHS = int(200000)
A = 5
T = 20

if LOAD_TRAINED:
    best_loss = trained['loss']
    global_iterations = trained['global_iteractions']
    epoch_start = trained['epoch']
else:
    best_loss = 10.0
    global_iterations = 0
    epoch_start = 0

for epoch in range(epoch_start, MAX_EPOCHS):
    for i, data in enumerate(dataloader, 0):
        global_iterations += 1
        
        player_past           = data['player_past'].cuda()
        other_pasts           = data['other_pasts'].cuda()
        overhead_sdt_features = data['overhead_sdt_features'].cuda()

        B = player_past.shape[0]
        z = torch.randn((B, A, T, 2, 1)).cuda()

        # model
        model = model.train()
        mus, covs = model(player_past, other_pasts, overhead_sdt_features, z)

        # q distribution
        q_dist = MultivariateNormal(loc=mus, covariance_matrix=covs)

        # target distribution (p)
        player_future = data['player_expert'].cuda()
        other_futures = data['other_experts'].cuda()
        futures       = torch.cat((torch.unsqueeze(player_future, dim=1), other_futures), dim=1)
        p_cov         = 0.01 * torch.eye(2)[None, None, None, ...].repeat(B, A, T, 1, 1).cuda()
        p_dist        = MultivariateNormal(loc=futures, covariance_matrix=p_cov)

        # cross entropy loss with noised 
        p_samples = p_dist.sample(sample_shape=[12])
        loss      = -q_dist.log_prob(p_samples).mean()
        loss_lb   = -p_dist.log_prob(p_samples).mean()
        loss_pos  = loss - loss_lb

        # optimize
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
        optimizer.step()
        
        # write train information
        writer.add_scalar("train/batch/Hpq_positive", loss_pos, global_iterations)
        
        if (global_iterations % 10 == 0) and (global_iterations > 0):
            write_model_summary(model, writer)
        
        if (global_iterations % 10 == 0) and (global_iterations > 0):
            _loss, _loss_pos, _minADE, _minFDE, _miss_rate_2 = evaluate_dataset(model, train_test_dataloader)
            writer.add_scalar("train/all/Hpq_positive", _loss_pos, global_iterations)
            writer.add_scalar("train/all/Hpq", _loss, global_iterations)
            writer.add_scalar("train/all/minADE", _minADE, global_iterations)
            writer.add_scalar("train/all/minFDE", _minFDE, global_iterations)
            writer.add_scalar("train/all/miss_rate_2", _miss_rate_2*100.0, global_iterations)
            print(f"epoch={epoch:05d}, iter={global_iterations:07d}, Hpq_positive={_loss_pos:.2f}, Hpq={_loss:.2f}, minADE={_minADE:.2f}, minFDE={_minFDE:.2f}, miss_rate_2={_miss_rate_2*100.0:.2f}%")
            
            if global_iterations % 100 == 0:
                image = draw_predictions(
                    model, train_dataset, dataset_idxs, global_iterations, _loss, _loss_pos, _minADE, _minFDE, _miss_rate_2)
                writer.add_image('train_data', image, global_step=global_iterations, dataformats='HWC')
                del image
        
            # save model
            if _loss < best_loss:
                best_loss = _loss
                torch.save({
                    'epoch': epoch,
                    'global_iteractions':global_iterations,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': _loss,
                }, logdir + "/model.pt")

epoch=00000, iter=0000010, Hpq_positive=933.64, Hpq=931.87, minADE=28.61, minFDE=68.05, miss_rate_2=92.47%
epoch=00000, iter=0000020, Hpq_positive=679.54, Hpq=677.78, minADE=22.64, minFDE=53.37, miss_rate_2=89.95%


KeyboardInterrupt: 

In [None]:
raise NotImplemented

In [None]:
names = name.split('.')
if len(names) == 4: _, module_name, num, wb = names
elif len(names) == 3: _, module_name, wb = names

In [None]:
scalar_model_summary

In [None]:
q_dist

In [None]:
p_samples = p_dist.sample(sample_shape=[12])
loss = -q_dist.log_prob(p_samples).mean()
-p_dist.log_prob(p_samples).mean()

In [None]:
for param in model.named_parameters():
    print(param)

In [None]:
name, val = param

In [None]:
val.grad