In [None]:
import Pkg
Pkg.activate("..")

In [None]:
using ArgParse: ArgParseSettings, @add_arg_table!, parse_args
using Statistics: mean
using Printf
using Knet
using Debugger: @enter, @bp, @run
using JSON
using PyCall

include("LPCore.jl")
include("setup.jl")

s = ArgParseSettings()
@add_arg_table! s begin
    "--dataset"
        help = "which environment to use"
        arg_type = String
        default = "halfcheetah-medium-expert-v2"
    "--exp_name"
        help = "name of the experiment"
        arg_type = String
        default = "debug"
    "--seed"
        help = "seed"
        default = 42
    "--beam_width"
        default = 64
    "--n_expand"
        default = 4
    "--suffix"
        default = ""
    "--config"
        help = "relative jl file path with configurations"
        arg_type = String
        default = "../config/vqvae.jl"
end;

In [None]:
@pyimport torch
@pyimport numpy
@pyimport random

seed = 42
numpy.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

weights = torch.load("test/files/gpt_trained.pt", map_location=torch.device("cpu"))
prior_weights = torch.load("test/files/prior_trained.pt");

In [None]:
#######################
####### setup ########
#######################

super_args = Dict{String, Any}(
    "dataset"=> "hopper-medium-replay-v2",
    "exp_name"=> "T-1-42",
    "seed"=> 42,
    "config"=> "../config/vqvae.jl",
    "beam_width"=>64,
    "n_expand"=>4,
    "suffix"=>""
)
args = parser(super_args, experiment="plan")

args["logbase"] = expanduser(args["logbase"])
args["savepath"] = expanduser(args["savepath"])
args["loadpath"] = joinpath(args["logbase"], args["dataset"], args["exp_name"])

In [None]:
env = load_environment(args["dataset"])
dataset_config = Knet.load(joinpath(args["loadpath"] , "dataset_config.jld2"), "config")

dataset = SequenceDataset(
    dataset_config["env_name"];
    penalty=dataset_config["penalty"],
    sequence_length=dataset_config["sequence_length"], 
    step=dataset_config["step"], 
    discount=dataset_config["discount"], 
    disable_goal=dataset_config["disable_goal"], 
    normalize_raw=dataset_config["normalize_raw"], 
    normalize_reward=dataset_config["normalize_reward"],
    max_path_length=dataset_config["max_path_length"],
    atype=dataset_config["atype"]
);

# Representation model init and weight loading

In [None]:
model_config = Knet.load(joinpath("/Users/mehmeteneserciyes/logs_julia/hopper-medium-replay-v2/T-1-42", "model_config.jld2"), "config")

In [None]:
gpt = VQContinuousVAE(model_config);

In [None]:
# encoder
gpt.model.embed.w = Param(atype(weights["model.embed.weight"][:cpu]()[:numpy]()))
gpt.model.embed.b = Param(atype(weights["model.embed.bias"][:cpu]()[:numpy]()))

gpt.model.pos_emb = Param(atype(permutedims(weights["model.pos_emb"][:cpu]()[:numpy](), (3,2,1))))

for i in 1:model_config["n_layer"]
    gpt.model.encoder.layers[i].ln1.a = Param(atype(weights["model.encoder.$(i-1).ln1.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].ln1.b = Param(atype(weights["model.encoder.$(i-1).ln1.bias"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].ln2.a = Param(atype(weights["model.encoder.$(i-1).ln2.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].ln2.b = Param(atype(weights["model.encoder.$(i-1).ln2.bias"][:cpu]()[:numpy]()))

    gpt.model.encoder.layers[i].attn.key.w = Param(atype(weights["model.encoder.$(i-1).attn.key.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].attn.key.b = Param(atype(weights["model.encoder.$(i-1).attn.key.bias"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].attn.query.w = Param(atype(weights["model.encoder.$(i-1).attn.query.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].attn.query.b = Param(atype(weights["model.encoder.$(i-1).attn.query.bias"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].attn.value.w = Param(atype(weights["model.encoder.$(i-1).attn.value.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].attn.value.b = Param(atype(weights["model.encoder.$(i-1).attn.value.bias"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].attn.proj.w = Param(atype(weights["model.encoder.$(i-1).attn.proj.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].attn.proj.b = Param(atype(weights["model.encoder.$(i-1).attn.proj.bias"][:cpu]()[:numpy]()))

    gpt.model.encoder.layers[i].mlp.layers[1].w = Param(atype(weights["model.encoder.$(i-1).mlp.0.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].mlp.layers[1].b = Param(atype(weights["model.encoder.$(i-1).mlp.0.bias"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].mlp.layers[3].w = Param(atype(weights["model.encoder.$(i-1).mlp.2.weight"][:cpu]()[:numpy]()))
    gpt.model.encoder.layers[i].mlp.layers[3].b = Param(atype(weights["model.encoder.$(i-1).mlp.2.bias"][:cpu]()[:numpy]()))
end

gpt.model.cast_embed.w = Param(atype(weights["model.cast_embed.weight"][:cpu]()[:numpy]()))
gpt.model.cast_embed.b = Param(atype(weights["model.cast_embed.bias"][:cpu]()[:numpy]()))

# Decoder
gpt.model.latent_mixing.w = Param(atype(weights["model.latent_mixing.weight"][:cpu]()[:numpy]()))
gpt.model.latent_mixing.b = Param(atype(weights["model.latent_mixing.bias"][:cpu]()[:numpy]()))

for i in 1:model_config["n_layer"]
    gpt.model.decoder.layers[i].ln1.a = Param(atype(weights["model.decoder.$(i-1).ln1.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].ln1.b = Param(atype(weights["model.decoder.$(i-1).ln1.bias"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].ln2.a = Param(atype(weights["model.decoder.$(i-1).ln2.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].ln2.b = Param(atype(weights["model.decoder.$(i-1).ln2.bias"][:cpu]()[:numpy]()))

    gpt.model.decoder.layers[i].attn.key.w = Param(atype(weights["model.decoder.$(i-1).attn.key.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].attn.key.b = Param(atype(weights["model.decoder.$(i-1).attn.key.bias"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].attn.query.w = Param(atype(weights["model.decoder.$(i-1).attn.query.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].attn.query.b = Param(atype(weights["model.decoder.$(i-1).attn.query.bias"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].attn.value.w = Param(atype(weights["model.decoder.$(i-1).attn.value.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].attn.value.b = Param(atype(weights["model.decoder.$(i-1).attn.value.bias"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].attn.proj.w = Param(atype(weights["model.decoder.$(i-1).attn.proj.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].attn.proj.b = Param(atype(weights["model.decoder.$(i-1).attn.proj.bias"][:cpu]()[:numpy]()))

    gpt.model.decoder.layers[i].mlp.layers[1].w = Param(atype(weights["model.decoder.$(i-1).mlp.0.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].mlp.layers[1].b = Param(atype(weights["model.decoder.$(i-1).mlp.0.bias"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].mlp.layers[3].w = Param(atype(weights["model.decoder.$(i-1).mlp.2.weight"][:cpu]()[:numpy]()))
    gpt.model.decoder.layers[i].mlp.layers[3].b = Param(atype(weights["model.decoder.$(i-1).mlp.2.bias"][:cpu]()[:numpy]()))
end

gpt.model.ln_f.a = Param(atype(weights["model.ln_f.weight"][:cpu]()[:numpy]()))
gpt.model.ln_f.b = Param(atype(weights["model.ln_f.bias"][:cpu]()[:numpy]()))

gpt.model.predict.w = Param(atype(weights["model.predict.weight"][:cpu]()[:numpy]()))
gpt.model.predict.b = Param(atype(weights["model.predict.bias"][:cpu]()[:numpy]()))

# codebook
gpt.model.codebook.embedding = Param(atype(weights["model.codebook.embedding"][:cpu]()[:numpy]()'))
gpt.model.codebook.ema_count = Param(atype(weights["model.codebook.ema_count"][:cpu]()[:numpy]()))
gpt.model.codebook.ema_w = Param(atype(weights["model.codebook.ema_w"][:cpu]()[:numpy]()'))

# padding vector
gpt.padding_vector = atype(normalize_joined_single(dataset, atype(zeros(gpt.transition_dim-1))));

# TransformerPrior Init and Model Loading

In [None]:
args = parser(super_args, experiment="train")
args["logbase"] = expanduser(args["logbase"])
args["savepath"] = expanduser(args["savepath"])
args["savepath"] = "/Users/mehmeteneserciyes/logs_julia/hopper-medium-replay-v2/T-1-42/"
block_size = args["subsampled_sequence_length"] ÷ args["latent_step"]
obs_dim = dataset.observation_dim

In [None]:
model_config = deepcopy(args)
model_config["block_size"] = block_size
model_config["observation_dim"] = obs_dim
model_config["n_embd"] = args["n_embd"] * args["n_head"]

In [None]:
# turn off dropout
model_config["embd_pdrop"] = 0.0f0
model_config["attn_pdrop"] = 0.0f0
model_config["resid_pdrop"] = 0.0f0

In [None]:
prior = TransformerPrior(model_config);

In [None]:
# encoder
prior.tok_emb = Param(atype(prior_weights["tok_emb.weight"][:cpu]()[:numpy]()'))
prior.pos_emb = Param(atype(permutedims(prior_weights["pos_emb"][:cpu]()[:numpy](), (3,2,1))))

prior.state_emb.w = Param(atype(prior_weights["state_emb.weight"][:cpu]()[:numpy]()))
prior.state_emb.b = Param(atype(prior_weights["state_emb.bias"][:cpu]()[:numpy]()))

for i in 1:model_config["n_layer"]
    prior.blocks.layers[i].ln1.a = Param(atype(prior_weights["blocks.$(i-1).ln1.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].ln1.b = Param(atype(prior_weights["blocks.$(i-1).ln1.bias"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].ln2.a = Param(atype(prior_weights["blocks.$(i-1).ln2.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].ln2.b = Param(atype(prior_weights["blocks.$(i-1).ln2.bias"][:cpu]()[:numpy]()))

    prior.blocks.layers[i].attn.key.w = Param(atype(prior_weights["blocks.$(i-1).attn.key.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].attn.key.b = Param(atype(prior_weights["blocks.$(i-1).attn.key.bias"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].attn.query.w = Param(atype(prior_weights["blocks.$(i-1).attn.query.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].attn.query.b = Param(atype(prior_weights["blocks.$(i-1).attn.query.bias"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].attn.value.w = Param(atype(prior_weights["blocks.$(i-1).attn.value.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].attn.value.b = Param(atype(prior_weights["blocks.$(i-1).attn.value.bias"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].attn.proj.w = Param(atype(prior_weights["blocks.$(i-1).attn.proj.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].attn.proj.b = Param(atype(prior_weights["blocks.$(i-1).attn.proj.bias"][:cpu]()[:numpy]()))

    prior.blocks.layers[i].mlp.layers[1].w = Param(atype(prior_weights["blocks.$(i-1).mlp.0.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].mlp.layers[1].b = Param(atype(prior_weights["blocks.$(i-1).mlp.0.bias"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].mlp.layers[3].w = Param(atype(prior_weights["blocks.$(i-1).mlp.2.weight"][:cpu]()[:numpy]()))
    prior.blocks.layers[i].mlp.layers[3].b = Param(atype(prior_weights["blocks.$(i-1).mlp.2.bias"][:cpu]()[:numpy]()))
end

prior.ln_f.a = Param(atype(prior_weights["ln_f.weight"][:cpu]()[:numpy]()))
prior.ln_f.b = Param(atype(prior_weights["ln_f.bias"][:cpu]()[:numpy]()))
# no bias
prior.head.w = Param(atype(prior_weights["head.weight"][:cpu]()[:numpy]())); 

# Planning part

In [None]:
args = parser(super_args, experiment="plan")
args["logbase"] = expanduser(args["logbase"])
args["savepath"] = expanduser(args["savepath"])

In [None]:
observation_gt = numpy.load("test/files/plan_observation.npy")
state_gt = numpy.load("test/files/plan_state.npy")

In [None]:
discount = dataset.discount
observation_dim = dataset.observation_dim
action_dim = dataset.action_dim

#######################
###### main loop ######
#######################
REWARD_DIM = VALUE_DIM = 1
transition_dim = observation_dim + action_dim + REWARD_DIM + VALUE_DIM

observation = observation_gt
total_reward = 0
discount_return = 0

# Loop

In [None]:
t=0;

In [None]:
rollout = [deepcopy(state_gt)]

In [None]:
## previous (tokenized) transitions for conditioning transformer
context = []
mses = []

In [None]:
T = env.max_episode_steps

In [None]:
observation = observation_gt;
state = state_gt;

In [None]:
if dataset.normalized_raw
    println("normalize")
    observation = normalize_states(dataset, observation) # TODO: implement normalize_states
end

In [None]:

#######################
####### util functions ########
#######################

function make_prefix(obs, transition_dim)
    obs_discrete = atype(obs)
    pad_dims = atype(zeros(transition_dim - size(obs_discrete, 1)))
    if ndims(obs_discrete) == 2
        obs_discrete = reshape(obs_discrete, :, 1, 1)
        pad_dims = reshape(pad_dims, :, 1, 1)
    end
    transition = cat(obs_discrete, pad_dims, dims=1)
    prefix = transition
    return prefix
end

function extract_actions(x, observation_dim, action_dim; t=nothing)
    actions =  x[observation_dim+1:observation_dim+action_dim, :]
    if t !== nothing
        return actions[:, t]
    else
        return actions
    end
end

VALUE_PLACEHOLDER = 1e6
function update_context(observation, action, reward)
    rew_val = [reward; VALUE_PLACEHOLDER]
    transition = cat(observation, action, rew_val; dims=1)
    context = []
    transition_discrete = atype(transition)
    transition_discrete = reshape(transition_discrete, :, 1, 1)
    push!(context, transition_discrete)
    return context
end


In [None]:
prefix = make_prefix(observation, transition_dim)

In [None]:
sequence = beam_with_prior(prior, gpt, prefix, dataset, 
                discount=discount, 
                steps=args["horizon"],
                beam_width=args["beam_width"],
                n_expand=args["n_expand"],
                likelihood_weight=args["prob_weight"],
                prob_threshold=args["prob_threshold"]);

In [None]:
first_value = denormalize_values(dataset, sequence[end-1, 1])
first_search_value = denormalize_values(dataset, sequence[end-1,end])

In [None]:
sequence_recon = sequence
## [ action_dim ] index into sampled latentplan to grab first action
feature_dim = dataset.observation_dim
action = extract_actions(sequence_recon, feature_dim, action_dim; t=1) 

In [None]:
if dataset.normalized_raw
    action = reshape(denormalize_actions(dataset, action), :)
    sequence_recon = denormalize_joined(dataset, sequence_recon)
end

In [None]:
next_observation, reward, terminal, _ = env.step(action)

In [None]:
total_reward += reward
discount_return += reward * (discount ^ (t-1))
score = env.get_normalized_score(total_reward)

In [None]:
push!(rollout, deepcopy(state))
context = update_context(observation, action, reward)

In [None]:
@printf("[ plan ] t: %d / %d | r: %.2f | R: %.2f | score: %.4f | time: | %s | %s | %s\n", 
t, T, reward, total_reward, score,
args["dataset"], args["exp_name"], args["suffix"])

In [None]:
json_path = joinpath(args["savepath"], "rollout.json")

In [None]:
json_data = Dict(
    "score" => score,
    "step" => t,
    "return" => total_reward,
    "term" => terminal,
    "gpt_epoch" => 123,
    "first_value" => first_value,
    "first_search_value" => first_search_value,
    "discount_return" => discount_return,
    # "prediction_error" => mean(mses)
)

# Beam with prior

In [None]:
prior=prior
model=gpt
x=prefix
dataset=dataset
discount=discount;
steps=args["horizon"]
beam_width=args["beam_width"]
n_expand=args["n_expand"]
likelihood_weight=args["prob_weight"]
prob_threshold=args["prob_threshold"]

In [None]:
contex = nothing
state = x[1:prior.observation_dim, 1, :]
acc_probs = atype(zeros(1))
info = Dict();

In [None]:
state

In [None]:
# for loop starts
step = 0

In [None]:
logits, _ = prior(contex, state)
probs = softmax(logits[:, end, :], dims=1)
log_probs = log.(probs)

In [None]:
probs

In [None]:
log_probs_gt = numpy.load("test/files/plan_log_probs.npy")'
all(log_probs .≈ log_probs_gt)

In [None]:
nb_samples = step==0 ? beam_width * n_expand : n_expand
samples = torch.multinomial(torch.tensor(probs'), num_samples=nb_samples, replacement=true).numpy()' .+ 1
# samples = numpy.load("test/files/plan_samples.npy")' .+ 1
samples_log_prob = [reshape(a[i], size(a[i])..., 1) for (a, i) in zip(eachslice(log_probs, dims=2), eachslice(samples, dims=2))]

In [None]:
samples_log_prob_gt = numpy.load("test/files/plan_samples_log_prob.npy")'
all(samples_log_prob .≈ samples_log_prob_gt)

In [None]:
acc_probs = repeat_interleave(acc_probs, nb_samples) .+ reshape(samples_log_prob, :)
contex = reshape(samples, step+1, :)

In [None]:
prediction_raw = decode_from_indices(model, contex, state)
prediction = reshape(prediction_raw, model.action_dim+model.observation_dim+3, :)

In [None]:
r_t = prediction[end-2, :]
V_t = prediction[end-1, :]

In [None]:
if dataset !== nothing
    r_t = reshape(denormalize_rewards(dataset, r_t), :, size(contex, ndims(contex)))
end
if dataset !== nothing
    V_t = reshape(denormalize_values(dataset, V_t), :, size(contex, ndims(contex)))
end

In [None]:
discounts = cumprod(atype(ones(size(r_t)...)) .* discount, dims=1)
values = dropdims_n(sum(r_t[1:end-1, :] .* discounts[1:end-1, :], dims=1), dims=(1,)) .+ V_t[end, :] .* discounts[end, :]

In [None]:
likelihood_bonus = likelihood_weight .* clip(acc_probs, -1e5, log(prob_threshold)*(steps÷model.latent_step))

In [None]:
nb_top = step < steps ÷ model.latent_step - 1 ? beam_width : 1

In [None]:
values_with_b, index = torch.topk(torch.tensor(values.+likelihood_bonus), nb_top)
values_with_b = values_with_b.numpy()
index = index.numpy()
index.+=1

In [None]:
info[(step+1)*model.latent_step] = Dict(
            "predictions"=>cputype(prediction_raw),
            "returns"=>cputype(values),
            "latent_codes"=>cputype(contex),
            "log_probs"=>cputype(acc_probs),
            "objectives"=>cputype(values.+likelihood_bonus),
            "index"=>cputype(index),
            )

In [None]:
contex = contex[:, index]
acc_probs = acc_probs[index]

In [None]:
optimal = prediction_raw[:,:,index[1]]

In [None]:
size(values)

In [None]:
print("predicted max value $(values[1])")