In [None]:
import Pkg
Pkg.activate("..")

In [None]:
using ArgParse: ArgParseSettings, @add_arg_table!, parse_args
using Statistics: mean
using Printf
using Knet
using Debugger: @enter, @bp, @run
using JSON
using PyCall

include("LPCore.jl")
include("setup.jl")

s = ArgParseSettings()
@add_arg_table! s begin
    "--dataset"
        help = "which environment to use"
        arg_type = String
        default = "halfcheetah-medium-expert-v2"
    "--exp_name"
        help = "name of the experiment"
        arg_type = String
        default = "debug"
    "--seed"
        help = "seed"
        default = 42
    "--beam_width"
        default = 64
    "--n_expand"
        default = 4
    "--suffix"
        default = ""
    "--config"
        help = "relative jl file path with configurations"
        arg_type = String
        default = "../config/vqvae.jl"
end;

In [None]:
@pyimport torch
@pyimport numpy

weights = torch.load("test/files/gpt_trained.pt", map_location=torch.device("cpu"))
prior_weights = torch.load("test/files/prior_model.pt");

In [None]:
#######################
####### setup ########
#######################

super_args = Dict{String, Any}(
    "dataset"=> "hopper-medium-replay-v2",
    "exp_name"=> "T-1-42",
    "seed"=> 42,
    "config"=> "../config/vqvae.jl",
    "beam_width"=>64,
    "n_expand"=>4,
    "suffix"=>""
)
args = parser(super_args, experiment="plan")

args["logbase"] = expanduser(args["logbase"])
args["savepath"] = expanduser(args["savepath"])
args["loadpath"] = joinpath(args["logbase"], args["dataset"], args["exp_name"])

In [None]:
env = load_environment(args["dataset"])
dataset_config = Knet.load(joinpath(args["loadpath"] , "dataset_config.jld2"), "config")

dataset = SequenceDataset(
    dataset_config["env_name"];
    penalty=dataset_config["penalty"],
    sequence_length=dataset_config["sequence_length"], 
    step=dataset_config["step"], 
    discount=dataset_config["discount"], 
    disable_goal=dataset_config["disable_goal"], 
    normalize_raw=dataset_config["normalize_raw"], 
    normalize_reward=dataset_config["normalize_reward"],
    max_path_length=dataset_config["max_path_length"],
    atype=dataset_config["atype"]
);

# Representation model init and weight loading

In [None]:
model_config = Knet.load(joinpath("/Users/mehmeteneserciyes/logs_julia/hopper-medium-replay-v2/T-1-42", "model_config.jld2"), "config")

In [None]:
gpt = VQContinuousVAE(model_config);

In [None]:
# encoder
gpt.model.embed.w = Param(atype(weights["model.embed.weight"][:cpu]()[:numpy]()))
gpt.model.embed.b = Param(atype(weights["model.embed.bias"][:cpu]()[:numpy]()))

gpt.model.pos_emb = Param(atype(permutedims(weights["model.pos_emb"][:cpu]()[:numpy](), (3,2,1))))

for i in 1:model_config["n_layer"]
    gpt.model.encoder.layers[i].ln1.a = Param(atype(weights["model.encoder.$(i-1).ln1.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].ln1.b = Param(atype(weights["model.encoder.$(i-1).ln1.bias"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].ln2.a = Param(atype(weights["model.encoder.$(i-1).ln2.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].ln2.b = Param(atype(weights["model.encoder.$(i-1).ln2.bias"][:cpu]()[:numpy]()))

    gpt.model.encoder.layers[i].attn.key.w = Param(atype(weights["model.encoder.$(i-1).attn.key.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].attn.key.b = Param(atype(weights["model.encoder.$(i-1).attn.key.bias"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].attn.query.w = Param(atype(weights["model.encoder.$(i-1).attn.query.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].attn.query.b = Param(atype(weights["model.encoder.$(i-1).attn.query.bias"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].attn.value.w = Param(atype(weights["model.encoder.$(i-1).attn.value.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].attn.value.b = Param(atype(weights["model.encoder.$(i-1).attn.value.bias"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].attn.proj.w = Param(atype(weights["model.encoder.$(i-1).attn.proj.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].attn.proj.b = Param(atype(weights["model.encoder.$(i-1).attn.proj.bias"][:cpu]()[:numpy]()))

    gpt.model.encoder.layers[i].mlp.layers[1].w = Param(atype(weights["model.encoder.$(i-1).mlp.0.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].mlp.layers[1].b = Param(atype(weights["model.encoder.$(i-1).mlp.0.bias"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].mlp.layers[3].w = Param(atype(weights["model.encoder.$(i-1).mlp.2.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].mlp.layers[3].b = Param(atype(weights["model.encoder.$(i-1).mlp.2.bias"][:cpu]()[:numpy]()))
end

gpt.model.cast_embed.w = Param(atype(weights["model.cast_embed.weight"][:cpu]()[:numpy]()))
gpt.model.cast_embed.b = Param(atype(weights["model.cast_embed.bias"][:cpu]()[:numpy]()))

# Decoder
gpt.model.latent_mixing.w = Param(atype(weights["model.latent_mixing.weight"][:cpu]()[:numpy]()))
gpt.model.latent_mixing.b = Param(atype(weights["model.latent_mixing.bias"][:cpu]()[:numpy]()))

for i in 1:model_config["n_layer"]
    gpt.model.decoder.layers[i].ln1.a = Param(atype(weights["model.decoder.$(i-1).ln1.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].ln1.b = Param(atype(weights["model.decoder.$(i-1).ln1.bias"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].ln2.a = Param(atype(weights["model.decoder.$(i-1).ln2.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].ln2.b = Param(atype(weights["model.decoder.$(i-1).ln2.bias"][:cpu]()[:numpy]()))

    gpt.model.decoder.layers[i].attn.key.w = Param(atype(weights["model.decoder.$(i-1).attn.key.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].attn.key.b = Param(atype(weights["model.decoder.$(i-1).attn.key.bias"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].attn.query.w = Param(atype(weights["model.decoder.$(i-1).attn.query.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].attn.query.b = Param(atype(weights["model.decoder.$(i-1).attn.query.bias"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].attn.value.w = Param(atype(weights["model.decoder.$(i-1).attn.value.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].attn.value.b = Param(atype(weights["model.decoder.$(i-1).attn.value.bias"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].attn.proj.w = Param(atype(weights["model.decoder.$(i-1).attn.proj.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].attn.proj.b = Param(atype(weights["model.decoder.$(i-1).attn.proj.bias"][:cpu]()[:numpy]()))

    gpt.model.decoder.layers[i].mlp.layers[1].w = Param(atype(weights["model.decoder.$(i-1).mlp.0.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].mlp.layers[1].b = Param(atype(weights["model.decoder.$(i-1).mlp.0.bias"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].mlp.layers[3].w = Param(atype(weights["model.decoder.$(i-1).mlp.2.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].mlp.layers[3].b = Param(atype(weights["model.decoder.$(i-1).mlp.2.bias"][:cpu]()[:numpy]()))
end

gpt.model.ln_f.a = Param(atype(weights["model.ln_f.weight"][:cpu]()[:numpy]()))
gpt.model.ln_f.b = Param(atype(weights["model.ln_f.bias"][:cpu]()[:numpy]()))

gpt.model.predict.w = Param(atype(weights["model.predict.weight"][:cpu]()[:numpy]()))
gpt.model.predict.b = Param(atype(weights["model.predict.bias"][:cpu]()[:numpy]()))

# codebook
gpt.model.codebook.embedding = Param(atype(weights["model.codebook.embedding"][:cpu]()[:numpy]()'))
gpt.model.codebook.ema_count = Param(atype(weights["model.codebook.ema_count"][:cpu]()[:numpy]()))
gpt.model.codebook.ema_w = Param(atype(weights["model.codebook.ema_w"][:cpu]()[:numpy]()'))

# padding vector
gpt.padding_vector = atype(normalize_joined_single(dataset, atype(zeros(gpt.transition_dim-1))));

# TransformerPrior Init and Model Loading

In [None]:
args = parser(super_args, experiment="train")
args["logbase"] = expanduser(args["logbase"])
args["savepath"] = expanduser(args["savepath"])
args["savepath"] = "/Users/mehmeteneserciyes/logs_julia/hopper-medium-replay-v2/T-1-42/"
block_size = args["subsampled_sequence_length"] ÷ args["latent_step"]
obs_dim = dataset.observation_dim

In [None]:
model_config = deepcopy(args)
model_config["block_size"] = block_size
model_config["observation_dim"] = obs_dim
model_config["n_embd"] = args["n_embd"] * args["n_head"]

In [None]:
# turn off dropout
model_config["embd_pdrop"] = 0.0f0
model_config["attn_pdrop"] = 0.0f0
model_config["resid_pdrop"] = 0.0f0

In [None]:
prior = TransformerPrior(model_config);

In [None]:
# encoder
prior.tok_emb = Param(atype(prior_weights["tok_emb.weight"][:cpu]()[:numpy]()'))
prior.pos_emb = Param(atype(permutedims(prior_weights["pos_emb"][:cpu]()[:numpy](), (3,2,1))))

prior.state_emb.w = Param(atype(prior_weights["state_emb.weight"][:cpu]()[:numpy]()))
prior.state_emb.b = Param(atype(prior_weights["state_emb.bias"][:cpu]()[:numpy]()))

for i in 1:model_config["n_layer"]
    prior.blocks.layers[i].ln1.a = Param(atype(prior_weights["blocks.$(i-1).ln1.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].ln1.b = Param(atype(prior_weights["blocks.$(i-1).ln1.bias"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].ln2.a = Param(atype(prior_weights["blocks.$(i-1).ln2.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].ln2.b = Param(atype(prior_weights["blocks.$(i-1).ln2.bias"][:cpu]()[:numpy]()))

    prior.blocks.layers[i].attn.key.w = Param(atype(prior_weights["blocks.$(i-1).attn.key.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].attn.key.b = Param(atype(prior_weights["blocks.$(i-1).attn.key.bias"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].attn.query.w = Param(atype(prior_weights["blocks.$(i-1).attn.query.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].attn.query.b = Param(atype(prior_weights["blocks.$(i-1).attn.query.bias"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].attn.value.w = Param(atype(prior_weights["blocks.$(i-1).attn.value.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].attn.value.b = Param(atype(prior_weights["blocks.$(i-1).attn.value.bias"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].attn.proj.w = Param(atype(prior_weights["blocks.$(i-1).attn.proj.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].attn.proj.b = Param(atype(prior_weights["blocks.$(i-1).attn.proj.bias"][:cpu]()[:numpy]()))

    prior.blocks.layers[i].mlp.layers[1].w = Param(atype(prior_weights["blocks.$(i-1).mlp.0.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].mlp.layers[1].b = Param(atype(prior_weights["blocks.$(i-1).mlp.0.bias"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].mlp.layers[3].w = Param(atype(prior_weights["blocks.$(i-1).mlp.2.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].mlp.layers[3].b = Param(atype(prior_weights["blocks.$(i-1).mlp.2.bias"][:cpu]()[:numpy]()))
end

prior.ln_f.a = Param(atype(prior_weights["ln_f.weight"][:cpu]()[:numpy]()))
prior.ln_f.b = Param(atype(prior_weights["ln_f.bias"][:cpu]()[:numpy]()))
# no bias
prior.head.w = Param(atype(prior_weights["head.weight"][:cpu]()[:numpy]())); 

# Planning part

In [None]:
observation_gt = numpy.load("test/files/plan_observation.npy")
state_gt = numpy.load("test/files/plan_state.npy")

In [None]:
discount = dataset.discount
observation_dim = dataset.observation_dim
action_dim = dataset.action_dim

#######################
###### main loop ######
#######################
REWARD_DIM = VALUE_DIM = 1
transition_dim = observation_dim + action_dim + REWARD_DIM + VALUE_DIM

observation = observation_gt
total_reward = 0
discount_return = 0

In [None]:
rollout = [deepcopy(state_gt)]

In [None]:
## previous (tokenized) transitions for conditioning transformer
context = []
mses = []

In [None]:
T = env.max_episode_steps

In [None]:
observation = observation_gt;
state = state_gt;

In [None]:
if dataset.normalized_raw
    println("normalize")
    observation = normalize_states(dataset, observation) # TODO: implement normalize_states
end

In [None]:
function make_prefix(obs, transition_dim)
    obs_discrete = atype(obs)
    pad_dims = atype(zeros(transition_dim - size(obs_discrete, 1)))
    if ndims(obs_discrete) == 2
        obs_discrete = reshape(obs_discrete, :, 1, 1)
        pad_dims = reshape(pad_dims, :, 1, 1)
    end
    transition = cat(obs_discrete, pad_dims, dims=1)
    prefix = transition
    return prefix
end

function extract_actions(x, observation_dim, action_dim, t=nothing)
    actions =  x[observation_dim:observation_dim+action_dim, :]
    if t != nothing
        return actions[:, t]
    else
        return actions
    end
end

In [None]:
prefix = make_prefix(observation, transition_dim)

# Beam with prior