In [None]:
import Pkg
Pkg.activate("..")

In [None]:
using Test
using PyCall
using Knet
using Debugger: @enter, @bp, @run
using CUDA

if CUDA.functional()
	atype=KnetArray{Float32}
else	
	atype=Array{Float32}
end
cputype=Array{Float32}

In [None]:
@pyimport torch
@pyimport numpy

weights = torch.load("test/files/gpt_trained.pt", map_location=torch.device("cpu"))
prior_weights = torch.load("test/files/prior_model.pt")

In [None]:
include("datasets/sequence.jl")
include("models/common.jl")
include("models/transformers.jl")
include("models/vqvae.jl")
include("setup.jl")

In [None]:
super_args = Dict{String, Any}(
    "dataset"=> "hopper-medium-replay-v2",
    "exp_name"=> "debug",
    "seed"=> 42,
    "config"=> "../config/vqvae.jl",
    "representation_path" => "", #TODO
)

args = parser(super_args, experiment="plan")

args["logbase"] = expanduser(args["logbase"])
args["savepath"] = expanduser(args["savepath"])

env_name = occursin("-v", args["dataset"]) ? args["dataset"] : args["dataset"] * "-v0"

In [None]:
dataset_config = Knet.load(joinpath("/Users/mehmeteneserciyes/logs_julia/hopper-medium-replay-v2/T-1-42", "dataset_config.jld2"), "config")

dataset = SequenceDataset(
    dataset_config["env_name"];
    penalty=dataset_config["penalty"],
    sequence_length=dataset_config["sequence_length"], 
    step=dataset_config["step"], 
    discount=dataset_config["discount"], 
    disable_goal=dataset_config["disable_goal"], 
    normalize_raw=dataset_config["normalize_raw"], 
    normalize_reward=dataset_config["normalize_reward"],
    max_path_length=dataset_config["max_path_length"],
    atype=dataset_config["atype"]
)

obs_dim = dataset.observation_dim
act_dim = dataset.action_dim
transition_dim = dataset.joined_dim+1

# Representation model init and weight loading

In [None]:
model_config = Knet.load(joinpath("/Users/mehmeteneserciyes/logs_julia/hopper-medium-replay-v2/T-1-42", "model_config.jld2"), "config")

In [None]:
representation = VQContinuousVAE(model_config);

In [None]:
# encoder
representation.model.embed.w = Param(atype(weights["model.embed.weight"][:cpu]()[:numpy]()))
representation.model.embed.b = Param(atype(weights["model.embed.bias"][:cpu]()[:numpy]()))

representation.model.pos_emb = Param(atype(permutedims(weights["model.pos_emb"][:cpu]()[:numpy](), (3,2,1))))

for i in 1:model_config["n_layer"]
    representation.model.encoder.layers[i].ln1.a = Param(atype(weights["model.encoder.$(i-1).ln1.weight"][:cpu]()[:numpy]()))
    representation.model.encoder.layers[i].ln1.b = Param(atype(weights["model.encoder.$(i-1).ln1.bias"][:cpu]()[:numpy]()))
    representation.model.encoder.layers[i].ln2.a = Param(atype(weights["model.encoder.$(i-1).ln2.weight"][:cpu]()[:numpy]()))
    representation.model.encoder.layers[i].ln2.b = Param(atype(weights["model.encoder.$(i-1).ln2.bias"][:cpu]()[:numpy]()))

    representation.model.encoder.layers[i].attn.key.w = Param(atype(weights["model.encoder.$(i-1).attn.key.weight"][:cpu]()[:numpy]()))
    representation.model.encoder.layers[i].attn.key.b = Param(atype(weights["model.encoder.$(i-1).attn.key.bias"][:cpu]()[:numpy]()))
    representation.model.encoder.layers[i].attn.query.w = Param(atype(weights["model.encoder.$(i-1).attn.query.weight"][:cpu]()[:numpy]()))
    representation.model.encoder.layers[i].attn.query.b = Param(atype(weights["model.encoder.$(i-1).attn.query.bias"][:cpu]()[:numpy]()))
    representation.model.encoder.layers[i].attn.value.w = Param(atype(weights["model.encoder.$(i-1).attn.value.weight"][:cpu]()[:numpy]()))
    representation.model.encoder.layers[i].attn.value.b = Param(atype(weights["model.encoder.$(i-1).attn.value.bias"][:cpu]()[:numpy]()))
    representation.model.encoder.layers[i].attn.proj.w = Param(atype(weights["model.encoder.$(i-1).attn.proj.weight"][:cpu]()[:numpy]()))
    representation.model.encoder.layers[i].attn.proj.b = Param(atype(weights["model.encoder.$(i-1).attn.proj.bias"][:cpu]()[:numpy]()))

    representation.model.encoder.layers[i].mlp.layers[1].w = Param(atype(weights["model.encoder.$(i-1).mlp.0.weight"][:cpu]()[:numpy]()))
    representation.model.encoder.layers[i].mlp.layers[1].b = Param(atype(weights["model.encoder.$(i-1).mlp.0.bias"][:cpu]()[:numpy]()))
    representation.model.encoder.layers[i].mlp.layers[3].w = Param(atype(weights["model.encoder.$(i-1).mlp.2.weight"][:cpu]()[:numpy]()))
    representation.model.encoder.layers[i].mlp.layers[3].b = Param(atype(weights["model.encoder.$(i-1).mlp.2.bias"][:cpu]()[:numpy]()))
end

representation.model.cast_embed.w = Param(atype(weights["model.cast_embed.weight"][:cpu]()[:numpy]()))
representation.model.cast_embed.b = Param(atype(weights["model.cast_embed.bias"][:cpu]()[:numpy]()))

# Decoder
representation.model.latent_mixing.w = Param(atype(weights["model.latent_mixing.weight"][:cpu]()[:numpy]()))
representation.model.latent_mixing.b = Param(atype(weights["model.latent_mixing.bias"][:cpu]()[:numpy]()))

for i in 1:model_config["n_layer"]
    representation.model.decoder.layers[i].ln1.a = Param(atype(weights["model.decoder.$(i-1).ln1.weight"][:cpu]()[:numpy]()))
    representation.model.decoder.layers[i].ln1.b = Param(atype(weights["model.decoder.$(i-1).ln1.bias"][:cpu]()[:numpy]()))
    representation.model.decoder.layers[i].ln2.a = Param(atype(weights["model.decoder.$(i-1).ln2.weight"][:cpu]()[:numpy]()))
    representation.model.decoder.layers[i].ln2.b = Param(atype(weights["model.decoder.$(i-1).ln2.bias"][:cpu]()[:numpy]()))

    representation.model.decoder.layers[i].attn.key.w = Param(atype(weights["model.decoder.$(i-1).attn.key.weight"][:cpu]()[:numpy]()))
    representation.model.decoder.layers[i].attn.key.b = Param(atype(weights["model.decoder.$(i-1).attn.key.bias"][:cpu]()[:numpy]()))
    representation.model.decoder.layers[i].attn.query.w = Param(atype(weights["model.decoder.$(i-1).attn.query.weight"][:cpu]()[:numpy]()))
    representation.model.decoder.layers[i].attn.query.b = Param(atype(weights["model.decoder.$(i-1).attn.query.bias"][:cpu]()[:numpy]()))
    representation.model.decoder.layers[i].attn.value.w = Param(atype(weights["model.decoder.$(i-1).attn.value.weight"][:cpu]()[:numpy]()))
    representation.model.decoder.layers[i].attn.value.b = Param(atype(weights["model.decoder.$(i-1).attn.value.bias"][:cpu]()[:numpy]()))
    representation.model.decoder.layers[i].attn.proj.w = Param(atype(weights["model.decoder.$(i-1).attn.proj.weight"][:cpu]()[:numpy]()))
    representation.model.decoder.layers[i].attn.proj.b = Param(atype(weights["model.decoder.$(i-1).attn.proj.bias"][:cpu]()[:numpy]()))

    representation.model.decoder.layers[i].mlp.layers[1].w = Param(atype(weights["model.decoder.$(i-1).mlp.0.weight"][:cpu]()[:numpy]()))
    representation.model.decoder.layers[i].mlp.layers[1].b = Param(atype(weights["model.decoder.$(i-1).mlp.0.bias"][:cpu]()[:numpy]()))
    representation.model.decoder.layers[i].mlp.layers[3].w = Param(atype(weights["model.decoder.$(i-1).mlp.2.weight"][:cpu]()[:numpy]()))
    representation.model.decoder.layers[i].mlp.layers[3].b = Param(atype(weights["model.decoder.$(i-1).mlp.2.bias"][:cpu]()[:numpy]()))
end

representation.model.ln_f.a = Param(atype(weights["model.ln_f.weight"][:cpu]()[:numpy]()))
representation.model.ln_f.b = Param(atype(weights["model.ln_f.bias"][:cpu]()[:numpy]()))

representation.model.predict.w = Param(atype(weights["model.predict.weight"][:cpu]()[:numpy]()))
representation.model.predict.b = Param(atype(weights["model.predict.bias"][:cpu]()[:numpy]()))

# codebook
representation.model.codebook.embedding = Param(atype(weights["model.codebook.embedding"][:cpu]()[:numpy]()'))
representation.model.codebook.ema_count = Param(atype(weights["model.codebook.ema_count"][:cpu]()[:numpy]()))
representation.model.codebook.ema_w = Param(atype(weights["model.codebook.ema_w"][:cpu]()[:numpy]()'))

# padding vector
representation.padding_vector = atype(normalize_joined_single(dataset, atype(zeros(representation.transition_dim-1))));

# TransformerPrior Init and Model Loading

In [None]:
args = parser(super_args, experiment="train")
args["logbase"] = expanduser(args["logbase"])
args["savepath"] = expanduser(args["savepath"])
args["savepath"] = "/Users/mehmeteneserciyes/logs_julia/hopper-medium-replay-v2/T-1-42/"
block_size = args["subsampled_sequence_length"] ÷ args["latent_step"]
obs_dim = dataset.observation_dim

In [None]:
model_config = deepcopy(args)
model_config["block_size"] = block_size
model_config["observation_dim"] = obs_dim
model_config["n_embd"] = args["n_embd"] * args["n_head"]

In [None]:
# turn off dropout
model_config["embd_pdrop"] = 0.0f0
model_config["attn_pdrop"] = 0.0f0
model_config["resid_pdrop"] = 0.0f0

In [None]:
model = TransformerPrior(model_config);

In [None]:
# encoder
model.tok_emb = Param(atype(prior_weights["tok_emb.weight"][:cpu]()[:numpy]()'))
model.pos_emb = Param(atype(permutedims(prior_weights["pos_emb"][:cpu]()[:numpy](), (3,2,1))))

model.state_emb.w = Param(atype(prior_weights["state_emb.weight"][:cpu]()[:numpy]()))
model.state_emb.b = Param(atype(prior_weights["state_emb.bias"][:cpu]()[:numpy]()))

for i in 1:model_config["n_layer"]
    model.blocks.layers[i].ln1.a = Param(atype(prior_weights["blocks.$(i-1).ln1.weight"][:cpu]()[:numpy]()))
    model.blocks.layers[i].ln1.b = Param(atype(prior_weights["blocks.$(i-1).ln1.bias"][:cpu]()[:numpy]()))
    model.blocks.layers[i].ln2.a = Param(atype(prior_weights["blocks.$(i-1).ln2.weight"][:cpu]()[:numpy]()))
    model.blocks.layers[i].ln2.b = Param(atype(prior_weights["blocks.$(i-1).ln2.bias"][:cpu]()[:numpy]()))

    model.blocks.layers[i].attn.key.w = Param(atype(prior_weights["blocks.$(i-1).attn.key.weight"][:cpu]()[:numpy]()))
    model.blocks.layers[i].attn.key.b = Param(atype(prior_weights["blocks.$(i-1).attn.key.bias"][:cpu]()[:numpy]()))
    model.blocks.layers[i].attn.query.w = Param(atype(prior_weights["blocks.$(i-1).attn.query.weight"][:cpu]()[:numpy]()))
    model.blocks.layers[i].attn.query.b = Param(atype(prior_weights["blocks.$(i-1).attn.query.bias"][:cpu]()[:numpy]()))
    model.blocks.layers[i].attn.value.w = Param(atype(prior_weights["blocks.$(i-1).attn.value.weight"][:cpu]()[:numpy]()))
    model.blocks.layers[i].attn.value.b = Param(atype(prior_weights["blocks.$(i-1).attn.value.bias"][:cpu]()[:numpy]()))
    model.blocks.layers[i].attn.proj.w = Param(atype(prior_weights["blocks.$(i-1).attn.proj.weight"][:cpu]()[:numpy]()))
    model.blocks.layers[i].attn.proj.b = Param(atype(prior_weights["blocks.$(i-1).attn.proj.bias"][:cpu]()[:numpy]()))

    model.blocks.layers[i].mlp.layers[1].w = Param(atype(prior_weights["blocks.$(i-1).mlp.0.weight"][:cpu]()[:numpy]()))
    model.blocks.layers[i].mlp.layers[1].b = Param(atype(prior_weights["blocks.$(i-1).mlp.0.bias"][:cpu]()[:numpy]()))
    model.blocks.layers[i].mlp.layers[3].w = Param(atype(prior_weights["blocks.$(i-1).mlp.2.weight"][:cpu]()[:numpy]()))
    model.blocks.layers[i].mlp.layers[3].b = Param(atype(prior_weights["blocks.$(i-1).mlp.2.bias"][:cpu]()[:numpy]()))
end

model.ln_f.a = Param(atype(prior_weights["ln_f.weight"][:cpu]()[:numpy]()))
model.ln_f.b = Param(atype(prior_weights["ln_f.bias"][:cpu]()[:numpy]()))

model.head.w = Param(atype(prior_weights["head.weight"][:cpu]()[:numpy]())); # no bias

# Test same results

In [None]:
loader = DataLoader(dataset; shuffle=false, batch_size=args["batch_size"]);

In [None]:
batch = nothing
for (it, b) in enumerate(loader)
    batch = b;
    break
end

In [None]:
states = batch[1][1:model.observation_dim, 1, :]
indices = encode(representation, batch[1], batch[end])

In [None]:
indices_gt = numpy.load("test/files/indices.npy")'
all(indices .- 1 .== indices_gt)

In [None]:
loss = @diff model(indices[1:end-1, :], states, indices)[2]

In [None]:
value(loss)

# Embedding call with indices

In [None]:
idx = indices[1:end-1, :];

In [None]:
t, b = size(idx)

In [None]:
token_embeddings = model.tok_emb[:, indices[1:end-1, :]]
token_embeddings = cat(atype(zeros(Float32, model.embedding_dim, 1, b)), token_embeddings, dims=2);

In [None]:
token_embeddings_gt = permutedims(numpy.load("test/files/token_embeddings_prior.npy"), (3,2,1))
all(token_embeddings .≈ token_embeddings_gt)

In [None]:
position_embeddings = model.pos_emb[:, 1:t+1, :];

In [None]:
states_gt = numpy.load("test/files/state_prior.npy")'
all(states .≈ states_gt)

In [None]:
state_embeddings = model.state_emb(states)
state_embeddings = reshape(state_embeddings, size(state_embeddings)[1], 1, size(state_embeddings)[2:end]...);

In [None]:
mean(state_embeddings) ≈ mean(state_embeddings_gt)

In [None]:
size(state_embeddings_gt)

In [None]:
state_embeddings_gt = permutedims(numpy.load("test/files/state_embeddings_prior.npy"), (3,2,1))
all(abs.(state_embeddings .- state_embeddings_gt) .< 1e-6)

### Embeddings test

In [None]:
embeddings_gt = permutedims(numpy.load("test/files/embeddings_prior.npy"), (3,2,1));

In [None]:
x = model.drop(token_embeddings .+ position_embeddings .+ state_embeddings)
all(abs.(x .- embeddings_gt) .< 1e-6)

In [None]:
x = model.blocks(x)

In [None]:
blocks_output_gt = permutedims(numpy.load("test/files/blocks_output_prior.npy"), (3,2,1))
all(abs.(x .- blocks_output_gt) .< 1e-5)

In [None]:
x = model.ln_f(x)

# Test

In [None]:
representation = Knet.load(joinpath(args["savepath"], "state_123.jld2"))