In [1]:
!nvidia-smi

Tue Jan  2 08:46:18 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000000:15:00.0 Off |                    0 |
| N/A   32C    P0              60W / 300W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
import warnings
warnings.filterwarnings(action="ignore")

In [3]:
import lightning.pytorch as pl
from lightning.pytorch import Trainer, seed_everything

import os

from pprint import pprint

import sys
sys.path.append("../src")

from tqdm.auto import tqdm
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torchinfo import summary

In [4]:
%load_ext autoreload
%autoreload 2

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["WANDB__SERVICE_WAIT"] = "300"

In [5]:
import config

from dataloader import BEDataModule

from rt1 import RTCRAM
from utils.model_utils import training_step, validation_step, run_experiment 

## Build data module


In [6]:
dm = BEDataModule()
# dm.setup()

## Build model


In [7]:
rt1 = RTCRAM(
    cnn_bacnbone=config.SELECTED_CNN_BACKBONE, 
    num_res_blocks=config.NUM_RES_BLOCKS,
    freeze_cnn_backbone=config.FREEZE_CNN,
    args=None
).cuda()

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/efficientnet_b3.ra2_in1k)
INFO:timm.models._hub:[timm/efficientnet_b3.ra2_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


In [8]:
# print(rt1)
summary(model=rt1)

Layer (type:depth-idx)                                            Param #
RTCRAM                                                            --
├─RTEncoder: 1-1                                                  --
│    └─TextEncoder: 2-1                                           --
│    │    └─BertModel: 3-1                                        (28,763,648)
│    │    └─Dropout: 3-2                                          --
│    └─FiLMEncoder: 2-2                                           --
│    │    └─ImageFeatureExtractor: 3-3                            10,300,456
│    │    └─ModuleList: 3-4                                       6,340,608
│    └─TokenLearnerV11: 2-3                                       --
│    │    └─Sequential: 3-5                                       134,408
├─RTDecoder: 1-2                                                  --
│    └─TransformerDecoder: 2-4                                    --
│    │    └─EmbeddingLayer: 3-6                                   52

## Training config

In [9]:
loss_fn = nn.CrossEntropyLoss(
    ignore_index=config.TGT_PAD_TOK_ID, 
    label_smoothing=config.LABEL_SMOOTHING
)
opt = getattr(torch.optim, config.OPTIMIZER)(
    params=[p for p in rt1.parameters() if p.requires_grad], 
    lr=config.LR,
    weight_decay=config.WEIGHT_DECAY
)

scheduler = getattr(lr_scheduler, config.LR_SCHEDULER["type"])(**config.LR_SCHEDULER["params"], optimizer=opt)

## Run experiment

In [None]:
trained_model = run_experiment(
    model=rt1, 
    dm=dm, 
    opt=opt, 
    loss_fn=loss_fn,
    scheduler=scheduler
)

INFO:root:Training on 3868 samples.
INFO:root:Validating on 640 samples.
INFO:root:Testing on 250 samples.


Total # examples: 4758


Epoch 1/80 - (Train 96.7%): : 34it [06:12, 10.94s/it, train_loss=4.0983, train_loss_step=4.0983, val_CER=0.6855, val_Loss=7.0422, val_WER=0.9375]                          
Epoch 2/80 - (Train 96.7%): : 34it [06:03, 10.68s/it, train_loss=3.4416, train_loss_step=3.4416, val_CER=0.7967, val_Loss=7.0422, val_WER=0.9375]                          
Epoch 3/80 - (Train 96.7%): : 34it [06:04, 10.73s/it, train_loss=3.1508, train_loss_step=3.1508, val_CER=0.6694, val_Loss=7.0422, val_WER=0.7500]                          
Epoch 4/80 - (Train 96.7%): : 34it [06:00, 10.60s/it, train_loss=2.8846, train_loss_step=2.8846, val_CER=1.1652, val_Loss=6.7168, val_WER=0.6250]                          
Epoch 5/80 - (Train 96.7%): : 34it [06:05, 10.75s/it, train_loss=2.6442, train_loss_step=2.6442, val_CER=0.7402, val_Loss=6.7168, val_WER=0.8125]                          
Epoch 6/80 - (Train 96.7%): : 34it [06:07, 10.80s/it, train_loss=2.4402, train_loss_step=2.4402, val_CER=1.1870, val_Loss=6.7168, val_WER=0.