In [44]:
import os
import json
import glob
import pickle
import time
import cv2
import gym
import minerl
import torch
import torch as th
import numpy as np
from tqdm.auto import tqdm
import faiss

from openai_vpt.agent import PI_HEAD_KWARGS, MineRLAgent
from openai_vpt.lib.policy import MinecraftPolicy
from data_loader import DataLoader
from openai_vpt.lib.tree_util import tree_map

In [2]:
USING_FULL_DATASET = False

EPOCHS = 1 if USING_FULL_DATASET else 1
# Needs to be <= number of videos
BATCH_SIZE = 64 if USING_FULL_DATASET else 16
# Ideally more than batch size to create
# variation in datasets (otherwise, you will
# get a bunch of consecutive samples)
# Decrease this (and batch_size) if you run out of memory
N_WORKERS = 100 if USING_FULL_DATASET else 16
DEVICE = "cuda"

LOSS_REPORT_RATE = 100

# Tuned with bit of trial and error
LEARNING_RATE = 0.000181
# OpenAI VPT BC weight decay
# WEIGHT_DECAY = 0.039428
WEIGHT_DECAY = 0.0
# KL loss to the original model was not used in OpenAI VPT
KL_LOSS_WEIGHT = 1.0
MAX_GRAD_NORM = 5.0

MAX_BATCHES = 2000 if USING_FULL_DATASET else int(1e9)

In [3]:
def load_model_parameters(path_to_model_file):
    agent_parameters = pickle.load(open(path_to_model_file, "rb"))
    policy_kwargs = agent_parameters["model"]["args"]["net"]["args"]
    pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"]
    pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"])
    return policy_kwargs, pi_head_kwargs

In [4]:
in_model = "data/VPT-models/foundation-model-1x.model"
in_weights = "data/VPT-models/foundation-model-1x.weights"

In [5]:
agent_policy_kwargs, agent_pi_head_kwargs = load_model_parameters(in_model)
policy = MinecraftPolicy(**agent_policy_kwargs, single_output=True)
policy.to(DEVICE)
# agent = MineRLAgent(device=DEVICE, policy_kwargs=agent_policy_kwargs, pi_head_kwargs=agent_pi_head_kwargs)
# agent.load_weights(in_weights)
# policy = agent.policy

for param in policy.parameters():
    param.requires_grad = False

In [6]:
video_path = "data/MakeWaterfall/Player571-f153ac423f61-20220707-110239.mp4"

In [7]:
def load_frames(video_path):
    frames = []
    cap = cv2.VideoCapture(video_path)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.resize(frame, (128, 128), interpolation=cv2.INTER_LINEAR)
        frames.append(frame)
    cap.release()
    return frames

In [8]:
def load_jsonl(jsonl_path):
    with open(jsonl_path) as f:
        return [json.loads(line) for line in f]

In [9]:
video = load_frames(video_path)

In [10]:
len(video)

1115

In [11]:
len(load_jsonl("data/MakeWaterfall/Player571-f153ac423f61-20220707-110239.jsonl"))

1114

In [12]:
def preprocess(obs_frame):
    """
    Turn observation from MineRL environment into model's observation

    Returns torch tensors.
    """
    agent_input = cv2.resize(obs_frame, (128, 128), interpolation=cv2.INTER_LINEAR)[None]
    agent_input = {"img": th.from_numpy(agent_input).to(DEVICE)}
    return agent_input

In [18]:
def display_video(frames):
    for frame in frames:
        cv2.imshow("image", frame)
        cv2.waitKey(0)
    cv2.destroyAllWindows()

In [14]:
unique_ids = glob.glob(os.path.join('data/MakeWaterfall/', "*.mp4"))
unique_ids = list(set([os.path.basename(x).split(".")[0] for x in unique_ids]))

In [20]:
class SituationsLoader():
    '''Load the data from the MakeWaterfall dataset and create situations'''
    def __init__(self, data_dir='data/MakeWaterfall/'):
        unique_ids = glob.glob(os.path.join(data_dir, "*.mp4"))
        unique_ids = list(set([os.path.basename(x).split(".")[0] for x in unique_ids]))

        self.demonstration_tuples = []
        for unique_id in unique_ids:
            video_path = os.path.abspath(os.path.join(data_dir, unique_id + ".mp4"))
            json_path = os.path.abspath(os.path.join(data_dir, unique_id + ".jsonl"))
            self.demonstration_tuples.append((unique_id, video_path, json_path))
    
    def __len__(self):
        return len(self.demonstration_tuples)
    
    def _load_video(self, video_path):
        frames = []
        cap = cv2.VideoCapture(video_path)
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.resize(frame, (128, 128), interpolation=cv2.INTER_LINEAR)
            frames.append(frame)
        cap.release()
        return frames
    
    def _load_jsonl(self, jsonl_path):
        with open(jsonl_path) as f:
            return [json.loads(line) for line in f]
        
    def build_situations(self, window_size=128, stride=2):
        
        situations = []
        for unique_id, video_path, json_path in tqdm(self.demonstration_tuples):
            video = self._load_video(video_path)
            jsonl = self._load_jsonl(json_path)
            
            for i in range(window_size, len(video) - window_size, stride):
                situation = {}
                situation['demo_id'] = unique_id
                situation['situation_idx'] = i
                situation['situation_obs'] = video[i-window_size:i+1] # 128 context + 1 current
                situations.append(situation)
                
        return situations
    
    def save_situations(self, situations, save_path):
        with open(save_path, 'wb') as f:
            pickle.dump(situations, f)

    def load_situations(self, save_path):
        with open(save_path, 'rb') as f:
            return pickle.load(f)

In [21]:
situation_loader = SituationsLoader()

In [22]:
situations = situation_loader.build_situations()

  0%|          | 0/30 [00:00<?, ?it/s]

In [51]:
len(situations)

12200

In [52]:
# save_path = "data/situations.pkl"
# situation_loader.save_situations(situations, save_path)

In [40]:
# pkl_loaded_data = situation_loader.load_situations(save_path)

In [61]:
dummy_first = th.from_numpy(np.array((False,))).to(DEVICE).unsqueeze(1)

situation_latents = []

with th.inference_mode():
    for situation in tqdm(situations[:10]):
        situation_obs = situation['situation_obs']

        initial_state = policy.initial_state(1)
        states = [initial_state]

        for obs in situation_obs:
            obs = preprocess(obs)
            obs = tree_map(lambda x: x.unsqueeze(1), obs)
            pi_latent, state_out = policy(obs, states[-1], context={"first": dummy_first})
            states.append(state_out)
        
        situation_latents.append({
            'demo_id': situation['demo_id'],
            'situation_idx': situation['situation_idx'],
            'situation_latent': pi_latent.squeeze().detach().cpu().numpy()
        })

  0%|          | 0/10 [00:00<?, ?it/s]

In [39]:
situation_loader.save_situations(situation_latents, "data/situation_latents.pkl")

In [62]:
situation_latents_array = np.array([x['situation_latent'] for x in situation_latents])
situation_latents_array.shape

(10, 1024)

In [101]:
class Memory():

    def create_index(self, situation_latents_array):
        self.index = faiss.IndexFlatL2(1024)
        self.index.add(situation_latents_array)

    def save_index(self, save_path):
        faiss.write_index(self.index, save_path)
    
    def load_index(self, save_path):
        self.index = faiss.read_index(save_path)

    def search(self, query, k=4):
        distances, nearest_indices = self.index.search(query.reshape(1, 1024), k)
        return distances[0], nearest_indices[0]

In [102]:
memory = Memory()

In [98]:
memory.load_index("data/memory.faiss")

In [99]:
situation_query = situation_latents_array[-1]

In [100]:
memory.search(situation_query)

(array([   0.    ,  741.8623, 1323.79  , 1340.2336], dtype=float32),
 array([9, 8, 1, 0]))