## Device check

In [1]:
!nvidia-smi

Tue Oct 24 02:12:38 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.60.13    Driver Version: 525.60.13    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  On   | 00000000:15:00.0 Off |                    0 |
| N/A   27C    P0    40W / 300W |      0MiB / 32768MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Packages

In [2]:
import warnings
warnings.filterwarnings(action="ignore")

In [194]:
from einops import pack, unpack, repeat, reduce, rearrange
from einops.layers.torch import Rearrange, Reduce

import numpy as np

import logging
logging.basicConfig(level="INFO")
import os

import pandas as pd
from pprint import pprint

import sys
sys.path.append("../src")

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchinfo import summary

import lightning.pytorch as pl

In [4]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
%load_ext autoreload
%autoreload 2

In [212]:
import config
from dataloader import BEDataset, BEDataModule
import token_learner
from transformer import PositionalEncoder, TransformerDecoder, LayerNormalization, FeedFowardLayer

from film_layers import FiLMBlockV2, FiLMEncoder, ResBlockDWConv
from utils.model_utils import TextEncoder, ImageFeatureExtractor

## Load data summary

In [7]:
csv = pd.read_csv(os.path.join(config.DATASET_PATH, "train.csv"))

csv.head()

Unnamed: 0,sample_ID,in_state,goal_state,action_description,motor_cmd,len_action_desc,len_motor_cmd,version
0,7294,0,10,put the fork to the right of buttermilk,:FORK GREEN POSE-9 :BUTTERMILK GREEN POSE-2 :F...,8,11,v2
1,405,0,8,move the bottle backwards,:BOTTLE RED POSE-2 :BOTTLE #'*backward-transf...,4,8,v1
2,4235,0,10,put the bottle to the left of breakfast-cereal,:BOTTLE RED POSE-7 :BREAKFAST-CEREAL BLUE POSE...,8,11,v2
3,6990,0,10,put the milk in front of bottle,:MILK BLUE POSE-8 :BOTTLE RED POSE-4 :MILK #'...,7,11,v2
4,7096,0,10,put the cup in front of glasses,:CUP GREEN POSE-6 :GLASSES RED POSE-2 :CUP #'...,7,11,v2


In [8]:
# building data object
ds = BEDataset(
    df=csv    
)

len(ds)

4876

In [9]:
# fetching example
rand_idx = np.random.randint(low=0, high=len(ds))
ex = ds[rand_idx]

print("Dataset size: ", len(ds))
print("="*100)
print("ID\t: ", ex["sample_id"])
print(">> InState\t: ", ex["in_state"].shape)
print(">> Desc\t:")
pprint(ex["action_desc"])
print(">> Cmd\t:")
pprint(ex["motor_cmd"])
print("="*100)

Dataset size:  4876
ID	:  9401
>> InState	:  torch.Size([3, 224, 224])
>> Desc	:
{'ids': tensor([  101,  2404,  1996,  5442,  2000,  1996,  2157,  1997, 12256, 17130,
         2378,   102,     0,     0,     0,     0]),
 'length': 8,
 'mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]),
 'raw': 'put the knife to the right of mondamin',
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}
>> Cmd	:
{'ids': tensor([ 0, 17, 41, 30, 18, 41, 36, 17, 48, 18,  2,  2,  2,  2,  2,  2]),
 'length': 11,
 'mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]),
 'raw': ':KNIFE GREEN POSE-10 :MONDAMIN GREEN POSE-3 :KNIFE  '
        "#'*rightward-transformation*  :MONDAMIN"}


## Data Module

In [84]:
dm = BEDataModule()
dm.setup()

Total # examples: 4876


INFO:root:Training on 3610 samples.
INFO:root:Validating on 1266 samples.


In [85]:
# print("="*100)
# logging.info("\n>> train data loader")
# print(f"# train batches\t: {len(dm.train_dataloader())}")
# for data in dm.train_dataloader():
#     # pprint(data)
#     sample_id, in_state, ad, cmd = data["sample_id"], data["in_state"], data["action_desc"], data["motor_cmd"]
#     print("In \t\t\t: ", in_state.shape)
#     print("Action desc \t\t: ", ad["ids"].shape)
#     print("Action desc (len) \t: ", ad["length"].shape)
#     print("CMD \t\t\t: ", cmd["ids"].shape)
#     print("CMD(len) \t\t: ", cmd["length"].shape)
#     break

# print("\nIDs & decided tokens")
# for data in dm.train_dataloader():
#     print(data["action_desc"]["ids"][0].tolist())
#     print(dm.train_ds._decode_inputs(data["action_desc"]["ids"][0].tolist()))
#     print()
#     print(data["motor_cmd"]["ids"][0].tolist())
#     print(dm.train_ds._decode_outputs(data["motor_cmd"]["ids"][0].tolist()))

#     break
    
# print("="*100)

## Fetch batch

In [86]:
%%time
sample = next(iter(dm.train_dataloader()))
sample["in_state"].shape

CPU times: user 15.9 ms, sys: 271 ms, total: 287 ms
Wall time: 1.71 s


torch.Size([8, 3, 224, 224])

## Model Design

<!-- ![RT1 model architecture](../../imgs/rt1+.png) -->
<center>
    <img src="../imgs/rt1+.png" alt="RT1 model architecture" width="300" height="400">

<center>

### Encoder

#### Test Text Encoder

In [195]:
te = TextEncoder(freeze=True).cuda()
summary(model=te, col_names=["num_params", "trainable"])

Layer (type:depth-idx)                                  Param #                   Trainable
TextEncoder                                             --                        False
├─BertModel: 1-1                                        --                        False
│    └─BertEmbeddings: 2-1                              --                        False
│    │    └─Embedding: 3-1                              (15,627,264)              False
│    │    └─Embedding: 3-2                              (262,144)                 False
│    │    └─Embedding: 3-3                              (1,024)                   False
│    │    └─LayerNorm: 3-4                              (1,024)                   False
│    │    └─Dropout: 3-5                                --                        --
│    └─BertEncoder: 2-2                                 --                        False
│    │    └─ModuleList: 3-6                             (12,609,536)              False
│    └─BertPooler: 2-3         

In [196]:
emb = te(
    inp_ids=sample["action_desc"]["ids"].cuda(),
    mask=sample["action_desc"]["mask"].cuda(),
    tok_type_ids=sample["action_desc"]["token_type_ids"].cuda()
)

emb.shape

torch.Size([8, 512])

#### Test Img Feature Extractor

In [201]:
fe = ImageFeatureExtractor(pretrained=True, arch="resnet34").cuda()

summary(fe, col_names=["num_params", "trainable"])

Layer (type:depth-idx)                        Param #                   Trainable
ImageFeatureExtractor                         --                        Partial
├─ResNet: 1-1                                 --                        Partial
│    └─Conv2d: 2-1                            (9,408)                   False
│    └─BatchNorm2d: 2-2                       (128)                     False
│    └─ReLU: 2-3                              --                        --
│    └─MaxPool2d: 2-4                         --                        --
│    └─Sequential: 2-5                        --                        False
│    │    └─BasicBlock: 3-1                   (73,984)                  False
│    │    └─BasicBlock: 3-2                   (73,984)                  False
│    │    └─BasicBlock: 3-3                   (73,984)                  False
│    └─Sequential: 2-6                        --                        False
│    │    └─BasicBlock: 3-4                   (230,144)       

In [127]:
img_ftrs = fe(sample["in_state"].cuda())

img_ftrs.shape

torch.Size([8, 512, 7, 7])

#### Test FiLM Block

In [205]:
film_block = FiLMBlockV2().cuda()
print(film_block)
summary(model=film_block)

FiLMBlockV2(
  (projection_add): Linear(in_features=512, out_features=512, bias=True)
  (projection_mult): Linear(in_features=512, out_features=512, bias=True)
)


Layer (type:depth-idx)                   Param #
FiLMBlockV2                              --
├─Linear: 1-1                            262,656
├─Linear: 1-2                            262,656
Total params: 525,312
Trainable params: 525,312
Non-trainable params: 0

In [206]:
text_cond_ftrs = film_block(
    img_features=img_ftrs, 
    conditioning=emb
)

text_cond_ftrs.shape

torch.Size([8, 512, 7, 7])

#### Test Residual FiLM Block

In [208]:
dw_res = ResBlockDWConv(512, 512).cuda()
summary(model=dw_res)

Layer (type:depth-idx)                   Param #
ResBlockDWConv                           --
├─Conv2d: 1-1                            262,656
├─ReLU: 1-2                              --
├─Conv2d: 1-3                            5,120
├─Conv2d: 1-4                            262,656
├─BatchNorm2d: 1-5                       1,024
├─FiLMBlockV2: 1-6                       --
│    └─Linear: 2-1                       262,656
│    └─Linear: 2-2                       262,656
├─ReLU: 1-7                              --
Total params: 1,056,768
Trainable params: 1,056,768
Non-trainable params: 0

In [209]:
text_cond_ftrs_res = dw_res(
    img_features=img_ftrs, 
    conditioning=emb
)

text_cond_ftrs_res.shape

torch.Size([8, 512, 7, 7])

#### Test FiLM Encoder

In [213]:
film_encoder = FiLMEncoder(
    arch="resnet34",
    n_res_blocks=6,
).cuda()

# print(film_encoder)
summary(model=film_encoder)

Layer (type:depth-idx)                             Param #
FiLMEncoder                                        --
├─ImageFeatureExtractor: 1-1                       --
│    └─ResNet: 2-1                                 --
│    │    └─Conv2d: 3-1                            (9,408)
│    │    └─BatchNorm2d: 3-2                       (128)
│    │    └─ReLU: 3-3                              --
│    │    └─MaxPool2d: 3-4                         --
│    │    └─Sequential: 3-5                        (221,952)
│    │    └─Sequential: 3-6                        (1,116,416)
│    │    └─Sequential: 3-7                        (6,822,400)
│    │    └─Sequential: 3-8                        (13,114,368)
│    │    └─AdaptiveAvgPool2d: 3-9                 --
│    │    └─Linear: 3-10                           513,000
│    └─Sequential: 2-2                             21,284,672
│    │    └─Conv2d: 3-11                           (recursive)
│    │    └─BatchNorm2d: 3-12                      (recursive)
│  

In [214]:
%%time

out = film_encoder(
    x= sample["in_state"].cuda(),
    conditioning= emb
)

out.shape

CPU times: user 13.2 ms, sys: 13.2 ms, total: 26.4 ms
Wall time: 632 ms


torch.Size([8, 512, 7, 7])

#### Vision-Language Token Extraction

In [234]:
vl_conv = conv(config.EMBEDDING_DIM, 1, 1, 1, 0).cuda()
print(vl_conv)
summary(vl_conv)

Sequential(
  (0): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
  (1): GELU(approximate='none')
  (2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)


Layer (type:depth-idx)                   Param #
Sequential                               --
├─Conv2d: 1-1                            513
├─GELU: 1-2                              --
├─BatchNorm2d: 1-3                       2
Total params: 515
Trainable params: 515
Non-trainable params: 0

In [236]:
vl_tokens = vl_conv(out)
vl_tokens.shape

torch.Size([8, 1, 7, 7])

#### Token Learner

In [215]:
N, C, H, W = out.shape

N, C, H, W

(8, 512, 7, 7)

In [232]:
tokL_v11 = token_learner.TokenLearnerModuleV11(feature_shape=(N, H*W, C))
print(tokL_v11)
summary(model=tokL_v11)

TokenLearnerModuleV11(
  (layer_norm): LayerNormalization(
    (layer): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
  )
  (token_masking): FeedFowardLayer(
    (linear_1): Linear(in_features=512, out_features=64, bias=True)
    (activation): GELU(approximate='none')
    (linear_2): Linear(in_features=64, out_features=8, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
  )
)


Layer (type:depth-idx)                   Param #
TokenLearnerModuleV11                    --
├─LayerNormalization: 1-1                --
│    └─LayerNorm: 2-1                    1,024
├─FeedFowardLayer: 1-2                   --
│    └─Linear: 2-2                       32,832
│    └─GELU: 2-3                         --
│    └─Linear: 2-4                       520
│    └─Dropout: 2-5                      --
Total params: 34,376
Trainable params: 34,376
Non-trainable params: 0

In [233]:
# Create a dummy input tensor with the shape [batch_size, height, width, channels]
dummy_input = torch.randn(N, H * W, C)

# Set the model to evaluation mode
tokL_v11.eval()

# Pass the dummy input through the model
with torch.no_grad():
    output = tokL_v11(dummy_input)

# Print the shape of the output
print("Output shape:", output.shape)

Output shape: torch.Size([8, 512, 8])


#### RT-1 Encoder

In [74]:
class RT1Encoder(nn.Module):
    def __init__(
        self,
        cnn_bacnbone:str="resnet18"
    ):
        super().__init__()
        
        # Text encoder
        self.text_encoder = TextEncoder()
        
        # Image encoder
        self.image_encoder = FiLMEncoder()
        
        # Vision-Language tokens extractor
        self.vl_conv = conv(self.image_encoder.n_channels, 1, 1, 1, 0)
        
        # Token Learner
        self.token_learner = TokenLearner()
        
        # Transformer decoder
        self.transformer = TransformerDecoder()

    def forward(self, input_ids, attn_mask, token_type_ids, imgs):
        """
        """
        text_enc = self.text_encoder(
            inp_ids=input_ids,
            mask=attn_mask,
            tok_type_ids=token_type_ids
        )
        
        # Generage image tokens
        img_tokens = self.image_encoder(
            x= imgs,
            description= text_enc
        )
        
        # Vision-Language tokens extractor
        img_tokens = self.vl_conv(img_tokens)
        
        # Extract learned tokens
        learned_tokens  = self.tok(img_tokens)
        
        return learned_tokens
        

In [81]:
encoder = RT1Encoder()
summary(model=encoder, col_names=["num_params", "trainable"])

Layer (type:depth-idx)                                       Param #                   Trainable
RT1Encoder                                                   --                        Partial
├─TextEncoder: 1-1                                           --                        False
│    └─BertModel: 2-1                                        --                        False
│    │    └─BertEmbeddings: 3-1                              (7,945,728)               False
│    │    └─BertEncoder: 3-2                                 (3,159,040)               False
│    │    └─BertPooler: 3-3                                  (65,792)                  False
│    └─Dropout: 2-2                                          --                        --
├─FiLMEncoder: 1-2                                           --                        Partial
│    └─Linear: 2-3                                           101,376                   True
│    └─FeatureExtractor: 2-4                                 --   

### RT-1 Decoder

#### Transformer Decoder

In [52]:
dec = TransformerDecoder()
summary(model=dec)

Layer (type:depth-idx)                   Param #
TransformerDecoder                       --
├─ModuleList: 1-1                        --
│    └─TransformerDecoderLayer: 2-1      --
│    │    └─LayerNormalization: 3-1      1,024
│    │    └─MultiheadAttention: 3-2      1,050,624
│    │    └─Dropout: 3-3                 --
│    │    └─LayerNormalization: 3-4      1,024
│    │    └─MultiheadAttention: 3-5      1,050,624
│    │    └─Dropout: 3-6                 --
│    │    └─LayerNormalization: 3-7      1,024
│    │    └─FeedFowardLayer: 3-8         1,050,112
│    │    └─Dropout: 3-9                 --
│    └─TransformerDecoderLayer: 2-2      --
│    │    └─LayerNormalization: 3-10     1,024
│    │    └─MultiheadAttention: 3-11     1,050,624
│    │    └─Dropout: 3-12                --
│    │    └─LayerNormalization: 3-13     1,024
│    │    └─MultiheadAttention: 3-14     1,050,624
│    │    └─Dropout: 3-15                --
│    │    └─LayerNormalization: 3-16     1,024
│    │    └─FeedFo

#### Action Generator

In [None]:
class ActionGenerator(nn.Module):
    def __init__(
        self
    ):
        super().__init__()
    
    def forward(self, tokens):
        pass

#### Decoder

In [None]:
class RT1Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.positional_encoder = PositionalEncoder()
        self.transformer = TransformerDecoder()
        # 
        self.action_generator = ActionGenerator()
        

    def _positional_encoding(
        self,
        seq, 
        dim, 
        temperature = 10000, 
        device = None, 
        dtype = torch.float32
    ):
        n = torch.arange(seq, device = device)
        omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
        omega = 1. / (temperature ** omega)

        n = n[:, None] * omega[None, :]
        pos_emb = torch.cat((n.sin(), n.cos()), dim = 1)
        
        return pos_emb.type(dtype)

    
    def forward(self, instructions, imgs):
        pass

### RT-1 

In [None]:
class RT1(pl.LightningModule):
    def __init__(
        self
    ):
        super().__init__()
        self.encoder = RT1Encder()
        self.decoder = RT1Decoder()
        
    def forward(self, imgs, instruction):
        pass
    
    def configure_optimizers(self):
        pass
    
    def training_step(self, batch, batch_idx):
        pass
    
    def validation_step(self, batch, batch_idx):
        pass
    
    def test_step(self, batch, batch_idx):
        pass
    
    def compute_loss(self, outputs, targets):
        pass