## 🤔 Device check

## 🛠️ Packages

In [1]:
!nvidia-smi

Mon Dec 25 09:32:09 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000000:8A:00.0 Off |                    0 |
| N/A   27C    P0              39W / 300W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
import warnings
warnings.filterwarnings(action="ignore")

In [3]:

import lightning.pytorch as pl
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import RichProgressBar, TQDMProgressBar, ModelCheckpoint
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

import Levenshtein

import math

import logging
logging.basicConfig(level="INFO")

import math
from matplotlib import pyplot
%matplotlib inline

import numpy as np

import os

import pandas as pd
from pprint import pprint

import random

import sys
sys.path.append("../src")

import time
import timm
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary

import wandb

In [4]:
%load_ext autoreload
%autoreload 2

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["WANDB__SERVICE_WAIT"] = "300"

In [5]:
import config
from dataloader import BEDataset, BEDataModule
from transformer import make_attn_mask

from rt1 import RTCRAM
from utils.model_utils import plot_attention, fetch_sample_from_batch, calc_edit_distance
import utils.data_utils as data_utils

## 📊 Data Module

In [6]:
dm = BEDataModule()
dm.setup()

INFO:root:Training on 3848 samples.
INFO:root:Validating on 660 samples.
INFO:root:Testing on 250 samples.


Total # examples: 4758


In [7]:
%%time
batch = next(iter(dm.train_dataloader()))
print(batch.keys())
batch["in_state"].shape

dict_keys(['sample_id', 'in_state', 'action_desc', 'motor_cmd'])
CPU times: user 575 ms, sys: 636 ms, total: 1.21 s
Wall time: 41.1 s


torch.Size([128, 3, 288, 288])

In [8]:
inp = fetch_sample_from_batch(
    batch, 
    batch_size=batch["in_state"].shape[0],
    random=True
)

In [9]:
inp.keys()

dict_keys(['sample_id', 'in_state', 'raw_action_desc', 'ids', 'mask', 'token_type_ids', 'raw_motor_cmd', 'decoder_inp_ids', 'labels'])

In [10]:
dm.train_ds.tokenizer.decode(batch["action_desc"]["ids"][0])

'[CLS] put the cereal in front of mug [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [11]:
batch["action_desc"]["raw"][0]

'put the cereal in front of mug'

## 🤖 RT-CRAM

In [12]:
rt1 = RTCRAM(
    cnn_bacnbone=config.SELECTED_CNN_BACKBONE, 
    num_res_blocks=config.NUM_RES_BLOCKS,
    freeze_cnn_backbone=config.FREEZE_CNN
).cuda()
# print(rt1)

summary(model=rt1)

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/efficientnet_b3.ra2_in1k)
INFO:timm.models._hub:[timm/efficientnet_b3.ra2_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


Layer (type:depth-idx)                                            Param #
RTCRAM                                                            --
├─RTEncoder: 1-1                                                  --
│    └─TextEncoder: 2-1                                           --
│    │    └─BertModel: 3-1                                        (28,763,648)
│    │    └─Dropout: 3-2                                          --
│    └─FiLMEncoder: 2-2                                           --
│    │    └─ImageFeatureExtractor: 3-3                            10,300,456
│    │    └─ModuleList: 3-4                                       6,340,608
│    └─TokenLearnerV11: 2-3                                       --
│    │    └─Sequential: 3-5                                       134,408
├─RTDecoder: 1-2                                                  --
│    └─TransformerDecoder: 2-4                                    --
│    │    └─EmbeddingLayer: 3-6                                   53

## 🏋️‍ Training config

In [13]:
loss_fn = nn.CrossEntropyLoss(
    ignore_index=config.TGT_PAD_TOK_ID, 
    label_smoothing=config.LABEL_SMOOTHING
)

opt = getattr(torch.optim, config.OPTIMIZER)(
    params=[p for p in rt1.parameters() if p.requires_grad], 
    lr=config.LR,
    weight_decay=config.WEIGHT_DECAY
)

scheduler = getattr(torch.optim.lr_scheduler, config.LR_SCHEDULER["type"])(**config.LR_SCHEDULER["params"], optimizer=opt)

opt

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 1.0
    lr: 1e-05
    maximize: False
    weight_decay: 2e-06
)

## Training step

In [14]:
def training_step(model, batch, loss_fn):

    input_ids=batch["action_desc"]["ids"].to(config.DEVICE)
    attn_mask=batch["action_desc"]["mask"].to(config.DEVICE)
    token_type_ids=batch["action_desc"]["token_type_ids"].to(config.DEVICE)
    imgs=batch["in_state"].to(config.DEVICE)
    decoder_inp=batch["motor_cmd"]["decoder_inp_ids"].to(config.DEVICE)
    
    # forward
    logits, self_attn_ws, cross_attn_ws = model(
        input_ids=input_ids, 
        attn_mask=attn_mask, 
        token_type_ids=token_type_ids, 
        imgs=imgs,
        decoder_inp=decoder_inp
    )

    # loss computation
    labels = batch["motor_cmd"]["labels"].to(config.DEVICE)
    loss = loss_fn(logits.view(-1, logits.shape[2]), labels.view(-1))
        
    return loss, logits, self_attn_ws, cross_attn_ws

In [15]:
# loss, logits, self_attn_ws, cross_attn_ws = training_step(model=rt1, batch=batch, loss_fn=loss_fn)

In [16]:
# logits.shape

In [17]:
# # predictions
# preds = logits.softmax(dim=-1).argmax(dim=-1)
# preds.shape, preds

In [18]:

# # decode predictions
# preds = rt1.decode_predictions(
#     predicted_ids=preds
# )

# preds[0]

## 🛟 Greedy decoding

In [19]:
def greedy_decoding(
    model:pl.LightningModule, 
    batch_inp:dict, 
    max_len:int=config.MAX_OUT_SEQ_LEN, 
    debug:bool=False
):
    if model.device.type == "cpu":
        model.to(config.DEVICE)
    model.eval()
    
    sos_token = config.TARGETS_MAPPING["[SOS]"]
    eos_token = config.TARGETS_MAPPING["[EOS]"]
    
    input_ids=batch_inp["ids"].to(config.DEVICE)
    attn_mask=batch_inp["mask"].to(config.DEVICE)
    token_type_ids=batch_inp["token_type_ids"].to(config.DEVICE)
    imgs=batch_inp["in_state"].to(config.DEVICE)

    _, learned_tokens = model._encode(
        input_ids=input_ids, 
        attn_mask=attn_mask, 
        token_type_ids=token_type_ids, 
        imgs=imgs    
    )
    
    decoder_inp = torch.empty(1, 1, dtype=torch.long, device=input_ids.device).fill_(sos_token)

    for t in range(config.MAX_OUT_SEQ_LEN):
        mask = make_attn_mask(dim=decoder_inp.shape[1])

        with torch.no_grad():
            logits, self_attn_ws, cross_attn_ws = model._decode(
            decoder_inp=decoder_inp, 
            encoder_out=learned_tokens,
            attn_mask=mask,
            return_actions=False
        )

        # perform greedy decoding
        probs = model.decoder.action_generator(logits[:, -1])

        _, next_tok = torch.max(probs, dim=-1)
            
        # update decoder input
        decoder_inp = torch.cat((decoder_inp, next_tok.unsqueeze(1)), dim=1)
            
    return decoder_inp[:, 1:].cpu().detach(), logits, self_attn_ws.cpu().detach(), cross_attn_ws.cpu().detach()

## Validation Step

In [20]:
def validation_step(batch, model, loss_fn, debug:bool=False):
    
    inp = fetch_sample_from_batch(
        batch, 
        batch_size=batch["in_state"].shape[0],
        random=True
    )
    
    pred_ids, logits, self_attn_ws, cross_attn_ws = greedy_decoding(
        model=model, 
        batch_inp=inp, 
        debug=debug
    )
        
    labels = inp["labels"].to(config.DEVICE)
    
    preds = model.decode_predictions(
            predicted_ids=pred_ids
    )

    label = model.decode_predictions(
        predicted_ids=labels
    )
    
    lev_dist = calc_edit_distance(preds, labels, batch=True)
    
    # compute metrics
    val_loss = loss_fn(logits.view(-1, logits.shape[2]), labels.view(-1)).item()  # loss
    cer = model.cer_fn(preds[0], label[0]).item() # Character Error Rate
    wer = model.wer_fn(preds[0], label[0]).item() # Word Error Rate
    
    output = {
        "val_loss"              : val_loss,
        "CER"                   : cer,
        "WER"                   : wer,
        "label"                 : label[0],
        "pred_ids"              : pred_ids,
        "pred_tokens"           : preds[0],
        "self_attn_ws"          : self_attn_ws, 
        "cross_attn_ws"         : cross_attn_ws,
        "dist": lev_dist
    }
    
    return output

In [21]:
# %%time

# out = validation_step(model=rt1, batch=batch, loss_fn=loss_fn)

# out

## 🧑🏾‍🍳 Prepare Experiment

In [22]:
def run_experiment(model, dm, opt, loss_fn, scheduler):
    
    loss_epoch = np.inf
    val_loss = np.inf
    best_val_loss = np.inf
    
    cer_ = np.inf
    wer_ = np.inf
    
    for e in range(config.EPOCHS):        
        running_loss = 0.
        num_steps = len(dm.train_dataloader())
        
        pbar = tqdm(
            range(num_steps),
            position=0,
            leave=True,
            dynamic_ncols=True,
            total = num_steps
        )
        
        # training
        model.train()
        for step, batch in enumerate(dm.train_dataloader()):            
            pct = 100. * step / num_steps
            pbar.set_description(
                f"Epoch {e+1}/{config.EPOCHS} - (Train {pct:.1f}%)"
            )
            pbar.update()
            
            opt.zero_grad()

            # training step
            loss, logits, self_attn_ws, cross_attn_ws = training_step(
                model=model, 
                batch=batch, 
                loss_fn=loss_fn
            )
            
            # plot attention weights
            plot_attention(
                self_attn_ws, 
                show=False, 
                pre_fix="train_selfattn", 
                folder="train",
                epoch=e,
                wandb_logging=True
            )

            plot_attention(
                cross_attn_ws,
                kind="cross", 
                pre_fix="train_crossattn", 
                show=False, 
                folder="train",
                epoch=e,
                wandb_logging=True
            )   
            
            running_loss += loss.item()         
            
            # logging
            if step % 10 == 0:
                pbar.set_postfix(
                    train_loss_step="{:.04f}".format(running_loss/(step+1)),
                    train_loss="{:.04f}".format(loss_epoch),
                    CER="{:.04f}".format(cer_),
                    WER="{:.04f}".format(wer_),
                    val_loss="{:.04f}".format(val_loss),
                )
                pbar.update()

            # backward
            loss.backward()
            
            # Adjust learning weights
            opt.step()
            
        loss_epoch = running_loss / len(dm.train_dataloader())   
        final_lr_epoch = float(opt.param_groups[0]['lr'])
        
        # predictions
        preds = logits.softmax(dim=-1).argmax(dim=-1)

        # decode predictions
        preds = model.decode_predictions(
            predicted_ids=preds
        )

        labels = model.decode_predictions(
            predicted_ids=batch["motor_cmd"]["labels"]
        )         
            
        # log decoded sentenses
        with open(config.LOGGING_FILE, "a") as f:            
            f.write(f"Epoch #{e+1}\n")
            f.write(f"[Train] \n")
            
            pred = preds[0]
            label = labels[0]
            
            cer_ = model.cer_fn(pred, label).item()
            wer_ = model.wer_fn(pred, label).item()
            f.write(f"Predicted \t: {pred}\n")
            f.write(f"Actual \t\t: {label}\n")
                
        # validation
        val_batch = next(iter(dm.val_dataloader()))
        out = validation_step(model=model, batch=val_batch, loss_fn=loss_fn)
        
        val_loss = out["val_loss"]
        val_dist = out["dist"]
       
        # plot attention weights
        plot_attention(
            out["self_attn_ws"], 
            show=False, 
            pre_fix="val_selfattn", 
            folder="val",
            epoch=e,
            wandb_logging=True
        )

        plot_attention(
            out["cross_attn_ws"],
            kind="cross", 
            pre_fix="val_crossattn", 
            show=False, 
            folder="val",
            epoch=e,
            wandb_logging=True
        )   
        
        # update best score
        if val_loss < best_val_loss:
            # save checkpoint
            path = os.path.join(config.MODEL_PATH, "be_model.bin")
            torch.save({
                'model_state_dict'      :model.state_dict(),
                'optimizer_state_dict'  :opt.state_dict(),
                'val_dist'              : val_dist, 
                'epoch'                 : e
                }, path)
            
            # update best score
            best_val_loss = val_loss        
        
        pbar.set_postfix(
            train_loss_step="{:.04f}".format(running_loss/(step+1)),
            train_loss="{:.04f}".format(loss_epoch),
            # CER="{:.04f}".format(cer_),
            # WER="{:.04f}".format(wer_),
            val_Loss="{:.04f}".format(val_loss),
            val_CER="{:.04f}".format(out["CER"]),
            val_WER="{:.04f}".format(out["WER"]),
            lr_epoch="{:.1e}".format(final_lr_epoch),
        )  
        pbar.update()
        
        logs_dict = {
            "epoch" :e,
            "train_loss":loss_epoch,
            "val_loss":val_loss,
            "val_CER":out["CER"],
            "valWER":out["WER"],
            "lr":final_lr_epoch
        }
        wandb.log(logs_dict)
        
        # log decoded sentenses
        with open(config.LOGGING_FILE, "a") as f:                        
            pred = out["pred_tokens"]
            label = out["label"]
            
            f.write(f"[Val] \n")            
            f.write(f"Predicted \t: {pred}\n")
            f.write(f"Actual \t\t: {label}\n") 
            f.write(f"Curr val loss \t\t: {val_loss:.5f}\n") 
            f.write(f"Val dist \t\t: {val_dist:.5f}\n") 
            f.write(f"Best loss: \t\t: {best_val_loss:.5f}\n\n") 
            
        pbar.close()
        torch.cuda.empty_cache()
        
    return model

## 🚀 Run Experiment

In [23]:
# Set seed

random.seed(config.SEED)
np.random.seed(config.SEED)
torch.manual_seed(config.SEED)

if torch.cuda.is_available(): 
    torch.cuda.manual_seed(config.SEED)
    torch.cuda.manual_seed_all(config.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
##### init experiment
run = wandb.init(
    project='SMF-Be', 
    group="RT1-CRAM", 
    name="be_model", 
    reinit=True
)

with open(config.LOGGING_FILE, "a") as f:   
    f.write("*** New experiment ***\n")
    
    
trained_model = run_experiment(
    model=rt1, 
    dm=dm, 
    opt=opt, 
    loss_fn=loss_fn,
    scheduler=scheduler
)


wandb.finish()


ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdric225[0m ([33mjepsam-s23[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

In [None]:
# wandb.finish()

## 👨🏿‍🔬 Test / Inference

In [None]:
# input_ids=batch["action_desc"]["ids"].cuda()
# attn_mask=batch["action_desc"]["mask"].cuda()
# token_type_ids=batch["action_desc"]["token_type_ids"].cuda()
# imgs=batch["in_state"].cuda()
# decoder_inp=batch["motor_cmd"]["decoder_inp_ids"].cuda()
# src_mask=(batch["source_mask"].cuda(), batch["source_mask_tokens"].cuda())
# target_mask=batch["target_mask"].cuda()
# labels = batch["motor_cmd"]["labels"].cuda()