In [1]:
from collections import defaultdict
import numpy as np
import argparse
import logging
import random
import time
import sys
import gc
import os


import torch.backends.cudnn as cudnn
from attrdict import AttrDict
import torch.optim as optim
import torch.nn as nn
import torch

from sgan.losses import displacement_error, final_displacement_error
from sgan.losses import gan_g_loss, gan_d_loss, l2_loss

from sgan.utils import int_tuple, bool_flag, get_total_norm
from sgan.utils import relative_to_abs, get_dset_path

from sgan.models import TrajectoryGenerator, TrajectoryDiscriminator
from sgan.data.loader import data_loader

import train

def set_seed(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(seed)
    
set_seed()

  from .autonotebook import tqdm as notebook_tqdm


# Modle Load

hotel_8_model.pt : hotel이 아닌 다른 데이터로 학습하고 hotel에서 테스트할 모델. 예측 길이는 8

Distillation에서는 일반적으로 generator만 학습하므로 우선 generator만 가져옴

In [2]:
def get_generator(args):
    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm)
    generator.load_state_dict(checkpoint['g_state'])
    generator.cuda()
    generator.train()
    
    return generator

def get_discriminator(args):
    discriminator = TrajectoryDiscriminator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        h_dim=args.encoder_h_dim_d,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        dropout=args.dropout,
        batch_norm=args.batch_norm,
        d_type='local')
#         activation='leakyrelu')
    
    return discriminator

In [3]:
def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.kaiming_normal_(m.weight)
        
def get_dtypes(args):
    long_dtype = torch.LongTensor
    float_dtype = torch.FloatTensor
    if args.use_gpu == 1:
        long_dtype = torch.cuda.LongTensor
        float_dtype = torch.cuda.FloatTensor
    return long_dtype, float_dtype


In [4]:
checkpoint = torch.load("./models/sgan-p-models/hotel_8_model.pt")
args = AttrDict(checkpoint['args'])
args.output_dir = "./"
long_dtype, float_dtype = get_dtypes(args)

generator_T = get_generator(args)
generator_T.load_state_dict(checkpoint['g_state'])

generator_S = get_generator(args)
generator_S.apply(init_weights)
generator_S.type(float_dtype).train()

discriminator_S = get_discriminator(args)
discriminator_S.apply(init_weights)
discriminator_S.type(float_dtype).train()

g_loss_fn = gan_g_loss
d_loss_fn = gan_d_loss

optimizer_g = optim.Adam(generator_S.parameters(), lr=args.g_learning_rate)
optimizer_d = optim.Adam(
    discriminator_S.parameters(), lr=args.d_learning_rate
)


# Data Loader

In [5]:
train_path = get_dset_path('hotel', 'train')
_, train_loader = data_loader(args, train_path)

val_path = get_dset_path('hotel', 'val')
_, val_loader = data_loader(args, val_path)


# 모델 학습


원본 코드의 학습 구조가 조금 이상하게 되어 있음


```python

while(t < args.num_iterations):
    d_steps_left = args.d_steps
    g_steps_left = args.g_steps
    for batch in train_loader:
        if d_steps_left > 0:
            Train discriminator
            d_steps_left -=1
           
        elif g_steps_left > 0:        
            Train generator
            g_steps_left -=1
        
        if d_steps_left > 0 or g_steps_left > 0:
            continue
        
        if t % args.checkpoint_every == 0:
            evaluate with val_loader
            save model

        t += 1
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps            
```


In [6]:
def discriminator_step(args, batch, generator, discriminator, d_loss_fn, optimizer_d):
    batch = [tensor.cuda() for tensor in batch]
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,
     loss_mask, seq_start_end) = batch
    losses = {}
    loss = torch.zeros(1).to(pred_traj_gt)

    generator_out = generator(obs_traj, obs_traj_rel, seq_start_end)

    pred_traj_fake_rel = generator_out
    pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])

    traj_real = torch.cat([obs_traj, pred_traj_gt], dim=0)
    traj_real_rel = torch.cat([obs_traj_rel, pred_traj_gt_rel], dim=0)
    traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0)
    traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0)

    scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end)
    scores_real = discriminator(traj_real, traj_real_rel, seq_start_end)
    
    # Compute loss with optional gradient penalty
    data_loss = d_loss_fn(scores_real, scores_fake)
    losses['D_data_loss'] = data_loss.item()
    loss += data_loss
    losses['D_total_loss'] = loss.item()
    
    
    optimizer_d.zero_grad()
    loss.backward()
    if args.clipping_threshold_d > 0:
        nn.utils.clip_grad_norm_(discriminator.parameters(),
                                 args.clipping_threshold_d)
        
    optimizer_d.step()

    return losses

In [7]:
checkpoint = {
    'args': args.__dict__,
    'G_losses': defaultdict(list),
    'D_losses': defaultdict(list),
    'losses_ts': [],
    'metrics_val': defaultdict(list),
    'metrics_train': defaultdict(list),
    'sample_ts': [],
    'restore_ts': [],
    'norm_g': [],
    'norm_d': [],
    'counters': {
        't': None,
        'epoch': None,
    },
    'g_state': None,
    'g_optim_state': None,
    'd_state': None,
    'd_optim_state': None,
    'g_best_state': None,
    'd_best_state': None,
    'best_t': None,
    'g_best_nl_state': None,
    'd_best_state_nl': None,
    'best_t_nl': None,
}

In [8]:
def generator_step(args, batch, generator_S, generator_T, discriminator, g_loss_fn, optimizer_g, mode='lrp'):
    batch = [tensor.cuda() for tensor in batch]
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,
     loss_mask, seq_start_end) = batch
    
    losses = {}
    loss = torch.zeros(1).to(pred_traj_gt)
    g_l2_loss_rel = []
    g_distill_loss = []
    
    
    loss_mask = loss_mask[:, args.obs_len:]

    for _ in range(args.best_k):
        generator_out_S, feat_S = generator_S(obs_traj, obs_traj_rel, seq_start_end, is_feat=True)
        
        if mode == 'lrp':
            obs_traj_ref, obs_traj_rel_ref = get_lrp(generator_T, obs_traj, obs_traj_rel, pred_traj_gt_rel, seq_start_end)
        elif mode == 'random_noise':
#             obs_traj_ref += random_noise
#             obs_traj_rel_ref += random_noise2
            assert False, "random noise is not ready yet!!!"
            pass
            
        generator_out_T, feat_T = generator_T(obs_traj_ref, obs_traj_rel_ref, seq_start_end, is_feat=True)
#         generator_out_S2, feat_S2 = generator_S(obs_traj_ref, obs_traj_rel_ref, seq_start_end, is_feat=True)

        pred_traj_fake_rel_S = generator_out_S
        pred_traj_fake_rel_T = generator_out_T

        pred_traj_fake_S = relative_to_abs(pred_traj_fake_rel_S, obs_traj[-1])
        pred_traj_fake_T = relative_to_abs(pred_traj_fake_rel_T, obs_traj[-1])

        if args.l2_loss_weight > 0:
            g_l2_loss_rel.append(args.l2_loss_weight * l2_loss(
                pred_traj_fake_rel_S,
                pred_traj_gt_rel,
                loss_mask,
                mode='raw'))
            
            g_distill_loss.append(args.l2_loss_weight * l2_loss(
                pred_traj_fake_rel_S,
                pred_traj_fake_rel_T,
                loss_mask,
                mode='raw'))
            
            
    g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt)
    g_distill_loss_sum_rel = torch.zeros(1).to(pred_traj_gt)
    if args.l2_loss_weight > 0:
        g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1)
        for start, end in seq_start_end.data:
            _g_l2_loss_rel = g_l2_loss_rel[start:end]
            _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0)
            _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum(loss_mask[start:end])
            g_l2_loss_sum_rel += _g_l2_loss_rel
            
        losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item()
        loss += g_l2_loss_sum_rel

        
        g_distill_loss = torch.stack(g_distill_loss, dim=1)
        for start, end in seq_start_end.data:
            _g_distill_loss = g_l2_loss_rel[start:end]
            _g_distill_loss = torch.sum(_g_distill_loss, dim=0)
            _g_distill_loss = torch.min(_g_distill_loss) / torch.sum(loss_mask[start:end])
            g_distill_loss_sum_rel += _g_distill_loss
            
        losses['g_distill_loss'] = g_distill_loss_sum_rel.item()
        loss += g_distill_loss_sum_rel
        
        loss_feat = 0
        for i in range(len(feat_S)):
            if isinstance(feat_S[i], tuple):
                for j in range(len(feat_S[i])):
                    loss_feat += torch.mean((feat_S[i][j] - feat_T[i][j]) ** 2)
            else:
                loss_feat += torch.mean((feat_S[i] - feat_T[i]) ** 2)
                
        losses['loss_feat'] = loss_feat.item()
        loss += loss_feat
        
        
    traj_fake = torch.cat([obs_traj, pred_traj_fake_S], dim=0)
    traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel_S], dim=0)
    
    if discriminator != None:
        scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end)
        discriminator_loss = g_loss_fn(scores_fake)
        loss += discriminator_loss
        losses['G_discriminator_loss'] = discriminator_loss.item()
        losses['G_total_loss'] = loss.item()
    else:
        discriminator_loss = 0
    
    optimizer_g.zero_grad()
    loss.backward()
    if args.clipping_threshold_g > 0:
        nn.utils.clip_grad_norm_(
            generator_S.parameters(), args.clipping_threshold_g
        )
    optimizer_g.step()

    return losses

In [9]:
def get_lrp(generator_T, obs_traj, obs_traj_rel, pred_traj_gt_rel, seq_start_end, alpha = 390, negative = 1):
    generator_T.train()
    
    obs_traj.requires_grad = True
    obs_traj_rel.requires_grad = True
    
    pred = generator_T(obs_traj, obs_traj_rel, seq_start_end)

    loss = torch.mean((pred - pred_traj_gt_rel) ** 2)
    loss.backward()

    #  ===================================================================
    obs_traj_lrp = obs_traj - (obs_traj.grad * torch.abs(obs_traj) * alpha * negative)
    obs_traj_rel_lrp = obs_traj_rel - (obs_traj_rel.grad * torch.abs(obs_traj_rel) * alpha * negative)

    return obs_traj_lrp, obs_traj_rel_lrp

In [None]:
from tqdm import tqdm

t = 0
epoch = 0
while t < args.num_iterations:
    gc.collect()
    d_steps_left = args.d_steps
    g_steps_left = args.g_steps
    epoch += 1
    
    pbar = tqdm(train_loader)
    for batch in pbar:
        
        if d_steps_left > 0:
            step_type = 'd'
            losses_d = discriminator_step(args, batch, generator_S,
                                          discriminator_S, d_loss_fn,
                                          optimizer_d)
            d_steps_left -= 1
        elif g_steps_left > 0:
            step_type = 'g'
            losses_g = generator_step(args, batch, generator_S, generator_T,
                                      discriminator_S, g_loss_fn,
                                      optimizer_g)
            g_steps_left -= 1

        # 여기 밑으로는 그냥 evaluation하고 모델 저장하는 부분
        if d_steps_left > 0 or g_steps_left > 0:
            continue
        
        pbar.set_postfix({
            "G_l2" : losses_g['G_l2_loss_rel'],
            "G_adv" : losses_g['G_discriminator_loss'],
            "G_distill" : losses_g['g_distill_loss'],
            "G_feat" : losses_g['loss_feat'],
            "D" : losses_d['D_total_loss']
        })
        
        # Maybe save a checkpoint
        if t > 0 and t % args.checkpoint_every == 0:
#         if True:
            print('Checking stats on val ...')
            metrics_val = train.check_accuracy(
                args, val_loader, generator_S, discriminator_S, d_loss_fn
            )
            print('Checking stats on train ...')
            metrics_train = train.check_accuracy(
                args, train_loader, generator_S, discriminator_S,
                d_loss_fn, limit=True
            )

            for k, v in sorted(metrics_val.items()):
                print('  [val] {}: {:.3f}'.format(k, v))
                checkpoint['metrics_val'][k].append(v)
            for k, v in sorted(metrics_train.items()):
                print('  [train] {}: {:.3f}'.format(k, v))
                checkpoint['metrics_train'][k].append(v)

            min_ade = min(checkpoint['metrics_val']['ade'])
            min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

            if metrics_val['ade'] == min_ade:
                print('New low for avg_disp_error')
                checkpoint['best_t'] = t
                checkpoint['g_best_state'] = generator_S.state_dict()
                checkpoint['d_best_state'] = discriminator_S.state_dict()

            if metrics_val['ade_nl'] == min_ade_nl:
                print('New low for avg_disp_error_nl')
                checkpoint['best_t_nl'] = t
                checkpoint['g_best_nl_state'] = generator_S.state_dict()
                checkpoint['d_best_nl_state'] = discriminator_S.state_dict()


            checkpoint['g_state'] = generator_S.state_dict()
            checkpoint['g_optim_state'] = optimizer_g.state_dict()
            
            checkpoint['d_state'] = discriminator_S.state_dict()
            checkpoint['d_optim_state'] = optimizer_d.state_dict()
            os.makedirs("saved_models", exist_ok=True)
            
            checkpoint_path = os.path.join(
                args.output_dir, f'saved_models/S_{args.dataset_name}_{args.pred_len}_model.pt')
            print('Saving checkpoint to {}'.format(checkpoint_path))

            torch.save(checkpoint, checkpoint_path)


        t += 1
        d_steps_left = args.d_steps
        g_steps_left = args.g_steps
        if t >= args.num_iterations:
            break

100%|██████████| 46/46 [02:12<00:00,  2.89s/it, G_l2=17, G_adv=0.68, G_distill=17, G_feat=6.83, D=1.02]      
100%|██████████| 46/46 [02:13<00:00,  2.89s/it, G_l2=10.1, G_adv=0.687, G_distill=10.1, G_feat=3.4, D=1.32]  
100%|██████████| 46/46 [02:11<00:00,  2.85s/it, G_l2=9.5, G_adv=0.693, G_distill=9.5, G_feat=1.62, D=1.01]   
100%|██████████| 46/46 [02:10<00:00,  2.83s/it, G_l2=5.41, G_adv=0.693, G_distill=5.41, G_feat=1.05, D=1.4]  
100%|██████████| 46/46 [02:11<00:00,  2.87s/it, G_l2=6.87, G_adv=0.684, G_distill=6.87, G_feat=0.765, D=0.934]
100%|██████████| 46/46 [02:12<00:00,  2.88s/it, G_l2=3.49, G_adv=0.666, G_distill=3.49, G_feat=0.579, D=1.07] 
100%|██████████| 46/46 [02:12<00:00,  2.88s/it, G_l2=2.09, G_adv=0.691, G_distill=2.09, G_feat=0.483, D=1.32]
100%|██████████| 46/46 [02:11<00:00,  2.86s/it, G_l2=1.25, G_adv=0.692, G_distill=1.25, G_feat=0.398, D=1.17]
100%|██████████| 46/46 [02:12<00:00,  2.89s/it, G_l2=1.64, G_adv=0.692, G_distill=1.64, G_feat=0.298, D=1.34]
100%|███

Checking stats on val ...
Checking stats on train ...


  9%|▊         | 4/46 [00:15<03:12,  4.59s/it, G_l2=1.72, G_adv=0.691, G_distill=1.72, G_feat=0.269, D=1.21]

  [val] ade: 1.514
  [val] ade_l: 3.183
  [val] ade_nl: 2.886
  [val] d_loss: 1.290
  [val] fde: 2.737
  [val] fde_l: 5.755
  [val] fde_nl: 5.219
  [val] g_l2_loss_abs: 1.049
  [val] g_l2_loss_rel: 1.049
  [train] ade: 1.452
  [train] ade_l: 3.009
  [train] ade_nl: 2.806
  [train] d_loss: 1.303
  [train] fde: 2.622
  [train] fde_l: 5.434
  [train] fde_nl: 5.067
  [train] g_l2_loss_abs: 0.964
  [train] g_l2_loss_rel: 0.964
New low for avg_disp_error
New low for avg_disp_error_nl
Saving checkpoint to ./saved_models/S_hotel_8_model.pt


100%|██████████| 46/46 [02:13<00:00,  2.91s/it, G_l2=1.01, G_adv=0.689, G_distill=1.01, G_feat=0.293, D=1.19]
100%|██████████| 46/46 [02:13<00:00,  2.89s/it, G_l2=0.949, G_adv=0.675, G_distill=0.949, G_feat=0.28, D=1.14]
100%|██████████| 46/46 [02:11<00:00,  2.86s/it, G_l2=0.867, G_adv=0.693, G_distill=0.867, G_feat=0.242, D=1.42]
100%|██████████| 46/46 [02:13<00:00,  2.89s/it, G_l2=0.841, G_adv=0.693, G_distill=0.841, G_feat=0.284, D=1.4]
100%|██████████| 46/46 [02:12<00:00,  2.89s/it, G_l2=0.795, G_adv=0.693, G_distill=0.795, G_feat=0.268, D=1.3]
100%|██████████| 46/46 [02:10<00:00,  2.84s/it, G_l2=1.05, G_adv=0.688, G_distill=1.05, G_feat=0.226, D=1.37]
100%|██████████| 46/46 [02:12<00:00,  2.89s/it, G_l2=0.776, G_adv=0.692, G_distill=0.776, G_feat=0.252, D=1.37]
100%|██████████| 46/46 [02:14<00:00,  2.92s/it, G_l2=0.707, G_adv=0.693, G_distill=0.707, G_feat=0.24, D=1.14] 
100%|██████████| 46/46 [02:10<00:00,  2.84s/it, G_l2=0.759, G_adv=0.689, G_distill=0.759, G_feat=0.27, D=1.38] 

Checking stats on val ...
Checking stats on train ...


 15%|█▌        | 7/46 [00:21<01:57,  3.01s/it, G_l2=0.999, G_adv=0.693, G_distill=0.999, G_feat=0.19, D=1.31]

  [val] ade: 1.196
  [val] ade_l: 2.516
  [val] ade_nl: 2.281
  [val] d_loss: 1.340
  [val] fde: 2.263
  [val] fde_l: 4.758
  [val] fde_nl: 4.315
  [val] g_l2_loss_abs: 0.698
  [val] g_l2_loss_rel: 0.698
  [train] ade: 1.232
  [train] ade_l: 2.542
  [train] ade_nl: 2.391
  [train] d_loss: 1.325
  [train] fde: 2.305
  [train] fde_l: 4.756
  [train] fde_nl: 4.473
  [train] g_l2_loss_abs: 0.738
  [train] g_l2_loss_rel: 0.738
New low for avg_disp_error
New low for avg_disp_error_nl
Saving checkpoint to ./saved_models/S_hotel_8_model.pt


100%|██████████| 46/46 [02:14<00:00,  2.93s/it, G_l2=0.881, G_adv=0.682, G_distill=0.881, G_feat=0.179, D=1.31]
100%|██████████| 46/46 [02:11<00:00,  2.86s/it, G_l2=0.874, G_adv=0.693, G_distill=0.874, G_feat=0.209, D=1.3] 
100%|██████████| 46/46 [02:11<00:00,  2.85s/it, G_l2=0.625, G_adv=0.693, G_distill=0.625, G_feat=0.252, D=1.3] 
100%|██████████| 46/46 [02:11<00:00,  2.86s/it, G_l2=0.734, G_adv=0.67, G_distill=0.734, G_feat=0.155, D=1.32] 
100%|██████████| 46/46 [02:11<00:00,  2.86s/it, G_l2=0.542, G_adv=0.693, G_distill=0.542, G_feat=0.181, D=1.32]
100%|██████████| 46/46 [02:13<00:00,  2.91s/it, G_l2=0.595, G_adv=0.691, G_distill=0.595, G_feat=0.151, D=1.4] 
100%|██████████| 46/46 [02:13<00:00,  2.90s/it, G_l2=0.813, G_adv=0.692, G_distill=0.813, G_feat=0.179, D=1.24]
100%|██████████| 46/46 [02:11<00:00,  2.86s/it, G_l2=0.579, G_adv=0.692, G_distill=0.579, G_feat=0.148, D=1.35]
100%|██████████| 46/46 [02:11<00:00,  2.87s/it, G_l2=0.674, G_adv=0.691, G_distill=0.674, G_feat=0.245, 

Checking stats on val ...
Checking stats on train ...


 20%|█▉        | 9/46 [00:26<01:47,  2.90s/it, G_l2=0.966, G_adv=0.691, G_distill=0.966, G_feat=0.167, D=1.25]

  [val] ade: 1.008
  [val] ade_l: 2.119
  [val] ade_nl: 1.921
  [val] d_loss: 1.334
  [val] fde: 1.937
  [val] fde_l: 4.073
  [val] fde_nl: 3.694
  [val] g_l2_loss_abs: 0.503
  [val] g_l2_loss_rel: 0.503
  [train] ade: 1.002
  [train] ade_l: 2.070
  [train] ade_nl: 1.940
  [train] d_loss: 1.315
  [train] fde: 1.928
  [train] fde_l: 3.985
  [train] fde_nl: 3.736
  [train] g_l2_loss_abs: 0.498
  [train] g_l2_loss_rel: 0.498
New low for avg_disp_error
New low for avg_disp_error_nl
Saving checkpoint to ./saved_models/S_hotel_8_model.pt


100%|██████████| 46/46 [02:14<00:00,  2.93s/it, G_l2=0.54, G_adv=0.692, G_distill=0.54, G_feat=0.183, D=1.24]  
100%|██████████| 46/46 [02:10<00:00,  2.84s/it, G_l2=0.637, G_adv=0.693, G_distill=0.637, G_feat=0.169, D=1.38]
100%|██████████| 46/46 [02:11<00:00,  2.85s/it, G_l2=0.635, G_adv=0.69, G_distill=0.635, G_feat=0.159, D=1.3]  
100%|██████████| 46/46 [02:11<00:00,  2.87s/it, G_l2=0.765, G_adv=0.692, G_distill=0.765, G_feat=0.227, D=1.35]
100%|██████████| 46/46 [02:12<00:00,  2.89s/it, G_l2=0.664, G_adv=0.692, G_distill=0.664, G_feat=0.153, D=1.39]
100%|██████████| 46/46 [02:11<00:00,  2.86s/it, G_l2=0.776, G_adv=0.692, G_distill=0.776, G_feat=0.168, D=1.25]
100%|██████████| 46/46 [02:12<00:00,  2.89s/it, G_l2=0.53, G_adv=0.693, G_distill=0.53, G_feat=0.164, D=1.38]  
100%|██████████| 46/46 [02:11<00:00,  2.86s/it, G_l2=0.983, G_adv=0.687, G_distill=0.983, G_feat=0.169, D=1.35]
100%|██████████| 46/46 [02:10<00:00,  2.84s/it, G_l2=0.558, G_adv=0.693, G_distill=0.558, G_feat=0.198, 

Checking stats on val ...
Checking stats on train ...


 24%|██▍       | 11/46 [00:33<01:45,  3.03s/it, G_l2=0.932, G_adv=0.692, G_distill=0.932, G_feat=0.13, D=1.39]

  [val] ade: 0.927
  [val] ade_l: 1.949
  [val] ade_nl: 1.767
  [val] d_loss: 1.327
  [val] fde: 1.812
  [val] fde_l: 3.811
  [val] fde_nl: 3.456
  [val] g_l2_loss_abs: 0.446
  [val] g_l2_loss_rel: 0.446
  [train] ade: 0.902
  [train] ade_l: 1.851
  [train] ade_nl: 1.760
  [train] d_loss: 1.288
  [train] fde: 1.754
  [train] fde_l: 3.599
  [train] fde_nl: 3.422
  [train] g_l2_loss_abs: 0.423
  [train] g_l2_loss_rel: 0.423
New low for avg_disp_error
New low for avg_disp_error_nl
Saving checkpoint to ./saved_models/S_hotel_8_model.pt


100%|██████████| 46/46 [02:15<00:00,  2.94s/it, G_l2=0.876, G_adv=0.689, G_distill=0.876, G_feat=0.194, D=1.4] 
100%|██████████| 46/46 [02:13<00:00,  2.91s/it, G_l2=0.598, G_adv=0.692, G_distill=0.598, G_feat=0.151, D=1.29]
100%|██████████| 46/46 [02:13<00:00,  2.91s/it, G_l2=0.627, G_adv=0.688, G_distill=0.627, G_feat=0.161, D=1.4] 
100%|██████████| 46/46 [02:13<00:00,  2.91s/it, G_l2=0.537, G_adv=0.69, G_distill=0.537, G_feat=0.172, D=1.17] 
100%|██████████| 46/46 [02:12<00:00,  2.87s/it, G_l2=0.554, G_adv=0.693, G_distill=0.554, G_feat=0.158, D=1.33]
100%|██████████| 46/46 [02:12<00:00,  2.88s/it, G_l2=0.825, G_adv=0.69, G_distill=0.825, G_feat=0.18, D=1.4]   
100%|██████████| 46/46 [02:12<00:00,  2.88s/it, G_l2=0.74, G_adv=0.676, G_distill=0.74, G_feat=0.144, D=1.35]  
100%|██████████| 46/46 [02:12<00:00,  2.88s/it, G_l2=0.863, G_adv=0.684, G_distill=0.863, G_feat=0.202, D=1.2] 
100%|██████████| 46/46 [02:11<00:00,  2.85s/it, G_l2=0.859, G_adv=0.688, G_distill=0.859, G_feat=0.155, 

Checking stats on val ...
Checking stats on train ...


 28%|██▊       | 13/46 [00:38<01:36,  2.94s/it, G_l2=0.685, G_adv=0.693, G_distill=0.685, G_feat=0.144, D=1.14]

  [val] ade: 0.871
  [val] ade_l: 1.832
  [val] ade_nl: 1.661
  [val] d_loss: 1.318
  [val] fde: 1.711
  [val] fde_l: 3.598
  [val] fde_nl: 3.263
  [val] g_l2_loss_abs: 0.408
  [val] g_l2_loss_rel: 0.408
  [train] ade: 0.845
  [train] ade_l: 1.775
  [train] ade_nl: 1.612
  [train] d_loss: 1.275
  [train] fde: 1.653
  [train] fde_l: 3.473
  [train] fde_nl: 3.154
  [train] g_l2_loss_abs: 0.378
  [train] g_l2_loss_rel: 0.378
New low for avg_disp_error
New low for avg_disp_error_nl
Saving checkpoint to ./saved_models/S_hotel_8_model.pt


100%|██████████| 46/46 [02:14<00:00,  2.91s/it, G_l2=0.549, G_adv=0.689, G_distill=0.549, G_feat=0.169, D=1.38]
100%|██████████| 46/46 [02:11<00:00,  2.87s/it, G_l2=0.522, G_adv=0.691, G_distill=0.522, G_feat=0.142, D=1.12]
100%|██████████| 46/46 [02:11<00:00,  2.85s/it, G_l2=0.75, G_adv=0.654, G_distill=0.75, G_feat=0.179, D=1.12]  
100%|██████████| 46/46 [02:13<00:00,  2.91s/it, G_l2=0.545, G_adv=0.685, G_distill=0.545, G_feat=0.197, D=1.14] 
100%|██████████| 46/46 [02:08<00:00,  2.80s/it, G_l2=0.812, G_adv=0.689, G_distill=0.812, G_feat=0.155, D=1.36] 
100%|██████████| 46/46 [02:11<00:00,  2.85s/it, G_l2=0.65, G_adv=0.688, G_distill=0.65, G_feat=0.145, D=1.03]   
100%|██████████| 46/46 [02:11<00:00,  2.86s/it, G_l2=0.501, G_adv=0.691, G_distill=0.501, G_feat=0.232, D=0.931]
100%|██████████| 46/46 [02:10<00:00,  2.85s/it, G_l2=0.465, G_adv=0.687, G_distill=0.465, G_feat=0.187, D=1.38] 
100%|██████████| 46/46 [02:10<00:00,  2.84s/it, G_l2=0.481, G_adv=0.673, G_distill=0.481, G_feat=0.

Checking stats on val ...
Checking stats on train ...


 33%|███▎      | 15/46 [00:44<01:32,  2.98s/it, G_l2=0.852, G_adv=0.687, G_distill=0.852, G_feat=0.124, D=1.15]

  [val] ade: 0.888
  [val] ade_l: 1.868
  [val] ade_nl: 1.694
  [val] d_loss: 1.206
  [val] fde: 1.754
  [val] fde_l: 3.689
  [val] fde_nl: 3.345
  [val] g_l2_loss_abs: 0.423
  [val] g_l2_loss_rel: 0.423
  [train] ade: 0.839
  [train] ade_l: 1.711
  [train] ade_nl: 1.644
  [train] d_loss: 1.115
  [train] fde: 1.657
  [train] fde_l: 3.382
  [train] fde_nl: 3.250
  [train] g_l2_loss_abs: 0.383
  [train] g_l2_loss_rel: 0.383
Saving checkpoint to ./saved_models/S_hotel_8_model.pt


100%|██████████| 46/46 [02:14<00:00,  2.92s/it, G_l2=0.657, G_adv=0.654, G_distill=0.657, G_feat=0.146, D=1.1]  
100%|██████████| 46/46 [02:12<00:00,  2.88s/it, G_l2=0.566, G_adv=0.692, G_distill=0.566, G_feat=0.173, D=1.12] 
100%|██████████| 46/46 [02:13<00:00,  2.90s/it, G_l2=0.771, G_adv=0.692, G_distill=0.771, G_feat=0.206, D=0.946]
100%|██████████| 46/46 [02:10<00:00,  2.83s/it, G_l2=0.872, G_adv=0.686, G_distill=0.872, G_feat=0.244, D=1.25] 
100%|██████████| 46/46 [02:12<00:00,  2.89s/it, G_l2=0.621, G_adv=0.673, G_distill=0.621, G_feat=0.174, D=1.51] 
100%|██████████| 46/46 [02:10<00:00,  2.83s/it, G_l2=0.677, G_adv=0.681, G_distill=0.677, G_feat=0.173, D=0.945]
100%|██████████| 46/46 [02:12<00:00,  2.88s/it, G_l2=0.696, G_adv=0.692, G_distill=0.696, G_feat=0.212, D=1.01] 
100%|██████████| 46/46 [02:12<00:00,  2.87s/it, G_l2=0.639, G_adv=0.659, G_distill=0.639, G_feat=0.151, D=1.23] 
100%|██████████| 46/46 [02:10<00:00,  2.84s/it, G_l2=0.565, G_adv=0.685, G_distill=0.565, G_feat

Checking stats on val ...
Checking stats on train ...


 37%|███▋      | 17/46 [00:51<01:29,  3.08s/it, G_l2=0.767, G_adv=0.664, G_distill=0.767, G_feat=0.134, D=1.23]

  [val] ade: 0.863
  [val] ade_l: 1.816
  [val] ade_nl: 1.646
  [val] d_loss: 0.956
  [val] fde: 1.712
  [val] fde_l: 3.599
  [val] fde_nl: 3.264
  [val] g_l2_loss_abs: 0.409
  [val] g_l2_loss_rel: 0.409
  [train] ade: 0.839
  [train] ade_l: 1.795
  [train] ade_nl: 1.574
  [train] d_loss: 1.174
  [train] fde: 1.653
  [train] fde_l: 3.539
  [train] fde_nl: 3.102
  [train] g_l2_loss_abs: 0.370
  [train] g_l2_loss_rel: 0.370
New low for avg_disp_error
New low for avg_disp_error_nl
Saving checkpoint to ./saved_models/S_hotel_8_model.pt


100%|██████████| 46/46 [02:15<00:00,  2.95s/it, G_l2=0.622, G_adv=0.656, G_distill=0.622, G_feat=0.177, D=0.809]
100%|██████████| 46/46 [02:11<00:00,  2.86s/it, G_l2=0.484, G_adv=0.687, G_distill=0.484, G_feat=0.187, D=1.48] 
100%|██████████| 46/46 [02:12<00:00,  2.89s/it, G_l2=0.762, G_adv=0.692, G_distill=0.762, G_feat=0.165, D=1.09] 
100%|██████████| 46/46 [02:10<00:00,  2.84s/it, G_l2=0.652, G_adv=0.653, G_distill=0.652, G_feat=0.175, D=0.711]
100%|██████████| 46/46 [02:12<00:00,  2.89s/it, G_l2=0.852, G_adv=0.637, G_distill=0.852, G_feat=0.16, D=0.711] 
100%|██████████| 46/46 [02:14<00:00,  2.91s/it, G_l2=0.72, G_adv=0.652, G_distill=0.72, G_feat=0.159, D=1.07]   
100%|██████████| 46/46 [02:10<00:00,  2.84s/it, G_l2=0.501, G_adv=0.674, G_distill=0.501, G_feat=0.137, D=0.865]
100%|██████████| 46/46 [02:12<00:00,  2.87s/it, G_l2=0.768, G_adv=0.686, G_distill=0.768, G_feat=0.201, D=1.35] 
100%|██████████| 46/46 [02:12<00:00,  2.87s/it, G_l2=0.518, G_adv=0.675, G_distill=0.518, G_feat

Checking stats on val ...
Checking stats on train ...


 41%|████▏     | 19/46 [00:58<01:26,  3.19s/it, G_l2=0.599, G_adv=0.68, G_distill=0.599, G_feat=0.116, D=0.652]

  [val] ade: 0.846
  [val] ade_l: 1.779
  [val] ade_nl: 1.613
  [val] d_loss: 1.093
  [val] fde: 1.683
  [val] fde_l: 3.539
  [val] fde_nl: 3.209
  [val] g_l2_loss_abs: 0.409
  [val] g_l2_loss_rel: 0.409
  [train] ade: 0.802
  [train] ade_l: 1.663
  [train] ade_nl: 1.549
  [train] d_loss: 1.014
  [train] fde: 1.596
  [train] fde_l: 3.310
  [train] fde_nl: 3.083
  [train] g_l2_loss_abs: 0.353
  [train] g_l2_loss_rel: 0.353
New low for avg_disp_error
New low for avg_disp_error_nl
Saving checkpoint to ./saved_models/S_hotel_8_model.pt


100%|██████████| 46/46 [02:17<00:00,  2.99s/it, G_l2=0.667, G_adv=0.67, G_distill=0.667, G_feat=0.191, D=0.525] 
100%|██████████| 46/46 [02:11<00:00,  2.85s/it, G_l2=0.668, G_adv=0.689, G_distill=0.668, G_feat=0.134, D=0.816]
100%|██████████| 46/46 [02:12<00:00,  2.89s/it, G_l2=0.591, G_adv=0.692, G_distill=0.591, G_feat=0.273, D=1.42] 
100%|██████████| 46/46 [02:09<00:00,  2.82s/it, G_l2=0.493, G_adv=0.69, G_distill=0.493, G_feat=0.152, D=1.22]  
100%|██████████| 46/46 [02:15<00:00,  2.93s/it, G_l2=0.793, G_adv=0.661, G_distill=0.793, G_feat=0.146, D=0.997]
100%|██████████| 46/46 [02:16<00:00,  2.96s/it, G_l2=0.637, G_adv=0.67, G_distill=0.637, G_feat=0.124, D=1.52]  
100%|██████████| 46/46 [02:11<00:00,  2.86s/it, G_l2=0.513, G_adv=0.684, G_distill=0.513, G_feat=0.126, D=1.65] 
100%|██████████| 46/46 [02:13<00:00,  2.89s/it, G_l2=0.533, G_adv=0.641, G_distill=0.533, G_feat=0.17, D=0.94]  
100%|██████████| 46/46 [02:14<00:00,  2.92s/it, G_l2=0.747, G_adv=0.689, G_distill=0.747, G_feat

Checking stats on val ...
Checking stats on train ...


 46%|████▌     | 21/46 [01:02<01:15,  3.03s/it, G_l2=0.757, G_adv=0.674, G_distill=0.757, G_feat=0.15, D=0.935]

  [val] ade: 0.793
  [val] ade_l: 1.667
  [val] ade_nl: 1.512
  [val] d_loss: 0.989
  [val] fde: 1.585
  [val] fde_l: 3.333
  [val] fde_nl: 3.022
  [val] g_l2_loss_abs: 0.362
  [val] g_l2_loss_rel: 0.362
  [train] ade: 0.776
  [train] ade_l: 1.640
  [train] ade_nl: 1.473
  [train] d_loss: 0.885
  [train] fde: 1.556
  [train] fde_l: 3.287
  [train] fde_nl: 2.953
  [train] g_l2_loss_abs: 0.340
  [train] g_l2_loss_rel: 0.340
New low for avg_disp_error
New low for avg_disp_error_nl
Saving checkpoint to ./saved_models/S_hotel_8_model.pt


100%|██████████| 46/46 [02:15<00:00,  2.95s/it, G_l2=0.611, G_adv=0.687, G_distill=0.611, G_feat=0.156, D=1.3]  
100%|██████████| 46/46 [02:11<00:00,  2.85s/it, G_l2=0.609, G_adv=0.671, G_distill=0.609, G_feat=0.152, D=0.65] 
100%|██████████| 46/46 [02:13<00:00,  2.89s/it, G_l2=0.616, G_adv=0.671, G_distill=0.616, G_feat=0.174, D=1.27] 
100%|██████████| 46/46 [02:11<00:00,  2.87s/it, G_l2=0.643, G_adv=0.691, G_distill=0.643, G_feat=0.212, D=1.1]  
100%|██████████| 46/46 [02:12<00:00,  2.88s/it, G_l2=0.606, G_adv=0.688, G_distill=0.606, G_feat=0.121, D=1.4]  
100%|██████████| 46/46 [02:13<00:00,  2.89s/it, G_l2=0.587, G_adv=0.691, G_distill=0.587, G_feat=0.119, D=1.29] 
100%|██████████| 46/46 [02:13<00:00,  2.90s/it, G_l2=0.606, G_adv=0.654, G_distill=0.606, G_feat=0.195, D=0.453]
100%|██████████| 46/46 [02:12<00:00,  2.87s/it, G_l2=0.846, G_adv=0.691, G_distill=0.846, G_feat=0.126, D=0.781]
100%|██████████| 46/46 [02:12<00:00,  2.87s/it, G_l2=0.59, G_adv=0.684, G_distill=0.59, G_feat=0

Checking stats on val ...
Checking stats on train ...


 50%|█████     | 23/46 [01:06<01:06,  2.89s/it, G_l2=0.648, G_adv=0.688, G_distill=0.648, G_feat=0.112, D=0.835]

  [val] ade: 0.866
  [val] ade_l: 1.821
  [val] ade_nl: 1.651
  [val] d_loss: 0.831
  [val] fde: 1.718
  [val] fde_l: 3.613
  [val] fde_nl: 3.276
  [val] g_l2_loss_abs: 0.413
  [val] g_l2_loss_rel: 0.413
  [train] ade: 0.825
  [train] ade_l: 1.674
  [train] ade_nl: 1.626
  [train] d_loss: 1.117
  [train] fde: 1.649
  [train] fde_l: 3.347
  [train] fde_nl: 3.250
  [train] g_l2_loss_abs: 0.369
  [train] g_l2_loss_rel: 0.369
Saving checkpoint to ./saved_models/S_hotel_8_model.pt


100%|██████████| 46/46 [02:14<00:00,  2.91s/it, G_l2=0.719, G_adv=0.672, G_distill=0.719, G_feat=0.119, D=1.27] 
100%|██████████| 46/46 [02:12<00:00,  2.88s/it, G_l2=0.694, G_adv=0.692, G_distill=0.694, G_feat=0.166, D=1.12] 
100%|██████████| 46/46 [02:03<00:00,  2.69s/it, G_l2=0.919, G_adv=0.687, G_distill=0.919, G_feat=0.156, D=0.463]
100%|██████████| 46/46 [02:11<00:00,  2.87s/it, G_l2=0.497, G_adv=0.688, G_distill=0.497, G_feat=0.128, D=0.928]
100%|██████████| 46/46 [02:14<00:00,  2.92s/it, G_l2=0.525, G_adv=0.692, G_distill=0.525, G_feat=0.256, D=1.07]  
100%|██████████| 46/46 [02:12<00:00,  2.88s/it, G_l2=0.606, G_adv=0.692, G_distill=0.606, G_feat=0.139, D=1.43] 
100%|██████████| 46/46 [02:11<00:00,  2.85s/it, G_l2=0.64, G_adv=0.682, G_distill=0.64, G_feat=0.175, D=0.977]  
100%|██████████| 46/46 [02:10<00:00,  2.83s/it, G_l2=0.594, G_adv=0.688, G_distill=0.594, G_feat=0.121, D=0.606]
100%|██████████| 46/46 [02:12<00:00,  2.87s/it, G_l2=0.513, G_adv=0.692, G_distill=0.513, G_fea

Checking stats on val ...
Checking stats on train ...


 54%|█████▍    | 25/46 [01:11<01:01,  2.93s/it, G_l2=0.756, G_adv=0.679, G_distill=0.756, G_feat=0.123, D=0.556]

  [val] ade: 0.813
  [val] ade_l: 1.710
  [val] ade_nl: 1.550
  [val] d_loss: 0.949
  [val] fde: 1.626
  [val] fde_l: 3.419
  [val] fde_nl: 3.100
  [val] g_l2_loss_abs: 0.383
  [val] g_l2_loss_rel: 0.383
  [train] ade: 0.779
  [train] ade_l: 1.586
  [train] ade_nl: 1.531
  [train] d_loss: 0.850
  [train] fde: 1.569
  [train] fde_l: 3.195
  [train] fde_nl: 3.084
  [train] g_l2_loss_abs: 0.346
  [train] g_l2_loss_rel: 0.346
Saving checkpoint to ./saved_models/S_hotel_8_model.pt


100%|██████████| 46/46 [02:13<00:00,  2.89s/it, G_l2=0.631, G_adv=0.691, G_distill=0.631, G_feat=0.154, D=0.928]
100%|██████████| 46/46 [02:09<00:00,  2.82s/it, G_l2=0.755, G_adv=0.69, G_distill=0.755, G_feat=0.188, D=1.1]   
100%|██████████| 46/46 [02:09<00:00,  2.82s/it, G_l2=0.473, G_adv=0.673, G_distill=0.473, G_feat=0.194, D=1.61] 
100%|██████████| 46/46 [02:27<00:00,  3.21s/it, G_l2=0.896, G_adv=0.677, G_distill=0.896, G_feat=0.231, D=0.767]
100%|██████████| 46/46 [04:14<00:00,  5.54s/it, G_l2=0.574, G_adv=0.687, G_distill=0.574, G_feat=0.15, D=1.18]  
100%|██████████| 46/46 [04:32<00:00,  5.93s/it, G_l2=0.524, G_adv=0.687, G_distill=0.524, G_feat=0.12, D=0.36]  
100%|██████████| 46/46 [04:02<00:00,  5.28s/it, G_l2=0.501, G_adv=0.683, G_distill=0.501, G_feat=0.151, D=1.62] 
100%|██████████| 46/46 [02:09<00:00,  2.81s/it, G_l2=0.714, G_adv=0.668, G_distill=0.714, G_feat=0.144, D=0.803]
100%|██████████| 46/46 [02:09<00:00,  2.81s/it, G_l2=0.62, G_adv=0.687, G_distill=0.62, G_feat=0

Checking stats on val ...
Checking stats on train ...


 57%|█████▋    | 26/46 [01:22<01:43,  5.16s/it, G_l2=0.9, G_adv=0.677, G_distill=0.9, G_feat=0.12, D=0.739]

  [val] ade: 0.826
  [val] ade_l: 1.737
  [val] ade_nl: 1.575
  [val] d_loss: 0.824
  [val] fde: 1.652
  [val] fde_l: 3.474
  [val] fde_nl: 3.151
  [val] g_l2_loss_abs: 0.386
  [val] g_l2_loss_rel: 0.386
  [train] ade: 0.780
  [train] ade_l: 1.578
  [train] ade_nl: 1.542
  [train] d_loss: 0.698
  [train] fde: 1.571
  [train] fde_l: 3.178
  [train] fde_nl: 3.105
  [train] g_l2_loss_abs: 0.340
  [train] g_l2_loss_rel: 0.340
Saving checkpoint to ./saved_models/S_hotel_8_model.pt


100%|██████████| 46/46 [02:42<00:00,  3.53s/it, G_l2=0.666, G_adv=0.686, G_distill=0.666, G_feat=0.157, D=0.745]
100%|██████████| 46/46 [03:59<00:00,  5.20s/it, G_l2=0.557, G_adv=0.691, G_distill=0.557, G_feat=0.158, D=0.853]
100%|██████████| 46/46 [03:55<00:00,  5.12s/it, G_l2=0.709, G_adv=0.69, G_distill=0.709, G_feat=0.169, D=0.765] 
100%|██████████| 46/46 [03:57<00:00,  5.15s/it, G_l2=0.644, G_adv=0.686, G_distill=0.644, G_feat=0.17, D=1.22]  
100%|██████████| 46/46 [03:55<00:00,  5.13s/it, G_l2=0.849, G_adv=0.684, G_distill=0.849, G_feat=0.132, D=1.25] 
100%|██████████| 46/46 [03:51<00:00,  5.03s/it, G_l2=0.682, G_adv=0.689, G_distill=0.682, G_feat=0.128, D=1.42] 
100%|██████████| 46/46 [03:41<00:00,  4.82s/it, G_l2=0.819, G_adv=0.687, G_distill=0.819, G_feat=0.169, D=1.25] 
100%|██████████| 46/46 [03:44<00:00,  4.89s/it, G_l2=0.443, G_adv=0.691, G_distill=0.443, G_feat=0.124, D=1.07] 
100%|██████████| 46/46 [03:42<00:00,  4.84s/it, G_l2=0.69, G_adv=0.692, G_distill=0.69, G_feat=0

Checking stats on val ...
Checking stats on train ...


 61%|██████    | 28/46 [02:23<02:09,  7.20s/it, G_l2=0.755, G_adv=0.687, G_distill=0.755, G_feat=0.135, D=0.896]

  [val] ade: 0.829
  [val] ade_l: 1.744
  [val] ade_nl: 1.581
  [val] d_loss: 0.966
  [val] fde: 1.659
  [val] fde_l: 3.489
  [val] fde_nl: 3.164
  [val] g_l2_loss_abs: 0.383
  [val] g_l2_loss_rel: 0.383
  [train] ade: 0.788
  [train] ade_l: 1.594
  [train] ade_nl: 1.558
  [train] d_loss: 0.910
  [train] fde: 1.574
  [train] fde_l: 3.185
  [train] fde_nl: 3.112
  [train] g_l2_loss_abs: 0.342
  [train] g_l2_loss_rel: 0.342
Saving checkpoint to ./saved_models/S_hotel_8_model.pt


100%|██████████| 46/46 [03:50<00:00,  5.01s/it, G_l2=0.502, G_adv=0.684, G_distill=0.502, G_feat=0.12, D=0.664] 
100%|██████████| 46/46 [03:39<00:00,  4.78s/it, G_l2=0.555, G_adv=0.691, G_distill=0.555, G_feat=0.159, D=0.901]
100%|██████████| 46/46 [03:42<00:00,  4.84s/it, G_l2=0.636, G_adv=0.691, G_distill=0.636, G_feat=0.123, D=0.71]  
100%|██████████| 46/46 [03:41<00:00,  4.81s/it, G_l2=0.555, G_adv=0.689, G_distill=0.555, G_feat=0.197, D=0.948]
100%|██████████| 46/46 [03:42<00:00,  4.84s/it, G_l2=0.884, G_adv=0.678, G_distill=0.884, G_feat=0.115, D=1.36] 
100%|██████████| 46/46 [03:46<00:00,  4.92s/it, G_l2=0.616, G_adv=0.687, G_distill=0.616, G_feat=0.163, D=1.64] 
100%|██████████| 46/46 [03:42<00:00,  4.83s/it, G_l2=0.515, G_adv=0.689, G_distill=0.515, G_feat=0.212, D=0.95] 
 85%|████████▍ | 39/46 [03:07<00:29,  4.14s/it, G_l2=0.855, G_adv=0.684, G_distill=0.855, G_feat=0.164, D=0.287]

In [1]:
metrics_val = train.check_accuracy(
    args, val_loader, generator_S, discriminator_S, d_loss_fn
)
print('Checking stats on train ...')
metrics_train = train.check_accuracy(
    args, train_loader, generator_S, discriminator_S,
    d_loss_fn, limit=True
)
for k, v in sorted(metrics_val.items()):
    print('  [val] {}: {:.3f}'.format(k, v))
    checkpoint['metrics_val'][k].append(v)
for k, v in sorted(metrics_train.items()):
    print('  [train] {}: {:.3f}'.format(k, v))
    checkpoint['metrics_train'][k].append(v)

NameError: name 'train' is not defined

In [12]:
for k, v in sorted(metrics_val.items()):
    print('  [val] {}: {:.3f}'.format(k, v))
    checkpoint['metrics_val'][k].append(v)
for k, v in sorted(metrics_train.items()):
    print('  [train] {}: {:.3f}'.format(k, v))
    checkpoint['metrics_train'][k].append(v)

  [val] ade: 0.787
  [val] ade_l: 1.656
  [val] ade_nl: 1.501
  [val] d_loss: 0.817
  [val] fde: 1.606
  [val] fde_l: 3.377
  [val] fde_nl: 3.062
  [val] g_l2_loss_abs: 0.372
  [val] g_l2_loss_rel: 0.372
  [train] ade: 0.719
  [train] ade_l: 1.464
  [train] ade_nl: 1.414
  [train] d_loss: 0.953
  [train] fde: 1.460
  [train] fde_l: 2.972
  [train] fde_nl: 2.870
  [train] g_l2_loss_abs: 0.305
  [train] g_l2_loss_rel: 0.305
