# 测试一组数据，通过tensorboard画结构图

In [3]:
import argparse
import os
import torch
import pdb

import torch.nn as nn

from attrdict import AttrDict

from sgan.data.loader import data_loader
from sgan.models import TrajectoryGenerator
from sgan.losses import displacement_error, final_displacement_error
from sgan.utils import relative_to_abs, get_dset_path

def load_model(path):
    # torch.load最后返回的是一个dict，里面包含了保存模型时的一些参数和模型
    checkpoint = torch.load(path, map_location='cpu')
    generator = get_generator(checkpoint)
    # AttrDict是根据参数中的dict内容生成一个更加方便访问的dict实例
    args = AttrDict(checkpoint['args'])
    path_data = get_dset_path(args.dataset_name, "test")
    args.batch_size = 1
    _, loader = data_loader(args, path_data)

    return generator, loader


def get_generator(checkpoint):
    args = AttrDict(checkpoint['args'])
    generator = TrajectoryGenerator(
        obs_len=args.obs_len,
        pred_len=args.pred_len,
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        bottleneck_dim=args.bottleneck_dim,
        neighborhood_size=args.neighborhood_size,
        grid_size=args.grid_size,
        batch_norm=args.batch_norm)
    generator.load_state_dict(checkpoint['g_state'])

    generator.eval()
    return generator

In [4]:
from torch.utils.tensorboard import SummaryWriter


path = 'exp13_with_model.pt'
generator, loader = load_model(path)

In [7]:
with SummaryWriter('runs/test2') as wrtier:
    with torch.no_grad():
        for batch in loader:
            (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, 
            non_linear_ped, loss_mask, seq_start_end) = batch

            wrtier.add_graph(generator, (obs_traj, obs_traj_rel, seq_start_end))

            # [8, 4, 2]
            pred_traj_fake_rel = generator(
                obs_traj, obs_traj_rel, seq_start_end
            )

            break