In [None]:
# %env MINERL_DATA_ROOT=/Volumes/CORSAIR/data

In [None]:
import minerl  # NOTE: we need gym>=0.13.1,<0.20
import gym
from train import get_agent, get_dynamics_environment, FMC, get_data_handler
from fgz.training.fgz_trainer import FGZTrainer
import torch
import wandb
from tqdm import tqdm
from constants import UNROLL_STEPS

In [None]:
torch.cuda.empty_cache()  # fix memory leaks

In [None]:
minerl_env = gym.make('MineRLBasaltMakeWaterfall-v0')
agent = get_agent()
dynamics_env = get_dynamics_environment(minerl_env)

In [None]:
agent.device

In [None]:
# dynamics_env.batched_action_space_sample()

In [None]:
# dummy_initial_state = torch.ones(4096, dtype=float)
# dynamics_env.set_all_states(dummy_initial_state)

In [None]:
data_handler = get_data_handler(agent)
for loader in data_handler.loaders:
    print(loader)

In [None]:
# t = data_handler.sample_single_trajectory()
# for window in t:
#     print(len(window))
#     for frame, state_embedding, action in window:
#         print(state_embedding.shape)
#         break
#     break

In [None]:
dynamics_function_optimizer = torch.optim.Adam(
    dynamics_env.dynamics_function.parameters(),
    lr=0.00001,
    weight_decay=1e-4,
)

In [None]:
use_wandb = True

fmc = FMC(dynamics_env, freeze_best=True)
trainer = FGZTrainer(agent, fmc, data_handler, dynamics_function_optimizer, unroll_steps=UNROLL_STEPS, use_wandb=use_wandb)

if use_wandb:
    wandb.init(project="fgz-v0.1.1")

In [None]:
# trainer.save("test_trainer_save.pth")
# loaded_trainer = FGZTrainer.load("test_trainer_save.pth", agent)
# loaded_trainer.evaluate("MineRLBasaltMakeWaterfall-v0", render=True, max_steps=4096, force_no_escape=True)

In [None]:
# trainer.evaluate("MineRLBasaltMakeWaterfall-v0", render=True, max_steps=4096, force_no_escape=True)

In [None]:
# trainer.eval_actions

In [None]:
train_steps = 15000
checkpoint_every = 100
for train_step in tqdm(range(train_steps), desc="Training"):
    trainer.train_sub_trajectory(use_tqdm=False)
    if train_step % checkpoint_every == 0:
        trainer.save("fgz_dynamics_checkpoint.pth")

In [None]:
# trainer.save("fgz_dynamics_trained.pth")

In [None]:
# trainer = FGZTrainer.load("fgz_dynamics_trained.pth", agent)

In [None]:
trainer.evaluate("MineRLBasaltMakeWaterfall-v0", render=True, max_steps=4096, force_no_escape=True)