In [1]:
from argparse import ArgumentParser
import pickle
import time

import gym
import minerl
import torch
import json
import numpy as np
import glob
import cv2
from tqdm.auto import tqdm
import os

from rebeca import REBECA
from data_loader import DataLoader
from openai_vpt.lib.tree_util import tree_map

# Originally this code was designed for a small dataset of ~20 demonstrations per task.
# The settings might not be the best for the full BASALT dataset (thousands of demonstrations).
# Use this flag to switch between the two settings
USING_FULL_DATASET = False

EPOCHS = 1 if USING_FULL_DATASET else 2
# Needs to be <= number of videos
BATCH_SIZE = 64 if USING_FULL_DATASET else 1
# Ideally more than batch size to create
# variation in datasets (otherwise, you will
# get a bunch of consecutive samples)
# Decrease this (and batch_size) if you run out of memory
N_WORKERS = 100 if USING_FULL_DATASET else 1
DEVICE = "cuda"

LOSS_REPORT_RATE = 10

# Tuned with bit of trial and error
LEARNING_RATE = 0.000181
# OpenAI VPT BC weight decay
# WEIGHT_DECAY = 0.039428
WEIGHT_DECAY = 0.0
# KL loss to the original model was not used in OpenAI VPT
KL_LOSS_WEIGHT = 1.0
MAX_GRAD_NORM = 5.0

MAX_BATCHES = 2000 if USING_FULL_DATASET else int(1e9)

In [2]:
in_model = "data/VPT-models/foundation-model-1x.model"
in_weights = "data/VPT-models/foundation-model-1x.weights"
cnn_weights = "data/VPT-models/foundation-model-1x-cnn.weights"
trf_weights = "data/VPT-models/foundation-model-1x-trf.weights"
data_dir = "data/MakeWaterfallTrain/"

In [3]:
agent = REBECA(in_model, cnn_weights, trf_weights, "data/memory_cnn.json", device=DEVICE).to(DEVICE)

In [4]:
# Freeze all parameters
for param in agent.parameters():
    param.requires_grad = False

# Unfreeze controller layers
trainable_parameters = []
for name, param in agent.named_parameters():
    if not name.startswith("controller.vpt_transformers"):
        param.requires_grad = True
        trainable_parameters.append(param)

In [5]:
# Parameters taken from the OpenAI VPT paper
optimizer = torch.optim.Adam(
    trainable_parameters,
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY
)

data_loader = DataLoader(
    dataset_dir=data_dir,
    n_workers=N_WORKERS,
    batch_size=BATCH_SIZE,
    n_epochs=EPOCHS,
    num_demonstrations=1
)

In [7]:
start_time = time.time()

# Keep track of the hidden state per episode/trajectory.
# DataLoader provides unique id for each episode, which will
# be different even for the same trajectory when it is loaded
# up again
episode_hidden_states = {}

loss_sum = 0
for batch_i, (batch_images, batch_actions, batch_episode_id) in enumerate(data_loader):
    batch_loss = []

    for image, action, episode_id in zip(batch_images, batch_actions, batch_episode_id):
        if image is None and action is None:
            # A work-item was done. Remove hidden state
            if episode_id in episode_hidden_states:
                removed_hidden_state = episode_hidden_states.pop(episode_id)
                del removed_hidden_state
            continue

        agent_action = agent._env_action_to_agent(action, to_torch=True, check_if_null=True)
        if agent_action is None:
            # Action was null
            continue

        if episode_id not in episode_hidden_states:
            episode_hidden_states[episode_id] = agent.controller.initial_state(1)
        agent_state = episode_hidden_states[episode_id]

        pred_action, new_agent_state = agent(image, agent_state)

        # Make sure we do not try to backprop through sequence
        # (fails with current accumulation)
        new_agent_state = tree_map(lambda x: x.detach(), new_agent_state)
        episode_hidden_states[episode_id] = new_agent_state

        # Finally, update the agent to increase the probability of the
        # taken action.
        # Remember to take mean over batch losses
        buttons_log_prob = torch.log_softmax(pred_action['buttons'], dim=-1)
        buttons_loss = -buttons_log_prob[0, agent_action['buttons'].long()]
        camera_log_prob = torch.log_softmax(pred_action['camera'], dim=-1)
        camera_loss = -camera_log_prob[0, agent_action['camera'].long()]
        loss = buttons_loss + camera_loss
        batch_loss.append(loss)

    batch_loss = torch.stack(batch_loss).mean()
    print(batch_loss)
    batch_loss.backward()

    torch.nn.utils.clip_grad_norm_(trainable_parameters, MAX_GRAD_NORM)
    optimizer.step()
    optimizer.zero_grad()

    if batch_i % LOSS_REPORT_RATE == 0:
        time_since_start = time.time() - start_time
        print(f"Time: {time_since_start:.2f}, Batches: {batch_i}, Avrg loss: {batch_loss:.4f}")

    if batch_i > MAX_BATCHES:
        break

    break

RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

In [10]:
agent_action['buttons']

tensor([[1]], device='cuda:0')

In [11]:
pred_action['buttons'].argmax()

tensor(1970, device='cuda:0')

In [None]:
# Calculate the negative log loss for the action
log_prob = torch.log_softmax(pred_action['buttons'], dim=-1)
loss = -log_prob[0, agent_action['buttons'].long()]
loss

In [None]:
torch.nn.functional.nll_loss(log_prob, agent_action['buttons'].squeeze(0), reduction='none')