## Device check

In [1]:
!nvidia-smi

Fri Oct 20 03:18:37 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.60.13    Driver Version: 525.60.13    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  On   | 00000000:15:00.0 Off |                    0 |
| N/A   28C    P0    54W / 300W |    725MiB / 32768MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Packages

In [2]:
from einops import pack, unpack, repeat, reduce, rearrange
from einops.layers.torch import Rearrange, Reduce

import numpy as np

import logging
logging.basicConfig(level="INFO")
import os

import pandas as pd
from pprint import pprint

import sys
sys.path.append("../")

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.models as models
from torchinfo import summary

from transformers import (AutoTokenizer, AutoModel, AdamW, AutoConfig, get_linear_schedule_with_warmup)

import lightning.pytorch as pl

  from .autonotebook import tqdm as notebook_tqdm
INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmppwrkadki
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmppwrkadki/_remote_module_non_scriptable.py


In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
%load_ext autoreload
%autoreload 2

In [4]:
import config
from dataloader import BEDataset, BEDataModule
from transformer import PositionalEncoder, TransformerDecoder

## Load data summary

In [5]:
csv = pd.read_csv(os.path.join(config.DATASET_PATH, "train.csv"))

csv.head()

Unnamed: 0,sample_ID,in_state,goal_state,action_description,motor_cmd,len_action_desc,len_motor_cmd,version
0,7294,0,10,put the fork to the right of buttermilk,:FORK GREEN POSE-9 :BUTTERMILK GREEN POSE-2 :F...,8,11,v2
1,405,0,8,move the bottle backwards,:BOTTLE RED POSE-2 :BOTTLE #'*backward-transf...,4,8,v1
2,4235,0,10,put the bottle to the left of breakfast-cereal,:BOTTLE RED POSE-7 :BREAKFAST-CEREAL BLUE POSE...,8,11,v2
3,6990,0,10,put the milk in front of bottle,:MILK BLUE POSE-8 :BOTTLE RED POSE-4 :MILK #'...,7,11,v2
4,7096,0,10,put the cup in front of glasses,:CUP GREEN POSE-6 :GLASSES RED POSE-2 :CUP #'...,7,11,v2


In [6]:
# building data object
ds = BEDataset(
    df=csv    
)

len(ds)

4876

In [8]:
# fetching example
rand_idx = np.random.randint(low=0, high=len(ds))
ex = ds[rand_idx]

print("Dataset size: ", len(ds))
print("="*100)
print("ID\t: ", ex["sample_id"])
print(">> InState\t: ", ex["in_state"].shape)
print(">> Desc\t:")
pprint(ex["action_desc"])
print(">> Cmd\t:")
pprint(ex["motor_cmd"])
print("="*100)

Dataset size:  4876
ID	:  2636
>> InState	:  torch.Size([3, 224, 224])
>> Desc	:
{'ids': tensor([ 101, 2404, 1996, 4605, 2006, 2327, 1997, 5127,  102,    0,    0,    0,
           0,    0,    0]),
 'length': 7,
 'mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]),
 'raw': 'put the bowl on top of plate',
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}
>> Cmd	:
{'ids': tensor([ 0, 21, 31, 47, 24, 31, 27, 21, 42, 24,  2,  2,  2,  2,  2]),
 'length': 11,
 'mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]),
 'raw': ":BOWL RED POSE-6 :PLATE RED POSE-4 :BOWL  #'*on-transformation*  "
        ':PLATE'}


## Data Module

In [15]:
dm = BEDataModule()
dm.setup()

Total # examples: 4876


INFO:root:Training on 3610 samples.
INFO:root:Validating on 1266 samples.


In [29]:
print("="*100)
logging.info("\n>> train data loader")
print(f"# train batches\t: {len(dm.train_dataloader())}")
for data in dm.train_dataloader():
    # pprint(data)
    sample_id, in_state, ad, cmd = data["sample_id"], data["in_state"], data["action_desc"], data["motor_cmd"]
    print("In \t\t\t: ", in_state.shape)
    print("Action desc \t\t: ", ad["ids"].shape)
    print("Action desc (len) \t: ", ad["length"].shape)
    print("CMD \t\t\t: ", cmd["ids"].shape)
    print("CMD(len) \t\t: ", cmd["length"].shape)
    break

print("\nIDs & decided tokens")
for data in dm.train_dataloader():
    print(data["action_desc"]["ids"][0].tolist())
    print(dm.train_ds._decode_inputs(data["action_desc"]["ids"][0].tolist()))
    print()
    print(data["motor_cmd"]["ids"][0].tolist())
    print(dm.train_ds._decode_outputs(data["motor_cmd"]["ids"][0].tolist()))

    break
    
print("="*100)

INFO:root:
>> train data loader


# train batches	: 451
In 			:  torch.Size([8, 3, 224, 224])
Action desc 		:  torch.Size([8, 16])
Action desc (len) 	:  torch.Size([8])
CMD 			:  torch.Size([8, 16])
CMD(len) 		:  torch.Size([8])

IDs & decided tokens
[101, 2404, 1996, 5835, 2000, 1996, 2187, 1997, 14690, 7068, 102, 0, 0, 0, 0, 0]
put the bottle to the left of spatula

[0, 8, 46, 50, 10, 31, 40, 8, 43, 10, 2, 2, 2, 2, 2, 2]
:BOTTLE BLUE POSE-7 :SPATULA RED POSE-2 :BOTTLE #'*leftward-transformation* :SPATULA


## Model Design

<!-- ![RT1 model architecture](../../imgs/rt1+.png) -->
<center>
    <img src="../../imgs/rt1+.png" alt="RT1 model architecture" width="500" height="300">

<center>

### Encoder

#### Image Feature Extractor

In [77]:
def get_backbone(model:nn.Module):
    return nn.Sequential(*list(model.children())[:-2])

def conv(ic, oc, k, s, p, activation:str="GELU"):
    """
        Courtesy of [Kim Minjong](https://github.com/caffeinism):
        Adapted from: https://github.com/caffeinism/FiLM-pytorch/blob/master/networks.py
    """    
    return nn.Sequential(
        nn.Conv2d(ic, oc, k, s, p),
        getattr(nn, activation)(),
        nn.BatchNorm2d(oc),
    )


class FeatureExtractor(nn.Module):
    """
        Courtesy of [Kim Minjong](https://github.com/caffeinism):
        Adapted from: https://github.com/caffeinism/FiLM-pytorch/blob/master/networks.py
    """
    def __init__(
        self, 
        pretrained:bool=True, 
        arch:str="resnet34",
        freeze:bool=True
    ):
        super(FeatureExtractor, self).__init__()

        self.pretrained = pretrained
        self.freeze = freeze

        if self.pretrained:
            self.arch   = getattr(models, arch)(weights="IMAGENET1K_V1")
            self.fe     = get_backbone(model=self.arch)
        else:
            self.fe = nn.Sequential(
            conv(3, 128, 5, 2, 2),
            conv(128, 128, 3, 2, 1),
            conv(128, 128, 3, 2, 1),
            conv(128, 128, 3, 1, 1),
            conv(128, 128, 3, 1, 1),
        )
            
        if self.freeze:
            self._freeze_model()
            
    def _freeze_model(self):
        for param in self.fe.parameters():
            param.requires_grad = False 

    def forward(self, x, flat_out:bool=False):
        if self.pretrained:
            enc = self.fe(x)
            if flat_out:
                return torch.flatten(enc, 1)
            else:
                return enc
        else:
            return self.fe(x)


class Head(nn.Module):
    """
        Courtesy of [Kim Minjong](https://github.com/caffeinism):
        Adapted from: https://github.com/caffeinism/FiLM-pytorch/blob/master/networks.py
    """
    def __init__(self, prev_channels, n_classes):
        super(Head, self).__init__()

        self.conv = nn.Conv2d(prev_channels, 512, 1, 1, 0)
        self.relu = nn.ReLU(inplace=True)
        self.global_max_pool = nn.AdaptiveMaxPool2d((1, 1))
        self.fc = nn.Sequential(nn.Linear(512, 1024),
                                   nn.ReLU(inplace=True),
                                   nn.Linear(1024, 1024),
                                   nn.ReLU(inplace=True),
                                   nn.Linear(1024, n_classes))

    def forward(self, x, return_feats:bool=True):

        x = self.conv(x)
        feats = self.global_max_pool(x)

        if return_feats:
            return torch.flatten(feats, 1)
        else:
            x = feats.view(feats.size(0), feats.size(1))
            x = self.fc(x)
            return x




In [79]:
fe = FeatureExtractor(pretrained=True, arch="efficientnet_b3").cuda()

summary(fe, col_names=["num_params", "trainable"])

Layer (type:depth-idx)                                       Param #                   Trainable
FeatureExtractor                                             --                        Partial
├─EfficientNet: 1-1                                          --                        Partial
│    └─Sequential: 2-1                                       --                        False
│    │    └─Conv2dNormActivation: 3-1                        (1,160)                   False
│    │    └─Sequential: 3-2                                  (3,504)                   False
│    │    └─Sequential: 3-3                                  (48,118)                  False
│    │    └─Sequential: 3-4                                  (110,912)                 False
│    │    └─Sequential: 3-5                                  (638,700)                 False
│    │    └─Sequential: 3-6                                  (1,387,760)               False
│    │    └─Sequential: 3-7                                  (

In [80]:
img_ftrs = fe(ex["in_state"].unsqueeze(0).cuda())

img_ftrs.shape

torch.Size([1, 1536, 7, 7])

#### Film Block

In [33]:
class FiLMBlock(nn.Module):
    """
        Courtesy of [Kim Minjong](https://github.com/caffeinism):
        Adapted from: https://github.com/caffeinism/FiLM-pytorch/blob/master/networks.py
    """
    def __init__(self):
        super(FiLMBlock, self).__init__()

    def forward(self, x, gamma, beta):
        beta = beta.view(x.size(0), x.size(1), 1, 1)
        gamma = gamma.view(x.size(0), x.size(1), 1, 1)

        x = gamma * x + beta

        return x

#### Film Residual

In [34]:
class ResBlock(nn.Module):
    """
        Courtesy of [Kim Minjong](https://github.com/caffeinism):
        Adapted from: https://github.com/caffeinism/FiLM-pytorch/blob/master/networks.py
    """
    def __init__(self, in_place, out_place):
        super(ResBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_place, out_place, 1, 1, 0)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_place, out_place, 3, 1, 1)
        self.norm2 = nn.BatchNorm2d(out_place)
        self.film = FiLMBlock()
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x, beta, gamma):
        x = self.conv1(x)
        x = self.relu1(x)
        identity = x

        x = self.conv2(x)
        x = self.norm2(x)
        x = self.film(x, beta, gamma)
        x = self.relu2(x)

        x = x + identity

        return x

#### Text Encoder

In [71]:
class TextEncoder(nn.Module):
    def __init__(self, dropout_rate:float=config.TEXT_ENC_DROPOUT, freeze:bool=True):
        super().__init__()
        
        self.freeze = freeze
        
        model_config = AutoConfig.from_pretrained(config.LANG_MODEL_NAME)
        self.text_encoder = AutoModel.from_pretrained(config.LANG_MODEL_NAME, config=model_config)
        self.dropout = nn.Dropout(p=dropout_rate)
        
        if self.freeze:
            self._freeze_model()

    def _freeze_model(self):
        for param in self.text_encoder.parameters():
            param.requires_grad = False 
        
    def forward(self, inp_ids, mask, tok_type_ids):
        # embed NL instructions
        text_enc = self.text_encoder(
            input_ids=inp_ids,
            attention_mask=mask,
            token_type_ids=tok_type_ids
        ).pooler_output
        
        # print(text_enc.shape)
        text_enc = self.dropout(text_enc)
        
        return text_enc

In [72]:
te = TextEncoder(freeze=True).cuda()
te._freeze_model()
summary(model=te, col_names=["num_params", "trainable"])

Layer (type:depth-idx)                                  Param #                   Trainable
TextEncoder                                             --                        False
├─BertModel: 1-1                                        --                        False
│    └─BertEmbeddings: 2-1                              --                        False
│    │    └─Embedding: 3-1                              (7,813,632)               False
│    │    └─Embedding: 3-2                              (131,072)                 False
│    │    └─Embedding: 3-3                              (512)                     False
│    │    └─LayerNorm: 3-4                              (512)                     False
│    │    └─Dropout: 3-5                                --                        --
│    └─BertEncoder: 2-2                                 --                        False
│    │    └─ModuleList: 3-6                             (3,159,040)               False
│    └─BertPooler: 2-3         

In [37]:
ex["action_desc"]

{'raw': 'put the bowl on top of plate',
 'ids': tensor([ 101, 2404, 1996, 4605, 2006, 2327, 1997, 5127,  102,    0,    0,    0,
            0,    0,    0]),
 'mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]),
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'length': 7}

In [38]:
emb = te(
    inp_ids=ex["action_desc"]["ids"].unsqueeze(0).cuda(),
    mask=ex["action_desc"]["mask"].unsqueeze(0).cuda(),
    tok_type_ids=ex["action_desc"]["token_type_ids"].unsqueeze(0).cuda()
)

In [39]:
emb.shape

torch.Size([1, 256])

#### Film Encoder

In [40]:
class FiLMEncoder(nn.Module):
    """
        Adapted from: https://github.com/caffeinism/FiLM-pytorch/blob/master/networks.py
    """
    def __init__(
        self,
        n_res_blocks:int=3,
        out_dim:int=256,
        n_channels:int=512,
        dim_description:int=32,
        arch:str="resnet18"
    ):
        super(FiLMEncoder, self).__init__()
        
        if arch in ["resnet18", "resnet34"]:
            n_channels = 512
        elif "resnet" in arch:
            n_channels = 2048
        elif "convnext" in arch:
            n_channels = 768

        self.dim_description = dim_description
        self.film_generator = nn.Linear(self.dim_description, 2 * n_res_blocks * n_channels)
        self.feature_extractor = FeatureExtractor(arch=arch)
        self.res_blocks = nn.ModuleList()

        for _ in range(n_res_blocks):
            self.res_blocks.append(ResBlock(n_channels + 2, n_channels))

        # self.head = Head(n_channels, out_dim)

        self.n_res_blocks = n_res_blocks
        self.n_channels = n_channels

    def forward(self, x, description):
        batch_size = x.size(0)

        x = self.feature_extractor(x)
        film_vector = self.film_generator(description).view(
            batch_size, self.n_res_blocks, 2, self.n_channels)
        
        d = x.size(-1)
        coordinate = torch.arange(-1, 1 + 0.00001, 2 / (d-1)).cuda()
        coordinate_x = coordinate.expand(batch_size, 1, d, d)
        coordinate_y = coordinate.view(d, 1).expand(batch_size, 1, d, d)
        # print(f"x.shape: {x.shape} - coordinate_x.shape: {coordinate_x.shape} - coordinate_y: {coordinate_y.shape}")

        for i, res_block in enumerate(self.res_blocks):
            beta = film_vector[:, i, 0, :]
            gamma = film_vector[:, i, 1, :]

            x = torch.cat([x, coordinate_x, coordinate_y], 1)
            x = res_block(x, beta, gamma)
        
        print("pre-classifier: ", x.shape)
        feats = x #self.head(x, return_feats=False)

        return feats

In [41]:
film_encoder = FiLMEncoder(
    arch="resnet18",
    n_res_blocks=2,
    dim_description=256 # (emb) action_description_size
).cuda()

# print(film_encoder)
summary(model=film_encoder)

Layer (type:depth-idx)                             Param #
FiLMEncoder                                        --
├─Linear: 1-1                                      526,336
├─FeatureExtractor: 1-2                            --
│    └─ResNet: 2-1                                 --
│    │    └─Conv2d: 3-1                            9,408
│    │    └─BatchNorm2d: 3-2                       128
│    │    └─ReLU: 3-3                              --
│    │    └─MaxPool2d: 3-4                         --
│    │    └─Sequential: 3-5                        147,968
│    │    └─Sequential: 3-6                        525,568
│    │    └─Sequential: 3-7                        2,099,712
│    │    └─Sequential: 3-8                        8,393,728
│    │    └─AdaptiveAvgPool2d: 3-9                 --
│    │    └─Linear: 3-10                           513,000
│    └─Sequential: 2-2                             11,176,512
│    │    └─Conv2d: 3-11                           (recursive)
│    │    └─BatchNorm2

In [42]:
out = film_encoder(
    x= ex["in_state"].unsqueeze(0).cuda(),
    description= emb
)

out.shape

pre-classifier:  torch.Size([1, 512, 7, 7])


torch.Size([1, 512, 7, 7])

In [43]:
vl_conv = conv(512, 1, 1, 1, 0).cuda()
print(vl_conv)
summary(model=vl_conv)

Sequential(
  (0): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
  (1): GELU(approximate='none')
  (2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)


Layer (type:depth-idx)                   Param #
Sequential                               --
├─Conv2d: 1-1                            513
├─GELU: 1-2                              --
├─BatchNorm2d: 1-3                       2
Total params: 515
Trainable params: 515
Non-trainable params: 0

In [44]:
img_tokens = vl_conv(out)

img_tokens.shape

torch.Size([1, 1, 7, 7])

#### Token Learner

In [45]:
class TokenLearner(nn.Module):
    """
        TokenLearner version 1.1
        MLP (2 dense layers with gelu) for generating attention map
    """
    def __init__(
        self,
        *,
        dim:int=512,
        ff_mult = 2,
        num_output_tokens = 8,
        num_layers = 2
    ):
        super().__init__()
        inner_dim = dim * ff_mult * num_output_tokens

        self.num_output_tokens = num_output_tokens
        self.net = nn.Sequential(
            nn.Conv2d(dim * num_output_tokens, inner_dim, 1, groups = num_output_tokens),
            nn.GELU(),
            nn.Conv2d(inner_dim, num_output_tokens, 1, groups = num_output_tokens),
        )

    def forward(self, x):
        x, ps = pack_one(x, '* c h w')
        x = repeat(x, 'b c h w -> b (g c) h w', g = self.num_output_tokens)
        attn = self.net(x)

        attn = rearrange(attn, 'b g h w -> b 1 g h w')
        x = rearrange(x, 'b (g c) h w -> b c g h w', g = self.num_output_tokens)

        x = reduce(x * attn, 'b c g h w -> b c g', 'mean')
        x = unpack_one(x, ps, '* c n')
        return x

In [46]:
tokL = TokenLearner()
summary(model=tokL)

Layer (type:depth-idx)                   Param #
TokenLearner                             --
├─Sequential: 1-1                        --
│    └─Conv2d: 2-1                       4,284,896
│    └─GELU: 2-2                         --
│    └─Conv2d: 2-3                       8,280
Total params: 4,293,176
Trainable params: 4,293,176
Non-trainable params: 0

#### RT-1 Encoder

In [74]:
class RT1Encoder(nn.Module):
    def __init__(
        self,
        cnn_bacnbone:str="resnet18"
    ):
        super().__init__()
        
        # Text encoder
        self.text_encoder = TextEncoder()
        
        # Image encoder
        self.image_encoder = FiLMEncoder()
        
        # Vision-Language tokens extractor
        self.vl_conv = conv(self.image_encoder.n_channels, 1, 1, 1, 0)
        
        # Token Learner
        self.token_learner = TokenLearner()
        
        # Transformer decoder
        self.transformer = TransformerDecoder()

    def forward(self, input_ids, attn_mask, token_type_ids, imgs):
        """
        """
        text_enc = self.text_encoder(
            inp_ids=input_ids,
            mask=attn_mask,
            tok_type_ids=token_type_ids
        )
        
        # Generage image tokens
        img_tokens = self.image_encoder(
            x= imgs,
            description= text_enc
        )
        
        # Vision-Language tokens extractor
        img_tokens = self.vl_conv(img_tokens)
        
        # Extract learned tokens
        learned_tokens  = self.tok(img_tokens)
        
        return learned_tokens
        

In [81]:
encoder = RT1Encoder()
summary(model=encoder, col_names=["num_params", "trainable"])

Layer (type:depth-idx)                                       Param #                   Trainable
RT1Encoder                                                   --                        Partial
├─TextEncoder: 1-1                                           --                        False
│    └─BertModel: 2-1                                        --                        False
│    │    └─BertEmbeddings: 3-1                              (7,945,728)               False
│    │    └─BertEncoder: 3-2                                 (3,159,040)               False
│    │    └─BertPooler: 3-3                                  (65,792)                  False
│    └─Dropout: 2-2                                          --                        --
├─FiLMEncoder: 1-2                                           --                        Partial
│    └─Linear: 2-3                                           101,376                   True
│    └─FeatureExtractor: 2-4                                 --   

### RT-1 Decoder

#### Transformer Decoder

In [52]:
dec = TransformerDecoder()
summary(model=dec)

Layer (type:depth-idx)                   Param #
TransformerDecoder                       --
├─ModuleList: 1-1                        --
│    └─TransformerDecoderLayer: 2-1      --
│    │    └─LayerNormalization: 3-1      1,024
│    │    └─MultiheadAttention: 3-2      1,050,624
│    │    └─Dropout: 3-3                 --
│    │    └─LayerNormalization: 3-4      1,024
│    │    └─MultiheadAttention: 3-5      1,050,624
│    │    └─Dropout: 3-6                 --
│    │    └─LayerNormalization: 3-7      1,024
│    │    └─FeedFowardLayer: 3-8         1,050,112
│    │    └─Dropout: 3-9                 --
│    └─TransformerDecoderLayer: 2-2      --
│    │    └─LayerNormalization: 3-10     1,024
│    │    └─MultiheadAttention: 3-11     1,050,624
│    │    └─Dropout: 3-12                --
│    │    └─LayerNormalization: 3-13     1,024
│    │    └─MultiheadAttention: 3-14     1,050,624
│    │    └─Dropout: 3-15                --
│    │    └─LayerNormalization: 3-16     1,024
│    │    └─FeedFo

#### Action Generator

In [None]:
class ActionGenerator(nn.Module):
    def __init__(
        self
    ):
        super().__init__()
    
    def forward(self, tokens):
        pass

#### Decoder

In [None]:
class RT1Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.positional_encoder = PositionalEncoder()
        self.transformer = TransformerDecoder()
        # 
        self.action_generator = ActionGenerator()
        

    def _positional_encoding(
        self,
        seq, 
        dim, 
        temperature = 10000, 
        device = None, 
        dtype = torch.float32
    ):
        n = torch.arange(seq, device = device)
        omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
        omega = 1. / (temperature ** omega)

        n = n[:, None] * omega[None, :]
        pos_emb = torch.cat((n.sin(), n.cos()), dim = 1)
        
        return pos_emb.type(dtype)

    
    def forward(self, instructions, imgs):
        pass

### RT-1 

In [None]:
class RT1(pl.LightningModule):
    def __init__(
        self
    ):
        super().__init__()
        self.encoder = RT1Encder()
        self.decoder = RT1Decoder()
        
    def forward(self, imgs, instruction):
        pass
    
    def configure_optimizers(self):
        pass
    
    def training_step(self, batch, batch_idx):
        pass
    
    def validation_step(self, batch, batch_idx):
        pass
    
    def test_step(self, batch, batch_idx):
        pass
    
    def compute_loss(self, outputs, targets):
        pass