In [1]:
from collections import defaultdict
import numpy as np
import argparse
import logging
import random
import time
import sys
import gc
import os


import torch.backends.cudnn as cudnn
from attrdict import AttrDict
import torch.optim as optim
import torch.nn as nn
import torch

from sgan.losses import displacement_error, final_displacement_error
from sgan.losses import gan_g_loss, gan_d_loss, l2_loss

from sgan.utils import int_tuple, bool_flag, get_total_norm
from sgan.utils import relative_to_abs, get_dset_path

from sgan.models import TrajectoryGenerator, TrajectoryDiscriminator
from sgan.data.loader import data_loader

import train

def set_seed(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(seed)
    
set_seed()

  from .autonotebook import tqdm as notebook_tqdm


# Modle Load

hotel_8_model.pt : hotel이 아닌 다른 데이터로 학습하고 hotel에서 테스트할 모델. 예측 길이는 8

Distillation에서는 일반적으로 generator만 학습하므로 우선 generator만 가져옴

In [2]:
def get_generator(args):
    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm)
    generator.load_state_dict(checkpoint['g_state'])
    generator.cuda()
    generator.train()
    
    return generator

def get_discriminator(args):
    discriminator = TrajectoryDiscriminator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        h_dim=args.encoder_h_dim_d,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        dropout=args.dropout,
        batch_norm=args.batch_norm,
        d_type='local')
#         activation='leakyrelu')
    
    return discriminator

In [3]:
def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.kaiming_normal_(m.weight)
        
def get_dtypes(args):
    long_dtype = torch.LongTensor
    float_dtype = torch.FloatTensor
    if args.use_gpu == 1:
        long_dtype = torch.cuda.LongTensor
        float_dtype = torch.cuda.FloatTensor
    return long_dtype, float_dtype


In [5]:
checkpoint = torch.load("./models/sgan-p-models/hotel_8_model.pt")
args = AttrDict(checkpoint['args'])
args.output_dir = "./"
long_dtype, float_dtype = get_dtypes(args)

generator_T = get_generator(args)
generator_T.load_state_dict(checkpoint['g_state'])

generator_S = get_generator(args)
generator_S.apply(init_weights)
generator_S.type(float_dtype).train()

discriminator_S = get_discriminator(args)
discriminator_S.apply(init_weights)
discriminator_S.type(float_dtype).train()

g_loss_fn = gan_g_loss
d_loss_fn = gan_d_loss

optimizer_g = optim.Adam(generator_S.parameters(), lr=args.g_learning_rate)
optimizer_d = optim.Adam(
    discriminator_S.parameters(), lr=args.d_learning_rate
)


# Data Loader

In [6]:
train_path = get_dset_path('hotel', 'train')
_, train_loader = data_loader(args, train_path)

val_path = get_dset_path('hotel', 'val')
_, val_loader = data_loader(args, val_path)


# 모델 학습


원본 코드의 학습 구조가 조금 이상하게 되어 있음


```python

while(t < args.num_iterations):
    d_steps_left = args.d_steps
    g_steps_left = args.g_steps
    for batch in train_loader:
        if d_steps_left > 0:
            Train discriminator
            d_steps_left -=1
           
        elif g_steps_left > 0:        
            Train generator
            g_steps_left -=1
        
        if d_steps_left > 0 or g_steps_left > 0:
            continue
        
        if t % args.checkpoint_every == 0:
            evaluate with val_loader
            save model

        t += 1
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps            
```


In [7]:
def discriminator_step(args, batch, generator, discriminator, d_loss_fn, optimizer_d):
    batch = [tensor.cuda() for tensor in batch]
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,
     loss_mask, seq_start_end) = batch
    losses = {}
    loss = torch.zeros(1).to(pred_traj_gt)

    generator_out = generator(obs_traj, obs_traj_rel, seq_start_end)

    pred_traj_fake_rel = generator_out
    pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])

    traj_real = torch.cat([obs_traj, pred_traj_gt], dim=0)
    traj_real_rel = torch.cat([obs_traj_rel, pred_traj_gt_rel], dim=0)
    traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0)
    traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0)

    scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end)
    scores_real = discriminator(traj_real, traj_real_rel, seq_start_end)
    
    # Compute loss with optional gradient penalty
    data_loss = d_loss_fn(scores_real, scores_fake)
    losses['D_data_loss'] = data_loss.item()
    loss += data_loss
    losses['D_total_loss'] = loss.item()
    
    
    optimizer_d.zero_grad()
    loss.backward()
    if args.clipping_threshold_d > 0:
        nn.utils.clip_grad_norm_(discriminator.parameters(),
                                 args.clipping_threshold_d)
        
    optimizer_d.step()

    return losses

In [8]:
checkpoint = {
    'args': args.__dict__,
    'G_losses': defaultdict(list),
    'D_losses': defaultdict(list),
    'losses_ts': [],
    'metrics_val': defaultdict(list),
    'metrics_train': defaultdict(list),
    'sample_ts': [],
    'restore_ts': [],
    'norm_g': [],
    'norm_d': [],
    'counters': {
        't': None,
        'epoch': None,
    },
    'g_state': None,
    'g_optim_state': None,
    'd_state': None,
    'd_optim_state': None,
    'g_best_state': None,
    'd_best_state': None,
    'best_t': None,
    'g_best_nl_state': None,
    'd_best_state_nl': None,
    'best_t_nl': None,
}

#### alpha = 390, negative = 1일 때 LRP의 obs_traj에 대한 perturbation의 abs의 평균, 평균, std는 각각 다음과 같다.

* abs의 평균 : 0.0118972

* 평균 : 0.0000887874915

* std : 0.065657713

#### alpha = 390, negative = 1일 때 LRP의 obs_traj_rel에 대한 perturbation의 abs의 평균, 평균, std는 각각 다음과 같다.

* abs의 평균 : 0.0064371

* 평균 : 0.0000498725123

* std : 0.0324582

In [9]:
def generator_step(args, batch, generator_S, generator_T, discriminator, g_loss_fn, optimizer_g, mode):
    batch = [tensor.cuda() for tensor in batch]
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,
     loss_mask, seq_start_end) = batch
    
    losses = {}
    loss = torch.zeros(1).to(pred_traj_gt)
    g_l2_loss_rel = []
    g_distill_loss = []
    
    
    loss_mask = loss_mask[:, args.obs_len:]

    for _ in range(args.best_k):
        generator_out_S, feat_S = generator_S(obs_traj, obs_traj_rel, seq_start_end, is_feat=True)
        
        if mode == 'lrp':
            obs_traj_ref, obs_traj_rel_ref = get_lrp(generator_T, obs_traj, obs_traj_rel, pred_traj_gt_rel, seq_start_end)
        elif mode == 'random_noise':
            obs_traj_ref = obs_traj + (torch.randn_like(obs_traj) * 0.0656)
            obs_traj_rel_ref = obs_traj_rel + (torch.randn_like(obs_traj_rel) * 0.0324)
            
        generator_out_T, feat_T = generator_T(obs_traj_ref, obs_traj_rel_ref, seq_start_end, is_feat=True)
#         generator_out_S2, feat_S2 = generator_S(obs_traj_ref, obs_traj_rel_ref, seq_start_end, is_feat=True)

        pred_traj_fake_rel_S = generator_out_S
        pred_traj_fake_rel_T = generator_out_T

        pred_traj_fake_S = relative_to_abs(pred_traj_fake_rel_S, obs_traj[-1])
        pred_traj_fake_T = relative_to_abs(pred_traj_fake_rel_T, obs_traj[-1])

        if args.l2_loss_weight > 0:
            g_l2_loss_rel.append(args.l2_loss_weight * l2_loss(
                pred_traj_fake_rel_S,
                pred_traj_gt_rel,
                loss_mask,
                mode='raw'))
            
            g_distill_loss.append(args.l2_loss_weight * l2_loss(
                pred_traj_fake_rel_S,
                pred_traj_fake_rel_T,
                loss_mask,
                mode='raw'))
            
            
    g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt)
    g_distill_loss_sum_rel = torch.zeros(1).to(pred_traj_gt)
    if args.l2_loss_weight > 0:
        g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1)
        for start, end in seq_start_end.data:
            _g_l2_loss_rel = g_l2_loss_rel[start:end]
            _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0)
            _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum(loss_mask[start:end])
            g_l2_loss_sum_rel += _g_l2_loss_rel
            
        losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item()
        loss += g_l2_loss_sum_rel

        
        g_distill_loss = torch.stack(g_distill_loss, dim=1)
        for start, end in seq_start_end.data:
            _g_distill_loss = g_l2_loss_rel[start:end]
            _g_distill_loss = torch.sum(_g_distill_loss, dim=0)
            _g_distill_loss = torch.min(_g_distill_loss) / torch.sum(loss_mask[start:end])
            g_distill_loss_sum_rel += _g_distill_loss
            
        losses['g_distill_loss'] = g_distill_loss_sum_rel.item()
        loss += g_distill_loss_sum_rel
        
        loss_feat = 0
        for i in range(len(feat_S)):
            if isinstance(feat_S[i], tuple):
                for j in range(len(feat_S[i])):
                    loss_feat += torch.mean((feat_S[i][j] - feat_T[i][j]) ** 2)
            else:
                loss_feat += torch.mean((feat_S[i] - feat_T[i]) ** 2)
                
        losses['loss_feat'] = loss_feat.item()
        loss += loss_feat
        
        
    traj_fake = torch.cat([obs_traj, pred_traj_fake_S], dim=0)
    traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel_S], dim=0)
    
    if discriminator != None:
        scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end)
        discriminator_loss = g_loss_fn(scores_fake)
        loss += discriminator_loss
        losses['G_discriminator_loss'] = discriminator_loss.item()
        losses['G_total_loss'] = loss.item()
    else:
        discriminator_loss = 0
    
    optimizer_g.zero_grad()
    loss.backward()
    if args.clipping_threshold_g > 0:
        nn.utils.clip_grad_norm_(
            generator_S.parameters(), args.clipping_threshold_g
        )
    optimizer_g.step()

    return losses

In [10]:
def get_lrp(generator_T, obs_traj, obs_traj_rel, pred_traj_gt_rel, seq_start_end, alpha = 390, negative = 1):
    generator_T.train()
    
    obs_traj.requires_grad = True
    obs_traj_rel.requires_grad = True
    
    pred = generator_T(obs_traj, obs_traj_rel, seq_start_end)

    loss = torch.mean((pred - pred_traj_gt_rel) ** 2)
    loss.backward()

    #  ===================================================================
    obs_traj_lrp = obs_traj - (obs_traj.grad * torch.abs(obs_traj) * alpha * negative)
    obs_traj_rel_lrp = obs_traj_rel - (obs_traj_rel.grad * torch.abs(obs_traj_rel) * alpha * negative)

    return obs_traj_lrp, obs_traj_rel_lrp

In [11]:
from tqdm import tqdm
def main(args):
    t = 0
    epoch = 0
    while t < args.num_iterations:
        gc.collect()
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps
        epoch += 1

        pbar = tqdm(train_loader)
        for batch in pbar:

            if d_steps_left > 0:
                step_type = 'd'
                losses_d = discriminator_step(args, batch, generator_S,
                                              discriminator_S, d_loss_fn,
                                              optimizer_d)
                d_steps_left -= 1
            elif g_steps_left > 0:
                step_type = 'g'
                losses_g = generator_step(args, batch, generator_S, generator_T,
                                          discriminator_S, g_loss_fn,
                                          optimizer_g, mode = "random_noise")
                g_steps_left -= 1

            # 여기 밑으로는 그냥 evaluation하고 모델 저장하는 부분
            if d_steps_left > 0 or g_steps_left > 0:
                continue

            pbar.set_postfix({
                "G_l2" : losses_g['G_l2_loss_rel'],
                "G_adv" : losses_g['G_discriminator_loss'],
                "G_distill" : losses_g['g_distill_loss'],
                "G_feat" : losses_g['loss_feat'],
                "D" : losses_d['D_total_loss']
            })

            # Maybe save a checkpoint
            if t > 0 and t % args.checkpoint_every == 0:
    #         if True:
                print('Checking stats on val ...')
                metrics_val = train.check_accuracy(
                    args, val_loader, generator_S, discriminator_S, d_loss_fn
                )
                print('Checking stats on train ...')
                metrics_train = train.check_accuracy(
                    args, train_loader, generator_S, discriminator_S,
                    d_loss_fn, limit=True
                )

                for k, v in sorted(metrics_val.items()):
                    print('  [val] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_val'][k].append(v)
                for k, v in sorted(metrics_train.items()):
                    print('  [train] {}: {:.3f}'.format(k, v))
                    checkpoint['metrics_train'][k].append(v)

                min_ade = min(checkpoint['metrics_val']['ade'])
                min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

                if metrics_val['ade'] == min_ade:
                    print('New low for avg_disp_error')
                    checkpoint['best_t'] = t
                    checkpoint['g_best_state'] = generator_S.state_dict()
                    checkpoint['d_best_state'] = discriminator_S.state_dict()

                if metrics_val['ade_nl'] == min_ade_nl:
                    print('New low for avg_disp_error_nl')
                    checkpoint['best_t_nl'] = t
                    checkpoint['g_best_nl_state'] = generator_S.state_dict()
                    checkpoint['d_best_nl_state'] = discriminator_S.state_dict()


                checkpoint['g_state'] = generator_S.state_dict()
                checkpoint['g_optim_state'] = optimizer_g.state_dict()

                checkpoint['d_state'] = discriminator_S.state_dict()
                checkpoint['d_optim_state'] = optimizer_d.state_dict()
                os.makedirs("saved_models", exist_ok=True)

                checkpoint_path = os.path.join(
                    args.output_dir, f'saved_models/{args.dataset_name}_{args.pred_len}_model.pt')
                print('Saving checkpoint to {}'.format(checkpoint_path))

                torch.save(checkpoint, checkpoint_path)


            t += 1
            d_steps_left = args.d_steps
            g_steps_left = args.g_steps
            if t >= args.num_iterations:
                break

In [12]:
main(args)

  7%|▋         | 3/46 [00:10<02:26,  3.40s/it, G_l2=29.6, G_adv=0.67, G_distill=29.6, G_feat=14.9, D=1.36]

KeyboardInterrupt



In [11]:
metrics_val = train.check_accuracy(
    args, val_loader, generator_S, discriminator_S, d_loss_fn
)
print('Checking stats on train ...')
metrics_train = train.check_accuracy(
    args, train_loader, generator_S, discriminator_S,
    d_loss_fn, limit=True
)

for k, v in sorted(metrics_val.items()):
    print('  [val] {}: {:.3f}'.format(k, v))
    checkpoint['metrics_val'][k].append(v)
for k, v in sorted(metrics_train.items()):
    print('  [train] {}: {:.3f}'.format(k, v))
    checkpoint['metrics_train'][k].append(v)

Checking stats on train ...


In [20]:
checkpoint1 = torch.load("saved_models/hotel_8_model.pt")
checkpoint2 = torch.load("models/sgan-p-models/hotel_8_model.pt")

In [21]:
for k in checkpoint2['args'].keys():
    if k not in checkpoint1['args'].keys():
        checkpoint1['args'][k] = checkpoint2['args'][k]
        


In [24]:
torch.save(checkpoint1, "temp.pt")

In [26]:
!mv temp.pt temp/hotel_8_model.pt