## 🤔 Device check

In [1]:
!nvidia-smi

Wed Nov 22 04:39:14 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.60.13    Driver Version: 525.60.13    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  On   | 00000000:B3:00.0 Off |                    0 |
| N/A   28C    P0    40W / 300W |      0MiB / 32768MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## 🛠️ Packages

In [2]:
import warnings
warnings.filterwarnings(action="ignore")

In [3]:

import lightning.pytorch as pl
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import RichProgressBar, TQDMProgressBar, ModelCheckpoint
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

import math

import logging
logging.basicConfig(level="INFO")

import math
from matplotlib import pyplot
%matplotlib inline

import numpy as np

import os

import pandas as pd
from pprint import pprint

import random

import sys
sys.path.append("../src")

import time
import timm
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary

import wandb

In [4]:
%load_ext autoreload
%autoreload 2

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["WANDB__SERVICE_WAIT"] = "300"

In [5]:
import config
from dataloader import BEDataset, BEDataModule
from transformer import generate_causal_attention_mask

from rt1 import RT1CRAM
from utils.model_utils import plot_attention, fetch_sample_from_batch
import utils.data_utils as data_utils

## 📊 Data Module

In [6]:
dm = BEDataModule()
dm.setup()

INFO:root:Training on 4049 samples.
INFO:root:Validating on 459 samples.
INFO:root:Testing on 250 samples.


Total # examples: 4758


In [7]:
%%time
batch = next(iter(dm.train_dataloader()))
print(batch.keys())
batch["in_state"].shape

dict_keys(['sample_id', 'in_state', 'action_desc', 'source_mask_tokens', 'source_mask', 'motor_cmd', 'target_mask'])
CPU times: user 1.64 s, sys: 2.11 s, total: 3.75 s
Wall time: 32.9 s


torch.Size([32, 3, 288, 288])

## 🤖 RT1-CRAM

In [8]:
rt1 = RT1CRAM(
    cnn_bacnbone=config.SELECTED_CNN_BACKBONE, 
    num_res_blocks=config.NUM_RES_BLOCKS,
    freeze_cnn_backbone=config.FREEZE_CNN
).cuda()
# print(rt1)

summary(model=rt1)

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/efficientnet_b3.ra2_in1k)
INFO:timm.models._hub:[timm/efficientnet_b3.ra2_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


Layer (type:depth-idx)                                            Param #
RT1CRAM                                                           --
├─RT1Encoder: 1-1                                                 --
│    └─TextEncoder: 2-1                                           --
│    │    └─BertModel: 3-1                                        (28,763,648)
│    │    └─Dropout: 3-2                                          --
│    └─FiLMEncoder: 2-2                                           --
│    │    └─ImageFeatureExtractor: 3-3                            10,300,456
│    │    └─ModuleList: 3-4                                       3,170,304
│    └─TokenLearnerV11: 2-3                                       --
│    │    └─Sequential: 3-5                                       134,408
├─RT1Decoder: 1-2                                                 --
│    └─Embedding: 2-4                                             26,624
│    └─TransformerDecoder: 2-5                                  

## 🏋️‍ Training

In [9]:
loss_fn = nn.CrossEntropyLoss(
    ignore_index=config.TGT_PAD_TOK_ID, 
    label_smoothing=config.LABEL_SMOOTHING
)

opt = getattr(torch.optim, config.OPTIMIZER)(
    params=[p for p in rt1.parameters() if p.requires_grad], 
    lr=config.LR,
    weight_decay=config.WEIGHT_DECAY
)

scheduler = getattr(torch.optim.lr_scheduler, config.LR_SCHEDULER["type"])(**config.LR_SCHEDULER["params"], optimizer=opt)

opt

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0003
    maximize: False
    weight_decay: 2e-06
)

In [10]:
def training_step(model, batch, loss_fn):

    input_ids=batch["action_desc"]["ids"].to(config.DEVICE)
    attn_mask=batch["action_desc"]["mask"].to(config.DEVICE)
    token_type_ids=batch["action_desc"]["token_type_ids"].to(config.DEVICE)
    imgs=batch["in_state"].to(config.DEVICE)
    decoder_inp=batch["motor_cmd"]["decoder_inp_ids"].to(config.DEVICE)
    src_mask=(batch["source_mask"].to(config.DEVICE), batch["source_mask_tokens"].to(config.DEVICE))
    target_mask=batch["target_mask"].to(config.DEVICE)
    
    # forward
    logits, self_attn_ws, cross_attn_ws_seq, cross_attn_ws_tokens = model(
        input_ids=input_ids, 
        attn_mask=attn_mask, 
        token_type_ids=token_type_ids, 
        imgs=imgs,
        decoder_inp=decoder_inp, 
        src_mask=src_mask, 
        target_mask=target_mask 
    )

    # loss computation
    labels = batch["motor_cmd"]["labels"].to(config.DEVICE)
    loss = loss_fn(logits.view(-1, logits.shape[2]), labels.view(-1))
        
    return loss, logits, self_attn_ws, cross_attn_ws_seq, cross_attn_ws_tokens

## 🛟 Greedy decoding

In [11]:
def greedy_decoding(
    model:pl.LightningModule, 
    batch_inp:dict, 
    max_len:int=config.MAX_OUT_SEQ_LEN, 
    debug:bool=False
):
    if model.device.type == "cpu":
        model.to(config.DEVICE)
    model.eval()
    
    sos_token = config.TARGETS_MAPPING["[SOS]"]
    eos_token = config.TARGETS_MAPPING["[EOS]"]
    
    input_ids=batch_inp["ids"].to(config.DEVICE)
    attn_mask=batch_inp["mask"].to(config.DEVICE)
    token_type_ids=batch_inp["token_type_ids"].to(config.DEVICE)
    imgs=batch_inp["in_state"].to(config.DEVICE)
    src_mask=(
        batch_inp["source_mask"].to(config.DEVICE), 
        batch_inp["source_mask_tokens"].to(config.DEVICE)
    )

    text_enc_last_h, learned_tokens = model._encode(
        input_ids=input_ids, 
        attn_mask=attn_mask, 
        token_type_ids=token_type_ids, 
        imgs=imgs    
    )
    
    decoder_inp = torch.empty(1, 1, dtype=torch.long, device=input_ids.device).fill_(sos_token)

    # decoding procedure
    for t in range(max_len):
        
        decoder_mask = generate_causal_attention_mask(
            dim=decoder_inp.shape[1]
        ).type_as(attn_mask)
        
        # generate predictions
        with torch.no_grad():
            logits, self_attn_ws, cross_attn_ws_seq, cross_attn_ws_tokens = model._decode(
            decoder_inp=decoder_inp, 
            encoder_outs=(text_enc_last_h, learned_tokens), 
            src_mask=src_mask, 
            target_mask=decoder_mask,
            debug=debug,
            return_actions=False
        )

        # perform greedy decoding
        probs = model.decoder.action_generator(logits[:, -1])
            
        _, next_tok = torch.max(probs, dim=-1)
        # update decoder input
        decoder_inp = torch.cat((decoder_inp, next_tok.unsqueeze(1)), dim=1)
            
    return decoder_inp[:, 1:].cpu().detach(), logits, self_attn_ws.cpu().detach(), cross_attn_ws_seq.cpu().detach(), cross_attn_ws_tokens.cpu().detach()

In [12]:
def validation_step(batch, model, loss_fn, debug:bool=False):
    inp = fetch_sample_from_batch(
        batch, 
        batch_size=batch["in_state"].shape[0],
        random=True
    )
    
    pred_ids, logits, self_attn_ws, cross_attn_ws_seq, cross_attn_ws_tokens = greedy_decoding(
        model=model, 
        batch_inp=inp, 
        debug=debug
    )
    
    labels = inp["labels"].to(config.DEVICE)
    
    preds = model.decode_predictions(
            predicted_ids=pred_ids
    )[0]

    label = model.decode_predictions(
        predicted_ids=labels
    )[0]  
    
    # compute metrics
    val_loss = loss_fn(logits.view(-1, logits.shape[2]), labels.view(-1)).item()  # loss
    cer = model.cer_fn(preds, label).item() # Character Error Rate
    wer = model.wer_fn(preds, label).item() # Word Error Rate
    
    output = {
        "val_loss"              : val_loss,
        "CER"                   : cer,
        "WER"                   : wer,
        "label"                 : label,
        "pred_ids"              : pred_ids,
        "pred_tokens"           : preds,
        "self_attn_ws"          : self_attn_ws, 
        "cross_attn_ws_seq"     : cross_attn_ws_seq, 
        "cross_attn_ws_tokens"  : cross_attn_ws_tokens
    }
    
    return output

In [13]:
out = validation_step(model=rt1, batch=batch, loss_fn=loss_fn)
out

{'val_loss': 6.728468418121338,
 'CER': 1.0,
 'WER': 1.7777777910232544,
 'label': ":MILK RED POSE-8 :CAP RED POSE-1 :MILK #'*leftward-transformation* :CAP",
 'pred_ids': tensor([[40, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46]]),
 'pred_tokens': 'POSE-2 BLUE BLUE BLUE BLUE BLUE BLUE BLUE BLUE BLUE BLUE BLUE BLUE BLUE BLUE BLUE',
 'self_attn_ws': tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           [0.0000e+00, 0.0000e+00, 5.8140e-16,  ..., 0.0000e+00,
            2.0092e-37, 0.0000e+00],
           [0.0000e+00, 0.0000e+00, 4.3933e-23,  ..., 1.8333e-12,
            3.6262e-28, 1.0000e+00],
           ...,
           [0.0000e+00, 0.0000e+00, 2.9587e-21,  ..., 1.3839e-11,
            4.5463e-30, 1.0000e+00],
           [0.0000e+00, 0.0000e+00, 4.1969e-22,  ..., 4.4405e-12,
            1.2085e-30, 1.0000e+00],
           [0.0000e+00, 0.0000e+00, 1.2594e-21,  ..., 1.4109e-13,
            3.9368e-27, 1.0000e+00]],
 

## 🧑🏾‍🍳 Prepare Experiment

In [14]:
def run_experiment(model, dm, opt, loss_fn, scheduler):
    
    loss_epoch = np.inf
    val_loss = np.inf
    best_val_loss = np.inf
    
    cer_ = np.inf
    wer_ = np.inf
    
    for e in range(config.EPOCHS):        
        running_loss = 0.
        num_steps = len(dm.train_dataloader())
        
        pbar = tqdm(
            range(num_steps),
            position=0,
            leave=True,
            dynamic_ncols=True,
            total = num_steps
        )
        
        # training
        model.train()
        for step, batch in enumerate(dm.train_dataloader()):            
            pct = 100. * step / num_steps
            pbar.set_description(
                f"Epoch {e+1}/{config.EPOCHS} - (Train {pct:.1f}%)"
            )
            pbar.update()
            
            opt.zero_grad()

            # training step
            loss, logits, self_attn_ws, cross_attn_ws_seq, _ = training_step(
                model=model, 
                batch=batch, 
                loss_fn=loss_fn
            )
            
            # plot attention weights
            plot_attention(
                self_attn_ws, 
                show=False, 
                pre_fix="train_selfattn", 
                folder="train",
                epoch=e,
                wandb_logging=True
            )

            plot_attention(
                cross_attn_ws_seq,
                kind="cross", 
                pre_fix="train_crossattn", 
                show=False, 
                folder="train",
                epoch=e,
                wandb_logging=True
            )   
            
            running_loss += loss.item()         
            
            # logging
            if step % 10 == 0:
                pbar.set_postfix(
                    train_loss_step="{:.04f}".format(running_loss/(step+1)),
                    train_loss="{:.04f}".format(loss_epoch),
                    CER="{:.04f}".format(cer_),
                    WER="{:.04f}".format(wer_),
                    val_loss="{:.04f}".format(val_loss),
                )
                pbar.update()

            # backward
            loss.backward()
            
            # Adjust learning weights
            opt.step()
            
        loss_epoch = running_loss / len(dm.train_dataloader())   
        final_lr_epoch = float(opt.param_groups[0]['lr'])
        
        # predictions
        preds = logits.softmax(dim=-1).argmax(dim=-1)

        # decode predictions
        preds = model.decode_predictions(
            predicted_ids=preds
        )

        labels = model.decode_predictions(
            predicted_ids=batch["motor_cmd"]["labels"]
        )         
            
        # log decoded sentenses
        with open(config.LOGGING_FILE, "a") as f:            
            f.write(f"Epoch #{e+1}\n")
            f.write(f"[Train] \n")
            
            pred = preds[0]
            label = labels[0]
            
            cer_ = model.cer_fn(pred, label).item()
            wer_ = model.wer_fn(pred, label).item()
            f.write(f"Predicted \t: {pred}\n")
            f.write(f"Actual \t\t: {label}\n")
                
        # validation
        out = validation_step(model=rt1, batch=batch, loss_fn=loss_fn)
        val_loss = out["val_loss"]
        
        # start scheduling lr after epoch X
        # X set to 30 to start us of
        if e >=30:
            scheduler.step(val_loss)
       
        # plot attention weights
        plot_attention(
            out["self_attn_ws"], 
            show=False, 
            pre_fix="val_selfattn", 
            folder="val",
            epoch=e,
            wandb_logging=True
        )

        plot_attention(
            out["cross_attn_ws_seq"],
            kind="cross", 
            pre_fix="val_crossattn", 
            show=False, 
            folder="val",
            epoch=e,
            wandb_logging=True
        )   

        # plot_attention(
        #     out["cross_attn_ws_tokens"], 
        #     pre_fix="val_crossattn_tokens", 
        #     show=False, 
        #     folder="val",
        #     epoch=e,
        #     wandb_logging=True
        # )   
        
        # update best score
        if val_loss < best_val_loss:
            # save checkpoint
            path = os.path.join(config.MODEL_PATH, "be_model.bin")
            torch.save({
                'model_state_dict'      :model.state_dict(),
                'optimizer_state_dict'  :opt.state_dict(),
                'val_loss'              : val_loss, 
                'epoch'                 : e
                }, path)
            
            # update best score
            best_val_loss = val_loss        
        
        pbar.set_postfix(
            train_loss_step="{:.04f}".format(running_loss/(step+1)),
            train_loss="{:.04f}".format(loss_epoch),
            # CER="{:.04f}".format(cer_),
            # WER="{:.04f}".format(wer_),
            val_Loss="{:.04f}".format(val_loss),
            val_CER="{:.04f}".format(out["CER"]),
            val_WER="{:.04f}".format(out["WER"]),
            lr_epoch="{:.1e}".format(final_lr_epoch),
        )  
        pbar.update()
        
        logs_dict = {
            "epoch" :e,
            "train_loss":loss_epoch,
            "val_loss":val_loss,
            "val_CER":out["CER"],
            "valWER":out["WER"],
            "lr":final_lr_epoch
        }
        wandb.log(logs_dict)
        
        # log decoded sentenses
        with open(config.LOGGING_FILE, "a") as f:                        
            pred = out["pred_tokens"]
            label = out["label"]
            
            f.write(f"[Val] \n")            
            f.write(f"Predicted \t: {pred}\n")
            f.write(f"Actual \t\t: {label}\n") 
            f.write(f"Curr val loss \t\t: {val_loss:.5f}\n") 
            f.write(f"Best loss: \t\t: {best_val_loss:.5f}\n\n") 
            
        pbar.close()
        torch.cuda.empty_cache()
        
    return model

## 🚀 Run Experiment

In [15]:
# Set seed

random.seed(config.SEED)
np.random.seed(config.SEED)
torch.manual_seed(config.SEED)

if torch.cuda.is_available(): 
    torch.cuda.manual_seed(config.SEED)
    torch.cuda.manual_seed_all(config.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
##### init experiment
run = wandb.init(
    project='SMF-Be', 
    group="RT1-CRAM", 
    name="be_model", 
    reinit=True
)

with open(config.LOGGING_FILE, "a") as f:   
    f.write("*** New experiment ***\n")
    
    
trained_model = run_experiment(
    model=rt1, 
    dm=dm, 
    opt=opt, 
    loss_fn=loss_fn,
    scheduler=scheduler
)


wandb.finish()


ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdric225[0m ([33mjepsam-s23[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

In [None]:
# wandb.finish()

## 👨🏿‍🔬 Test / Inference

In [None]:
# input_ids=batch["action_desc"]["ids"].cuda()
# attn_mask=batch["action_desc"]["mask"].cuda()
# token_type_ids=batch["action_desc"]["token_type_ids"].cuda()
# imgs=batch["in_state"].cuda()
# decoder_inp=batch["motor_cmd"]["decoder_inp_ids"].cuda()
# src_mask=(batch["source_mask"].cuda(), batch["source_mask_tokens"].cuda())
# target_mask=batch["target_mask"].cuda()
# labels = batch["motor_cmd"]["labels"].cuda()