In [None]:
#!/usr/bin/env python
from pointer_network import PointerNetwork
from options import get_options
from problem import TSP
import torch
import pprint as pp
from critic_network import CriticNetwork, CriticBaseline
import torch.optim as optim
from tensorboard_logger import Logger as TbLogger

opts = get_options()
# Pretty print the run args
pp.pprint(vars(opts))

# Set the random seed
torch.manual_seed(opts.seed)


opts.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Load data from load_path
load_data = {}
assert opts.load_path is None or opts.resume is None, "Only one of load path and resume can be given"
load_path = opts.load_path if opts.load_path is not None else opts.resume
if load_path is not None:
    print('  [*] Loading data from {}'.format(load_path))
    load_data = torch.load(load_path, map_location=lambda storage, loc: storage) 

# Initialize model and send to device
model = PointerNetwork(
        opts.embedding_dim,
        opts.hidden_dim,
        TSP,
        n_encode_layers=opts.n_encode_layers,
        mask_inner=True,
        mask_logits=True,
        normalization=opts.normalization,
        tanh_clipping=opts.tanh_clipping,
        checkpoint_encoder=opts.checkpoint_encoder,
        shrink_size=opts.shrink_size
    ).to(opts.device)

model.load_state_dict({**model.state_dict(), **load_data.get('model', {})})


if opts.baseline == 'critic':
    baseline = CriticBaseline(
            (            
                CriticNetwork(
                    2,
                    opts.embedding_dim,
                    opts.hidden_dim,
                    opts.n_encode_layers,
                    opts.normalization
                )
            ).to(opts.device)
        )
    

# Initialize optimizer
optimizer = optim.Adam(
    [{'params': model.parameters(), 'lr': opts.lr_model}]
    + (
        [{'params': baseline.get_learnable_parameters(), 'lr': opts.lr_critic}]
        if len(baseline.get_learnable_parameters()) > 0
        else []
    )
)