In [1]:
from argparse import ArgumentParser
import pickle
import time

import gym
import minerl
import torch
import json
import numpy as np
import glob
import cv2
from tqdm.auto import tqdm
import os

from rebeca import REBECA
from data_loader import DataLoader
from openai_vpt.lib.tree_util import tree_map

# Originally this code was designed for a small dataset of ~20 demonstrations per task.
# The settings might not be the best for the full BASALT dataset (thousands of demonstrations).
# Use this flag to switch between the two settings
USING_FULL_DATASET = True

EPOCHS = 1 if USING_FULL_DATASET else 2
# Needs to be <= number of videos
BATCH_SIZE = 64 if USING_FULL_DATASET else 16
# Ideally more than batch size to create
# variation in datasets (otherwise, you will
# get a bunch of consecutive samples)
# Decrease this (and batch_size) if you run out of memory
N_WORKERS = 100 if USING_FULL_DATASET else 20
DEVICE = "cuda"

LOSS_REPORT_RATE = 100

# Tuned with bit of trial and error
LEARNING_RATE = 0.000181
# OpenAI VPT BC weight decay
# WEIGHT_DECAY = 0.039428
WEIGHT_DECAY = 0.0
# KL loss to the original model was not used in OpenAI VPT
KL_LOSS_WEIGHT = 1.0
MAX_GRAD_NORM = 5.0

MAX_BATCHES = 2000 if USING_FULL_DATASET else int(1e9)

In [2]:
in_model = "data/VPT-models/foundation-model-1x.model"
in_weights = "data/VPT-models/foundation-model-1x-net.weights"

In [3]:
rebeca = REBECA(in_model, in_weights, "data/memory.json", device=DEVICE)
rebeca.to(rebeca.device)
print('Model loaded')

Model loaded


In [4]:
class DemoLoader:
    def __init__(self, data_dir="data/MakeWaterfallTrain/"):
        self.load_expert_data(data_dir)
        self.generator = self.load_demonstrations() # create a generator object

    def load_expert_data(self, data_dir):
        """Load expert demonstrations from data_dir"""

        unique_ids = glob.glob(os.path.join(data_dir, "*.mp4"))
        unique_ids = list(set([os.path.basename(x).split(".")[0] for x in unique_ids]))
        unique_ids.sort()

        self.demonstration_tuples = []
        for unique_id in unique_ids:
            video_path = os.path.abspath(os.path.join(data_dir, unique_id + ".mp4"))
            json_path = os.path.abspath(os.path.join(data_dir, unique_id + ".jsonl"))
            self.demonstration_tuples.append((unique_id, video_path, json_path))

    def load_demonstrations(self):
        """Load expert demonstrations from demonstration tuples"""
        _demonstration_tuples = self.demonstration_tuples

        for unique_id, video_path, json_path in tqdm(
            _demonstration_tuples, desc="Loading expert demonstrations"
        ):
            video = self._load_video(video_path)
            jsonl = self._load_jsonl(json_path)

            yield video, jsonl

    def _load_video(self, video_path):
        frames = []
        cap = cv2.VideoCapture(video_path)
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.resize(frame, (128, 128), interpolation=cv2.INTER_LINEAR)
            frames.append(frame)
        cap.release()
        return frames

    def _load_jsonl(self, jsonl_path):
        with open(jsonl_path) as f:
            return [json.loads(line) for line in f]

    def __iter__(self):
        return self # return the iterator object itself

    def __next__(self):
        return next(self.generator) # return the next value from the generator

In [5]:
data_loader = DemoLoader()

In [10]:
for video, jsonl in data_loader:
    rebeca.reset()
    for frame, env_action in tqdm(zip(video, jsonl), total=len(video)):
        pred_action = rebeca(frame)
    break

  0%|          | 0/5330 [00:00<?, ?it/s]

In [11]:
pred_action

(tensor([[-0.3194, -0.2304,  0.1311,  ...,  0.0514,  0.1191,  0.3249]],
        device='cuda:0', grad_fn=<ReshapeAliasBackward0>),
 tensor([[ 0.1215, -0.4680, -0.0967,  0.0766, -0.0943, -0.1302,  0.1931,  0.0405,
           0.0266, -0.1493, -0.4175,  0.1956,  0.1271, -0.0098,  0.0247,  0.0814,
          -0.1454, -0.1170, -0.2893, -0.1270, -0.0932, -0.0964, -0.3983, -0.2037,
           0.0006,  0.1556,  0.2214, -0.0284, -0.1083,  0.2841, -0.1103, -0.1082,
          -0.1439, -0.2897, -0.0180, -0.1113,  0.2394,  0.0360,  0.2240, -0.0315,
           0.2324, -0.0670,  0.2696,  0.0841, -0.0163,  0.3667, -0.3701,  0.0058,
          -0.1509, -0.4441,  0.0619, -0.0256, -0.0868, -0.1526,  0.0702,  0.2592,
           0.1436, -0.2668, -0.0836,  0.1447,  0.3319,  0.0959,  0.1219, -0.0486,
          -0.1173,  0.2014, -0.1207,  0.0532, -0.1319,  0.0784, -0.2219,  0.1838,
          -0.1637,  0.0094, -0.2201, -0.0925,  0.0219, -0.1162, -0.2440,  0.0585,
          -0.2819, -0.0107,  0.0316,  0.1053, -0.