In [1]:
from PIL import Image
import torch
import torch.nn as nn
import numpy as np
from model import VPTEncoder, Controller
from memory import SituationLoader, Memory
from matplotlib import pyplot as plt

from openai_vpt.lib.action_mapping import CameraHierarchicalMapping
from openai_vpt.lib.actions import ActionTransformer
from action_utils import ActionProcessor

In [2]:
in_model = "data/VPT-models/foundation-model-1x.model"
in_weights = "data/VPT-models/foundation-model-1x.weights"

In [3]:
vpt = VPTEncoder(in_model)
vpt.eval()
expert_dataloader = SituationLoader(vpt)

In [4]:
demonstrations = expert_dataloader.load_demonstrations(num_demos=2)

Loading expert demonstrations:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
encoded_demos = expert_dataloader.encode_demonstrations(demonstrations)

Encoding expert demonstrations:   0%|          | 0/2 [00:00<?, ?it/s]

Encoding Trajectory:   0%|          | 0/1029 [00:00<?, ?it/s]

Encoding Trajectory:   0%|          | 0/1255 [00:00<?, ?it/s]

In [None]:
situations = expert_dataloader.create_situations(encoded_demos, stride=64)

In [None]:
encoded_demos[0].keys()

In [None]:
situations[0]['actions']['buttons']

In [None]:
situations[1].keys()

In [None]:
memory = Memory()

In [None]:
memory.create_index(situations)

In [None]:
memory.save_index(save_dir="data")

In [None]:
memory.load_index('data/memory.json')

In [None]:
query = situations[0]["situation"]

In [None]:
res = memory.search(query, k=5)
res

In [None]:
Image.fromarray(demonstrations[0]["video"][res[0]["sit_frame_idx"]][..., ::-1])

In [None]:
Image.fromarray(demonstrations[1]["video"][res[1]["sit_frame_idx"]][..., ::-1])

In [None]:
# Image.fromarray(list(filter(lambda x: x['demo_id'] == 'Player648-f6ef373a67fd-20220706-200136', demonstrations))[0]['video'][res[3]['situation_idx']])

In [None]:
# Define some dummy inputs
batch_size = 16
seq_len = 10
embed_dim = 1024
num_heads = 8

query = torch.randn(1, batch_size, embed_dim) # The query sequence
key = torch.randn(seq_len, batch_size, embed_dim) # The key sequence
value = torch.randn(seq_len, batch_size, embed_dim) # The value sequence

# Create a multihead attention layer
attn = nn.MultiheadAttention(embed_dim, num_heads)

# Apply cross attention
output, weights = attn(query, key, value)

In [None]:
# Define some constants
d_model = 1024 # dimension of the input and output vectors
d_k = 1024 # dimension of the query and key vectors
d_v = 1024 # dimension of the value vectors
n_heads = 4 # number of attention heads
assert d_model % n_heads == 0 # make sure d_model is divisible by n_heads

# Define linear layers for projection
Wq_obs = nn.Linear(d_model, d_k) # project observation embedding to query vector
Wk_sit = nn.Linear(d_model, d_k) # project situation embedding to key vector
Wv_sit = nn.Linear(d_model, d_v) # project situation embedding to value vector
Wk_key = nn.Linear(8641, d_k) # project keyboard action one-hot vector to key vector
Wv_key = nn.Linear(8641, d_v) # project keyboard action one-hot vector to value vector
Wk_cam = nn.Linear(121, d_k) # project camera action one-hot vector to key vector
Wv_cam = nn.Linear(121, d_v) # project camera action one-hot vector to value vector

# Define multi-head attention layer
MHA = nn.MultiheadAttention(d_model, n_heads)

# Define output layer for concatenation or addition
Wo = nn.Linear(d_v, d_model) # project output vector to original dimension

# Define input tensors
obs = torch.randn(1, 1, d_model) # observation embedding tensor of shape [1, 1, 1024]
sit = torch.randn(1, 1, d_model) # situation embedding tensor of shape [1, 1, 1024]
key = torch.randn(1, 128, 8641) # keyboard action one-hot tensor of shape [1, 128, 8641]
cam = torch.randn(1, 128, 121) # camera action one-hot tensor of shape [1, 128, 121]

# Project input tensors to query, key and value vectors
q_obs = Wq_obs(obs) # query vector tensor of shape [1 ,1 ,1024]
k_sit = Wk_sit(sit) # key vector tensor of shape [1 ,1 ,1024]
v_sit = Wv_sit(sit) # value vector tensor of shape [1 ,1 ,1024]
k_key = Wk_key(key) # key vector tensor of shape [1 ,128 ,1024]
v_key = Wv_key(key) # value vector tensor of shape [1 ,128 ,1024]
k_cam = Wk_cam(cam) # key vector tensor of shape [1 ,128 ,1024]
v_cam = Wv_cam(cam) # value vector tensor of shape [1 ,128 ,1024]

In [None]:
# Concatenate all the key and value vectors along the second dimension
k_all = torch.cat([k_sit, k_key, k_cam], dim=1) # key vector tensor of shape [1 ,257 ,1024]
v_all = torch.cat([v_sit, v_key, v_cam], dim=1) # value vector tensor of shape [1 ,257 ,1024]

In [None]:
q_obs.shape

In [None]:
# Apply multi-head attention on the query and key-value pairs
out_obs, _ = MHA(q_obs, k_all.transpose(0 ,1), v_all.transpose(0 ,1)) 
# output vector tensor of shape [1 ,1 ,1024]

# Optionally, you can concatenate or add the output vector with the original query vector
out_obs = torch.cat([out_obs.transpose(0 ,1), q_obs], dim=1)
# output vector tensor of shape [1 ,2 ,1024] after concatenation
# out_obs = out_obs + q_obs 
# output vector tensor of shape [1 ,2 ,1024] after addition

# Apply output layer on the output vector
# out_obs = Wo(out_obs) 
# output vector tensor of shape [1 ,2 ,1024] after projection

out_obs.shape

In [None]:
controller = Controller()

In [None]:
_observation = torch.Tensor(situations[0]["situation"]).reshape(1, 1, -1)
_situation = torch.Tensor(situations[0]["situation"]).reshape(1, 1, -1)
_actions = situations[10]["actions"]
_actions = {
    "camera": torch.Tensor(_actions["camera"]).unsqueeze(0),
    "keyboard": torch.Tensor(_actions["buttons"]).unsqueeze(0)
}

In [None]:
out_key, out_cam = controller(_observation, _situation, _actions)

In [None]:
out_key.shape

In [None]:
out_cam.shape