In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import latentplan.utils as utils
import latentplan.datasets as datasets
from latentplan.models.vqvae import VQContinuousVAE

In [None]:
import os
from torch.utils.data.dataloader import DataLoader
import torch, numpy as np

In [None]:
class Parser(utils.Parser):
    dataset: str = 'halfcheetah-medium-expert-v2'
    config: str = 'config.vqvae'

#######################
######## setup ########
#######################

args = Parser().parse_args('train')

env_name = args.dataset if "-v" in args.dataset else args.dataset+"-v0"
env = datasets.load_environment(env_name)

sequence_length = args.subsampled_sequence_length * args.step
args.logbase = os.path.expanduser(args.logbase)
args.savepath = os.path.expanduser(args.savepath)
if not os.path.exists(args.savepath):
    os.makedirs(args.savepath)

# Dataset

In [None]:
dataset_class = datasets.SequenceDataset

dataset_config = utils.Config(
    dataset_class,
    savepath=(args.savepath, 'data_config.pkl'),
    env=args.dataset,
    penalty=args.termination_penalty,
    sequence_length=sequence_length,
    step=args.step,
    discount=args.discount,
    disable_goal=args.disable_goal,
    normalize_raw=args.normalize,
    normalize_reward=args.normalize_reward,
    max_path_length=int(args.max_path_length),
)

dataset = dataset_config()
obs_dim = dataset.observation_dim
act_dim = dataset.action_dim
if args.task_type == "locomotion":
    transition_dim = obs_dim+act_dim+3
else:
    transition_dim = 128+act_dim+3

block_size = args.subsampled_sequence_length * transition_dim # total number of dimensionalities for a maximum length sequence (T)

print(
    f'Dataset size: {len(dataset)} | '
    f'Joined dim: {transition_dim} '
    f'(observation: {obs_dim}, action: {act_dim}) | Block size: {block_size}'
)

# Model

In [None]:
model_config = utils.Config(
    VQContinuousVAE,
    savepath=(args.savepath, 'model_config.pkl'),
    ## discretization
    vocab_size=args.N, block_size=block_size,
    K=args.K,
    ## architecture
    n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd * args.n_head,
    ## dimensions
    observation_dim=obs_dim, action_dim=act_dim, transition_dim=transition_dim,
    ## loss weighting
    action_weight=args.action_weight, reward_weight=args.reward_weight, value_weight=args.value_weight,
    position_weight=args.position_weight,
    first_action_weight=args.first_action_weight,
    sum_reward_weight=args.sum_reward_weight,
    last_value_weight=args.last_value_weight,
    trajectory_embd=args.trajectory_embd,
    model=args.model,
    latent_step=args.latent_step,
    ma_update=args.ma_update,
    residual=args.residual,
    obs_shape=args.obs_shape,
    ## dropout probabilities
    embd_pdrop=args.embd_pdrop, resid_pdrop=args.resid_pdrop, attn_pdrop=args.attn_pdrop,
    bottleneck=args.bottleneck,
    masking=args.masking,
    state_conditional=args.state_conditional,
)


model = model_config()
if args.normalize:
    model.set_padding_vector(dataset.normalize_joined_single(np.zeros(model.transition_dim-1)))

## Model Params

In [None]:
650 // model.model.transition_dim

In [None]:
i = 0
total_param = 0
for name, param in model.named_parameters():
    i+=1
    print(i, " - ", name, ": ", param.numel())
    total_param += param.numel()
print("Total params:", total_param)


# Train process

In [None]:
loader = DataLoader(dataset, shuffle=False, batch_size=args.batch_size)
for it, batch in enumerate(loader):
    with torch.set_grad_enabled(True):
        *_, recon_loss, vq_loss, commit_loss = model(*batch)

In [None]:
model.padding_vector