In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir("..")

In [5]:
import math
from collections import OrderedDict

import numpy
import os
import ipdb
import random
import torch
import torch.optim as optim
from os import path

import wandb

import planning
import utils
from dataloader import DataLoader

In [11]:
opt = utils.parse_command_line(args=[])
opt.model_file = path.join(opt.model_dir, "policy_networks", "MPUR-" + opt.policy)
utils.build_model_file_name(opt)

os.system("mkdir -p " + path.join(opt.model_dir, "policy_networks"))

random.seed(opt.seed)
numpy.random.seed(opt.seed)
torch.manual_seed(opt.seed)

# Define default device
opt.device = torch.device(
    "cuda" if torch.cuda.is_available() and not opt.no_cuda else "cpu"
)
if torch.cuda.is_available() and opt.no_cuda:
    print(
        "WARNING: You have a CUDA device, so you should probably run without -no_cuda"
    )

# load the model

model_path = path.join(opt.model_dir, opt.mfile)
if path.exists(model_path):
    model = torch.load(model_path)
elif path.exists(opt.mfile):
    model = torch.load(opt.mfile)
else:
    raise RuntimeError(f"couldn't find file {opt.mfile}")
print("Loaded model")
if not hasattr(model.encoder, "n_channels"):
    model.encoder.n_channels = 3

if type(model) is dict:
    model = model["model"]
model.opt.lambda_l = opt.lambda_l  # used by planning.py/compute_uncertainty_batch
model.opt.lambda_o = opt.lambda_o  # used by planning.py/compute_uncertainty_batch
if opt.value_model != "":
    value_function = torch.load(
        path.join(opt.model_dir, "value_functions", opt.value_model)
    ).to(opt.device)
    model.value_function = value_function

# Create policy
model.create_policy_net(opt)
optimizer = optim.Adam(model.policy_net.parameters(), opt.lrt)  # POLICY optimiser ONLY!
print("Policy created")
# Load normalisation stats
stats = torch.load("traffic-data/state-action-cost/data_i80_v0/data_stats.pth")
model.stats = stats  # used by planning.py/compute_uncertainty_batch
print("Normalization loaded")
if "ten" in opt.mfile:
    p_z_file = opt.model_dir + opt.mfile + ".pz"
    p_z = torch.load(p_z_file)
    model.p_z = p_z

# Send to GPU if possible
model.to(opt.device)
model.policy_net.stats_d = {}
for k, v in stats.items():
    if isinstance(v, torch.Tensor):
        model.policy_net.stats_d[k] = v.to(opt.device)
if opt.learned_cost:
    print("[loading cost regressor]")
    model.cost = torch.load(path.join(opt.model_dir, opt.mfile + ".cost.model"))[
        "model"
    ]
print("Model setup completed")
model.train()
model.opt.u_hinge = opt.u_hinge
planning.estimate_uncertainty_stats(model, dataloader, n_batches=50, npred=opt.npred)
print("Uncertainty stats estimated")
model.eval()
print("done")

[will save as: models/policy_networks/MPUR-policy-deterministic-model=vae-zdropout=0.5-nfeature=256-bsize=6-npred=30-ureg=0.05-lambdal=0.2-lambdaa=0.0-gamma=0.99-lrtz=0.0-updatez=0-inferz=False-learnedcost=False-seed=1-novalue]
Loaded model
Policy created
Normalization loaded
Model setup completed
done


In [None]:
dataloader = DataLoader(None, opt, opt.dataset)
print("Data loaded")

### Train step

In [16]:
def step(what, nbatches, npred):
    train = True if what == "train" else False
    model.train()
    model.policy_net.train()
    n_updates, grad_norm = 0, 0
    total_losses = dict(
        proximity=0,
        uncertainty=0,
        lane=0,
        offroad=0,
        action=0,
        policy=0,
    )
    inputs, actions, targets, ids, car_sizes = dataloader.get_batch_fm(what, npred)
    pred, actions = planning.train_policy_net_mpur(
            model,
            inputs,
            targets,
            car_sizes,
            n_models=10,
            lrt_z=opt.lrt_z,
            n_updates_z=opt.z_updates,
            infer_z=opt.infer_z,
        )
    pred["policy"] = (
        pred["proximity"]
        + opt.u_reg * pred["uncertainty"]
        + opt.lambda_l * pred["lane"]
        + opt.lambda_a * pred["action"]
        + opt.lambda_o * pred["offroad"]
    )

    if not math.isnan(pred["policy"].item()):
        if train:
            optimizer.zero_grad()
            pred["policy"].backward()  # back-propagation through time!
            print(pred["policy"])
            grad_norm += utils.grad_norm(model.policy_net).item()
            torch.nn.utils.clip_grad_norm_(
                model.policy_net.parameters(), opt.grad_clip
            )
            optimizer.step()
        for loss in total_losses:
            total_losses[loss] += pred[loss].item()
        n_updates += 1
    else:
        print("warning, NaN")  # Oh no... Something got quite fucked up!
        ipdb.set_trace()
step("train", opt.epoch_size, opt.npred)

tensor(0.4407, device='cuda:0', grad_fn=<AddBackward0>)
