In [1]:
using ArgParse: ArgParseSettings, @add_arg_table!, parse_args
using Statistics: mean
using Printf
using Knet

include("../latentplan/LPCore.jl")
include("../latentplan/setup.jl")
using .LPCore

No module named 'flow'
No module named 'carla'
pybullet build time: Nov 12 2022 16:05:06


In [2]:
losssum(prediction) = mean(prediction[2] + prediction[3] + prediction[4])

function vq_train(config, model::VQContinuousVAE, dataset; n_epochs=1, log_freq=100)
    # set optimizers
    opt_decay = AdamW(lr=config["learning_rate"], beta1=config["betas"][1], beta2=config["betas"][2], weight_decay=config["weight_decay"], gclip=config["grad_norm_clip"])
    opt_no_decay = AdamW(lr=config["learning_rate"], beta1=config["betas"][1], beta2=config["betas"][2], weight_decay=0.0, gclip=config["grad_norm_clip"])

    for p in paramlist_decay(model)
        p.opt = clone(opt_decay)
    end
    for p in paramlist_no_decay(model)
        p.opt = clone(opt_no_decay)
    end

    n_tokens = 0
    loader = DataLoader(dataset; shuffle=true, batch_size=config["batch_size"])

    for epoch in 1:n_epochs
        losses = []
        for (it, batch) in enumerate(loader)
            y = batch[end-1]
            n_tokens += cumprod(size(y))

            if n_tokens < config["warmup_tokens"]
                # linear warmup
                lr_mult = float(n_tokens) / float(max(1, config["warmup_tokens"]))
            else
                # cosine learning rate decay
                progress = float(n_tokens - config["warmup_tokens"]) / float(
                    max(1, config["final_tokens"] - config["warmup_tokens"])
                )
                lr_mult = max(0.1, 0.5 * (1.0 + cos(pi * progress)))
            end

            if config["lr_decay"]
                lr = config["learning_rate"] * lr_mult
                # TODO: param_group learning rate
                for p in paramlist(model)
                    p.opt.lr = lr
                end
            else
                lr = config["learning_rate"]
            end

            # forward the model
            total_loss = @diff losssum(model(batch...))
            push!(losses, value(total_loss))
            for p in paramlist(model)
                update!(p, grad(total_loss, p))
            end
        end
    end
end

vq_train (generic function with 1 method)

In [3]:
s = ArgParseSettings()
@add_arg_table! s begin
    "--dataset"
        help = "which environment to use"
        arg_type = String
        default = "halfcheetah-medium-expert-v2"
    "--exp_name"
        help = "name of the experiment"
        arg_type = String
        default = "debug"
    "--seed"
        help = "seed"
        arg_type = Int
        default = 42
    "--config"
        help = "relative jl file path with configurations"
        arg_type = String
        default = "../config/vqvae.jl"
end

#######################
######## setup ########
#######################

super_args = parse_args([], s)
args = parser(super_args, experiment="train");

[ utils/setup ] Reading config: ../config/vqvae.jl:halfcheetah_medium_expert_v2
/Users/enes/logs/halfcheetah-medium-expert-v2/debug/ already exists. Proceeding...
Made directory/Users/enes/logs/halfcheetah-medium-expert-v2/debug/


# Dataset

In [4]:
env_name = occursin("-v", args["dataset"]) ? args["dataset"] : args["dataset"] * "-v0"

# env params
sequence_length = args["subsampled_sequence_length"] * args["step"]
args["logbase"] = expanduser(args["logbase"])
args["savepath"] = expanduser(args["savepath"])
if !isdir(args["savepath"])
    mkpath(args["savepath"])
end

dataset = SequenceDataset(
    env_name;
    penalty=args["termination_penalty"], 
    sequence_length=sequence_length, 
    step=args["step"], 
    discount=args["discount"], 
    disable_goal=args["disable_goal"], 
    normalize_raw=args["normalize"], 
    normalize_reward=args["normalize_reward"],
    max_path_length=args["max_path_length"],
)

obs_dim = dataset.observation_dim
act_dim = dataset.action_dim
if args["task_type"] == "locomotion"
    transition_dim = obs_dim+act_dim+3
else
    transition_dim = 128+act_dim+3
end

block_size = args["subsampled_sequence_length"] * transition_dim # total number of dimensionalities for a maximum length sequence (T)

print(
    "Dataset size: $(length(dataset)) |
    Joined dim: $transition_dim
    observation: $obs_dim, action: $act_dim | Block size: $block_size"
)

[ datasets/sequence ] Sequence length: 25 | Step: 1 | Max path length: 1000
[ datasets/sequence ] Loading...


load datafile: 100%|█████████████████████████████| 9/9 [00:07<00:00,  1.28it/s]
[32mGenerating dataset 100%|█████████████████████████████████| Time: 0:00:00[39m


✓
[ datasets/sequence ] Segmenting...
✓
Dataset size: 48 |
    Joined dim: 26
    observation: 17, action: 6 | Block size: 650

# Model

In [5]:
model_config = deepcopy(args)
model_config["block_size"] = block_size
model_config["observation_dim"] = obs_dim
model_config["action_dim"] = act_dim
model_config["transition_dim"] = transition_dim
model_config["n_embd"] = args["n_embd"] * args["n_head"]
model_config["vocab_size"] = args["N"]

model = VQContinuousVAE(model_config);



## Model Params

In [6]:
# component = model.model.ln_f
# num_params = 0
# for (i, param) in enumerate(paramlist(model))
#     println(i, " - ", length(param))
#     num_params += length(param)
# end
# print("Number of parameters in model: ", num_params)

# Train process

In [7]:
loader = DataLoader(dataset; shuffle=true, batch_size=args["batch_size"])
batch = nothing
for (it, b) in enumerate(loader)
    batch = b
    break
end

In [8]:
joined_inputs, targets, mask, terminals = batch;

In [9]:
joined_dimension, t, b = size(joined_inputs)

(25, 24, 48)

In [12]:
model.padding_vector = normalize_joined_single(dataset, zeros(model.transition_dim-1))

25-element Vector{Float64}:
  1.3586551130902185
  0.228389940535605
  0.7069261656678192
  1.3224555746290494
  2.506223180702559
  0.3803132917548237
  0.5203201193541377
  0.17306247156993976
 -2.0574827651805627
  0.13590436198433534
 -0.06275258940279395
 -0.04819511565385873
 -0.007849192285634327
 -0.07929695236955941
  0.02990876731896082
  0.005835718021037472
  0.017623983572789888
  0.21693206086149935
  0.9879023298435688
  1.4429752437532077
  0.3731244195973748
  0.4378495563353325
  0.15207652016219186
 -1.8588657195968
 -2.1642992947245396