In [1]:
!nvidia-smi

Sat Jan  6 08:17:56 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000000:89:00.0 Off |                    0 |
| N/A   26C    P0              39W / 300W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
import warnings
warnings.filterwarnings(action="ignore")

In [16]:
import lightning.pytorch as pl
from lightning.pytorch import Trainer, seed_everything
import logging
logging.basicConfig(level="INFO")

import os

import pandas as pd
from pprint import pprint


from rich.progress import Progress

import sys
sys.path.append("../src")

import time

from tqdm.auto import tqdm
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torchinfo import summary

import wandb


In [4]:
%load_ext autoreload
%autoreload 2

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["WANDB__SERVICE_WAIT"] = "300"

In [22]:
import config

from dataloader import BEDataModule

from rt1 import RTCRAM
import utils.model_utils as model_utils

## Build data module


In [10]:
dm = BEDataModule()
dm.setup()

INFO:root:Training on 4054 samples.
INFO:root:Validating on 454 samples.
INFO:root:Testing on 249 samples.


Total # examples: 4757


## Build model


In [15]:
# rt1 = RTCRAM(
#     cnn_bacnbone=config.SELECTED_CNN_BACKBONE, 
#     num_res_blocks=config.NUM_RES_BLOCKS,
#     freeze_cnn_backbone=config.FREEZE_CNN,
#     args=None
# ).cuda()

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/efficientnet_b3.ra2_in1k)
INFO:timm.models._hub:[timm/efficientnet_b3.ra2_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


In [16]:
# print(rt1)
# summary(model=rt1)

Layer (type:depth-idx)                                            Param #
RTCRAM                                                            --
├─RTEncoder: 1-1                                                  --
│    └─TextEncoder: 2-1                                           --
│    │    └─BertModel: 3-1                                        (28,763,648)
│    │    └─Dropout: 3-2                                          --
│    └─FiLMEncoder: 2-2                                           --
│    │    └─ImageFeatureExtractor: 3-3                            10,300,456
│    │    └─ModuleList: 3-4                                       6,340,608
│    └─TokenLearnerV11: 2-3                                       --
│    │    └─Sequential: 3-5                                       134,408
├─RTDecoder: 1-2                                                  --
│    └─TransformerDecoder: 2-4                                    --
│    │    └─EmbeddingLayer: 3-6                                   53

## Training config

In [17]:
# loss_fn = nn.CrossEntropyLoss(
#     ignore_index=config.TGT_PAD_TOK_ID, 
#     label_smoothing=config.LABEL_SMOOTHING
# )
# opt = getattr(torch.optim, config.OPTIMIZER)(
#     params=[p for p in rt1.parameters() if p.requires_grad], 
#     lr=config.LR,
#     weight_decay=config.WEIGHT_DECAY
# )

# scheduler = getattr(lr_scheduler, config.LR_SCHEDULER["type"])(**config.LR_SCHEDULER["params"], optimizer=opt)

In [18]:
# out = model_utils.validation_step(model=rt1, batch=batch, loss_fn=loss_fn)

In [19]:
# out.keys()

In [20]:
# preds= rt1.decoder.action_generator(out["logits"]).argmax(-1)
# preds

In [21]:
# rt1.decode_predictions(preds)

## Run experiment

In [22]:
dm = BEDataModule()

run = wandb.init(
    dir='../',
    project='SMF-Be', 
    group="RT1-CRAM", 
    name="be_model", 
    reinit=True
)

trained_model = model_utils.run_experiment(
    model=rt1, 
    dm=dm, 
    opt=opt, 
    loss_fn=loss_fn,
    scheduler=scheduler,
    resume_training=True,
    epoch_resume=60
)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdric225[0m ([33mjepsam-s23[0m). Use [1m`wandb login --relogin`[0m to force relogin


INFO:root:Training on 4052 samples.
INFO:root:Validating on 456 samples.
INFO:root:Testing on 250 samples.


Total # examples: 4758
Loading model from checkpoint...
Loading model from checkpoint...Complete!


  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

  0%|                                                                                                         …

## Test/Inference Procedure

In [7]:
config.TEST_DEVICE, config.DEVICE

('cpu', 'cuda')

In [18]:
!nvidia-smi

Sat Jan  6 08:26:26 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000000:89:00.0 Off |                    0 |
| N/A   29C    P0              53W / 300W |   1105MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [27]:
out = model_utils.inference_step(
    test_loader=dm.test_dataloader(), 
    mode="eval", 
    device=config.DEVICE
)
out.keys()

INFO:root:Loading model from checkpoint...
INFO:root:Creating instance of RTCRAM...
INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/efficientnet_b3.ra2_in1k)
INFO:timm.models._hub:[timm/efficientnet_b3.ra2_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
INFO:root:Preparing checkpoint...
INFO:root:loading model state dict...
INFO:root:Loading model from checkpoint...Complete!
INFO:root:Running inference now...


Running inference:   0%|          | 0/4 [00:00<?, ?it/s]

Generating motor commands:   0%|          | 0/64 [00:00<?, ?it/s]

Generating motor commands:   0%|          | 0/64 [00:00<?, ?it/s]

Generating motor commands:   0%|          | 0/64 [00:00<?, ?it/s]

Generating motor commands:   0%|          | 0/64 [00:00<?, ?it/s]

**** Evaluatiion Report *****
> Test Lev. distance	: 8.4180
> Success Rate		: 43.7500%
**** Evaluatiion Report *****


Index(['prediction', 'label', 'correct', 'distance'], dtype='object')

In [28]:
out

Unnamed: 0,prediction,label,correct,distance
0,:FORK RED POSE-13 :FORK BLUE POSE-11 :FORK #'*...,:FORK RED POSE-13 :MONDAMIN BLUE POSE-11 :FORK...,0.0,14
1,:BREAKFAST-CEREAL BLUE POSE-3 :BREAKFAST-CEREA...,:BREAKFAST-CEREAL BLUE POSE-3 :BREAKFAST-CEREA...,1.0,0
2,:CEREAL BLUE POSE-11 :CEREAL #'*forward-transf...,:CEREAL BLUE POSE-11 :CEREAL #'*forward-transf...,1.0,0
3,:BREAKFAST-CEREAL RED POSE-7 :PLATE GREEN POSE...,:BREAKFAST-CEREAL GREEN POSE-7 :PLATE RED POSE...,0.0,6
4,:KNIFE BLUE POSE-10 :KNIFE GREEN POSE-3 :KNIFE...,:KNIFE GREEN POSE-2 :KNIFE #'*forward-transfor...,0.0,21
...,...,...,...,...
251,:SPOON RED POSE-9 :CAP BLUE POSE-1 :SPOON #'*f...,:SPOON RED POSE-9 :CAP BLUE POSE-1 :SPOON #'*f...,1.0,0
252,:CUP RED POSE-6 :BOTTLE BLUE POSE-2 :CUP #'*ri...,:CUP RED POSE-6 :CAP BLUE POSE-2 :CUP #'*right...,0.0,7
253,RED POSE-6 :CUP RED POSE-1 :CUP #'*forward-tra...,:CUP RED POSE-6 :RED-METAL-PLATE RED POSE-1 :C...,0.0,33
254,:CEREAL RED POSE-11 :CEREAL #'*leftward-transf...,:CEREAL RED POSE-11 :CEREAL #'*leftward-transf...,1.0,0
