In [1]:
using ArgParse: ArgParseSettings, @add_arg_table!, parse_args
using Statistics: mean
using Printf
using Knet

include("../latentplan/LPCore.jl")
include("../latentplan/setup.jl")
using .LPCore

No module named 'flow'
No module named 'carla'
pybullet build time: Nov 12 2022 16:05:06


In [3]:
losssum(prediction) = mean(prediction[2] + prediction[3] + prediction[4])

function vq_train(config, model::VQContinuousVAE, dataset; n_epochs=1, log_freq=100)
    # set optimizers
    opt_decay = AdamW(lr=config["learning_rate"], beta1=config["betas"][1], beta2=config["betas"][2], weight_decay=config["weight_decay"], gclip=config["grad_norm_clip"])
    opt_no_decay = AdamW(lr=config["learning_rate"], beta1=config["betas"][1], beta2=config["betas"][2], weight_decay=0.0, gclip=config["grad_norm_clip"])

    for p in paramlist_decay(model)
        p.opt = clone(opt_decay)
    end
    for p in paramlist_no_decay(model)
        p.opt = clone(opt_no_decay)
    end

    n_tokens = 0
    loader = DataLoader(dataset; shuffle=true, batch_size=config["batch_size"])

    for epoch in 1:n_epochs
        losses = []
        for (it, batch) in enumerate(loader)
            y = batch[end-1]
            n_tokens += cumprod(size(y))

            if n_tokens < config["warmup_tokens"]
                # linear warmup
                lr_mult = float(n_tokens) / float(max(1, config["warmup_tokens"]))
            else
                # cosine learning rate decay
                progress = float(n_tokens - config["warmup_tokens"]) / float(
                    max(1, config["final_tokens"] - config["warmup_tokens"])
                )
                lr_mult = max(0.1, 0.5 * (1.0 + cos(pi * progress)))
            end

            if config["lr_decay"]
                lr = config["learning_rate"] * lr_mult
                # TODO: param_group learning rate
                for p in paramlist(model)
                    p.opt.lr = lr
                end
            else
                lr = config["learning_rate"]
            end

            # forward the model
            total_loss = @diff losssum(model(batch...))
            push!(losses, value(total_loss))
            for p in paramlist(model)
                update!(p, grad(total_loss, p))
            end
        end
    end
end

vq_train (generic function with 1 method)

In [4]:
s = ArgParseSettings()
@add_arg_table! s begin
    "--dataset"
        help = "which environment to use"
        arg_type = String
        default = "halfcheetah-medium-expert-v2"
    "--exp_name"
        help = "name of the experiment"
        arg_type = String
        default = "debug"
    "--seed"
        help = "seed"
        arg_type = Int
        default = 42
    "--config"
        help = "relative jl file path with configurations"
        arg_type = String
        default = "../config/vqvae.jl"
end

#######################
######## setup ########
#######################

super_args = parse_args([], s)
args = parser(super_args, experiment="train");

[ utils/setup ] Reading config: ../config/vqvae.jl:halfcheetah_medium_expert_v2
/Users/enes/logs/halfcheetah-medium-expert-v2/debug/ already exists. Proceeding...
Made directory/Users/enes/logs/halfcheetah-medium-expert-v2/debug/


# Dataset

In [5]:
env_name = occursin("-v", args["dataset"]) ? args["dataset"] : args["dataset"] * "-v0"

# env params
sequence_length = args["subsampled_sequence_length"] * args["step"]
args["logbase"] = expanduser(args["logbase"])
args["savepath"] = expanduser(args["savepath"])
if !isdir(args["savepath"])
    mkpath(args["savepath"])
end

dataset = SequenceDataset(
    env_name;
    penalty=args["termination_penalty"], 
    sequence_length=sequence_length, 
    step=args["step"], 
    discount=args["discount"], 
    disable_goal=args["disable_goal"], 
    normalize_raw=args["normalize"], 
    normalize_reward=args["normalize_reward"],
    max_path_length=args["max_path_length"],
)

obs_dim = dataset.observation_dim
act_dim = dataset.action_dim
if args["task_type"] == "locomotion"
    transition_dim = obs_dim+act_dim+3
else
    transition_dim = 128+act_dim+3
end

block_size = args["subsampled_sequence_length"] * transition_dim # total number of dimensionalities for a maximum length sequence (T)

print(
    "Dataset size: $(length(dataset)) |
    Joined dim: $transition_dim
    observation: $obs_dim, action: $act_dim | Block size: $block_size"
)

[ datasets/sequence ] Sequence length: 25 | Step: 1 | Max path length: 1000
[ datasets/sequence ] Loading...


load datafile: 100%|█████████████████████████████| 9/9 [00:05<00:00,  1.65it/s]
[32mGenerating dataset 100%|█████████████████████████████████| Time: 0:00:00[39m


✓
[ datasets/sequence ] Segmenting...
✓
Dataset size: 48 |
    Joined dim: 26
    observation: 17, action: 6 | Block size: 650

# Model

In [6]:
model_config = deepcopy(args)
model_config["block_size"] = block_size
model_config["observation_dim"] = obs_dim
model_config["action_dim"] = act_dim
model_config["transition_dim"] = transition_dim
model_config["n_embd"] = args["n_embd"] * args["n_head"]
model_config["vocab_size"] = args["N"]

model = VQContinuousVAE(model_config);

In [7]:
length(paramlist(model))

139

In [10]:
num_params = 0
for (i, param) in enumerate(paramlist(model))
    println(i, " - ", length(param))
    num_params += length(param)
end
print("Number of parameters in model: ", num_params)

1 - 512
2 - 512
3 - 512
4 - 512
5 - 262144
6 - 512
7 - 262144
8 - 512
9 - 262144
10 - 512
11 - 262144
12 - 512
13 - 1048576
14 - 2048
15 - 1048576
16 - 512
17 - 512
18 - 512
19 - 512
20 - 512
21 - 262144
22 - 512
23 - 262144
24 - 512
25 - 262144
26 - 512
27 - 262144
28 - 512
29 - 1048576
30 - 2048
31 - 1048576
32 - 512
33 - 512
34 - 512
35 - 512
36 - 512
37 - 262144
38 - 512
39 - 262144
40 - 512
41 - 262144
42 - 512
43 - 262144
44 - 512
45 - 1048576
46 - 2048
47 - 1048576
48 - 512
49 - 512
50 - 512
51 - 512
52 - 512
53 - 262144
54 - 512
55 - 262144
56 - 512
57 - 262144
58 - 512
59 - 262144
60 - 512
61 - 1048576
62 - 2048
63 - 1048576
64 - 512
65 - 512
66 - 512
67 - 512
68 - 512
69 - 262144
70 - 512
71 - 262144
72 - 512
73 - 262144
74 - 512
75 - 262144
76 - 512
77 - 1048576
78 - 2048
79 - 1048576
80 - 512
81 - 512
82 - 512
83 - 512
84 - 512
85 - 262144
86 - 512
87 - 262144
88 - 512
89 - 262144
90 - 512
91 - 262144
92 - 512
93 - 1048576
94 - 2048
95 - 1048576
96 - 512
97 - 512
98 - 512
9