In [None]:
import os
import cv2
from PIL import Image
import torch
import torch.nn as nn
import numpy as np
from model import VPTEncoder, Controller
from memory import SituationLoader, Memory
from matplotlib import pyplot as plt

from openai_vpt.lib.action_mapping import CameraHierarchicalMapping
from openai_vpt.lib.actions import ActionTransformer
from action_utils import ActionProcessor

In [None]:
in_model = "data/VPT-models/foundation-model-1x.model"
in_weights = "data/VPT-models/foundation-model-1x-net.weights"

In [None]:
# state = torch.load(in_weights, map_location="cpu")
# # keep weights only for the net part
# state = {k: v for k, v in state.items() if k.startswith("net.")}
# # remove the "net." prefix
# state = {k[4:]: v for k, v in state.items()}
# torch.save(state, "data/VPT-models/foundation-model-1x-net.weights")

In [None]:
vpt = VPTEncoder(in_model, in_weights)
vpt.eval()
expert_dataloader = SituationLoader(vpt)

In [None]:
demonstrations = expert_dataloader.load_demonstrations(num_demos=2)

In [None]:
encoded_demos = expert_dataloader.encode_demonstrations(demonstrations)

In [None]:
situations = expert_dataloader.create_situations(encoded_demos, stride=64)

In [None]:
# Retriever
#   VPT Encoder (Frozen)
# 	Memory
# Rebeca Policy
# 	Retriever
# 	VPT Backbone (Trainable)
# 	Controller
# Forward
# 	Retrieve Situations
# 	obs = VPT Backbone (obs)
# 	preprocess retrieved situations
# 	key, cam = Controller(obs, situations, actions)

In [None]:
class Retriever():
    def __init__(self, encoder_model, encoder_weights, memory_path):
        self.vpt = VPTEncoder(encoder_model, encoder_weights)
        self.vpt.eval()
        self.memory = Memory()
        self.memory.load_index(memory_path)

        self.reset()

    def encode_query(self, query_obs):
        query_obs_vec, state_out = self.vpt(query_obs, self.hidden_state)
        self.hidden_state = state_out
        query_obs_vec = query_obs_vec.squeeze().cpu().numpy()
        return query_obs_vec

    def retrieve(self, query_obs, k=2):
        query_obs_vec = self.encode_query(query_obs)
        results = self.memory.search(query_obs_vec, k=k)

        if results[0]['distance'] == 0: # to prevent returning the same situation and overfitting
            print("Same situation found")
            return results[1]
        else:
            return results[0]

    def reset(self):
        self.hidden_state = self.vpt.policy.initial_state(1)

In [None]:
retriever = Retriever(in_model, in_weights, "data/memory.json")

In [None]:
def _load_video(video_path):
    frames = []
    cap = cv2.VideoCapture(video_path)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.resize(frame, (128, 128), interpolation=cv2.INTER_LINEAR)
        frames.append(frame)
    cap.release()
    return frames

In [None]:
def load_demonstration(demo_id, root_dir="data/MakeWaterfall"):
    video_path = f"{root_dir}/{demo_id}.mp4"
    frames = _load_video(video_path)
    return frames

def load_situation(situation_id, demo_frames):
    situation = demo_frames[situation_id]
    return situation

In [None]:
for obs in demonstrations[0]['video'][150:200]:
    result = retriever.retrieve(obs)
    print(result['distance'])
    plt.imshow(obs[..., ::-1])
    plt.show()
    res_demo = load_demonstration(result['demo_id'])
    res_situation = load_situation(result['sit_frame_idx'], res_demo)
    plt.imshow(res_situation[..., ::-1])
    plt.show()

In [None]:
class REBECA(nn.Module):
    def __init__(self, encoder_model, encoder_weights, memory_path, controller_path, controller_weights):
        super().__init__()
        self.retriever = Retriever(encoder_model, encoder_weights, memory_path)
        self.vpt = VPTEncoder(encoder_model, encoder_weights)
        self.controller = Controller(controller_path, controller_weights)

    def forward(self, obs, actions):
        result = self.retriever.retrieve(obs)
        res_demo = load_demonstration(result['demo_id'])
        res_situation = load_situation(result['sit_frame_idx'], res_demo)
        action = self.controller(obs, res_situation, actions)
        return action