In [1]:
!ls /kaggle/input/chestxraycaption/mimic_cxr/mimic_cxr/

annotation.json  images


In [2]:
!git clone --branch dev https://github.com/flych3r/R2Gen

Cloning into 'R2Gen'...
remote: Enumerating objects: 251, done.[K
remote: Counting objects: 100% (162/162), done.[K
remote: Compressing objects: 100% (54/54), done.[K
remote: Total 251 (delta 108), reused 152 (delta 105), pack-reused 89[K
Receiving objects: 100% (251/251), 70.24 MiB | 33.02 MiB/s, done.
Resolving deltas: 100% (135/135), done.


In [3]:
%cd R2Gen/

/kaggle/working/R2Gen


In [4]:
!git pull origin dev

From https://github.com/flych3r/R2Gen
 * branch            dev        -> FETCH_HEAD
Already up to date.


In [5]:
%pip install -r requirements.txt

Collecting gdown
  Downloading gdown-4.5.1.tar.gz (14 kB)
  Installing build dependencies ... [?25l- \ | / - \ done
[?25h  Getting requirements to build wheel ... [?25l- \ | / - done
[?25h  Preparing metadata (pyproject.toml) ... [?25l- \ | / - \ done
Building wheels for collected packages: gdown
  Building wheel for gdown (pyproject.toml) ... [?25l- \ | / - \ | done
[?25h  Created wheel for gdown: filename=gdown-4.5.1-py3-none-any.whl size=14933 sha256=a631d73028ff9b4df3bd467aed3accb385255bdf80bea340d1352fd7c9a8eb38
  Stored in directory: /root/.cache/pip/wheels/3d/ec/b0/a96d1d126183f98570a785e6bf8789fca559853a9260e928e1
Successfully built gdown
Installing collected packages: gdown
Successfully installed gdown-4.5.1
[0mNote: you may need to restart the kernel to use updated packages.


In [6]:
import wandb
import torch
import argparse
import numpy as np
from modules.tokenizers import Tokenizer
from modules.dataloaders import R2DataLoader
from modules.metrics import compute_scores
from modules.optimizers import build_optimizer, build_lr_scheduler
from modules.trainer import Trainer
from modules.loss import compute_loss
from models.r2gen import R2GenModel
from modules.utils import parse_args

In [7]:
args_str = """
--image_dir /kaggle/input/chestxraycaption/mimic_cxr/mimic_cxr/images \
--ann_path /kaggle/input/chestxraycaption/mimic_cxr/mimic_cxr/annotation.json \
--dataset_name mimic_cxr \
--max_seq_length 80 \
--threshold 3 \
--batch_size 64 \
--steps 12000 \
--eval_steps 500 \
--save_dir results/mimic_cxr \
--lr_scheduler_step_size 3000 \
--lr_scheduler_gamma 0.8 \
--visual_extractor resnet \
--d_vf 2048 \
--n_gpu 1 \
--logger wandb \
--seed 456789
"""

In [8]:
args = parse_args(args_str)

In [9]:
 import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
WANDB_KEY = user_secrets.get_secret("WANDB_KEY")

wandb.login(key=WANDB_KEY)
wandb.init(project=f"r2gen-{args.dataset_name}")
wandb.run.name = f'{args.visual_extractor}-{wandb.run.name}'

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mflych3r[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)

In [11]:
# create tokenizer
tokenizer = Tokenizer(args)

In [12]:
# create data loader
train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True)
val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False)
test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False)

In [13]:
# build model architecture
model = R2GenModel(args, tokenizer)

Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth


  0%|          | 0.00/171M [00:00<?, ?B/s]

In [14]:
# get function handles of loss and metrics
criterion = compute_loss
metrics = compute_scores

# build optimizer, learning rate scheduler
optimizer = build_optimizer(args, model)
lr_scheduler = build_lr_scheduler(args, optimizer)

In [15]:
# build trainer
trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader)

In [16]:
# start to train
trainer.train()

  0%|          | 0/12000 [00:00<?, ?it/s]

	step           : 500
	train/loss     : 3.6446248626708986
	val/BLEU_1     : 0.2470816106794329
	val/BLEU_2     : 0.16667676604750184
	val/BLEU_3     : 0.12121476636987331
	val/BLEU_4     : 0.0902596682246608
	val/METEOR     : 0.11870144319799833
	val/ROUGE_L    : 0.3036614507090103
Saving checkpoint: results/mimic_cxr/current_checkpoint.pth ...
Saving current best: model_best.pth ...
	step           : 1000
	train/loss     : 2.3568642251491547
	val/BLEU_1     : 0.31352685552640064
	val/BLEU_2     : 0.2084676537434328
	val/BLEU_3     : 0.14751915846324373
	val/BLEU_4     : 0.10912654254851817
	val/METEOR     : 0.13101344124612122
	val/ROUGE_L    : 0.3114595631264389
Saving checkpoint: results/mimic_cxr/current_checkpoint.pth ...
Saving current best: model_best.pth ...
	step           : 1500
	train/loss     : 2.0766223075389862
	val/BLEU_1     : 0.3493309728364442
	val/BLEU_2     : 0.23210890425511418
	val/BLEU_3     : 0.1638417859480881
	val/BLEU_4     : 0.12083656626590035
	val/METEOR 

In [17]:
# test model on test set
trainer.test()

  0%|          | 0/61 [00:00<?, ?it/s]

	BLEU_1         : 0.312483526526434
	BLEU_2         : 0.196094869418219
	BLEU_3         : 0.13368263377080505
	BLEU_4         : 0.0966714134192661
	METEOR         : 0.1261906316225447
	ROUGE_L        : 0.2724552868843175


In [18]:
wandb.finish()

VBox(children=(Label(value='1327.137 MB of 1327.137 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.…

0,1
train/loss,█▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/BLEU_1,▁▅▇▄▆▇▇▆▇▇█▇▆▇▆▆███▆▇▇▇▇
val/BLEU_2,▁▅▇▄▆▇▇▆▇▇█▇▆▇▆▆█▇█▆▇▆▇▇
val/BLEU_3,▁▅▇▄▆▆▇▆▇▇█▇▆▇▆▅█▇█▆▇▆▇▆
val/BLEU_4,▁▅▇▄▆▆▇▆▇▇█▇▆▇▆▆█▇▇▆▇▇▇▆
val/METEOR,▁▄▆▃▆▆▇▆▇▇█▇▆▇▇▅███▆▇▆▇▇
val/ROUGE_L,▃▆▇▁▄▃▇▇▆▆▆▅▄▆█▇▇▆▆▅▅▅▅▄

0,1
train/loss,1.43874
val/BLEU_1,0.35247
val/BLEU_2,0.22968
val/BLEU_3,0.16001
val/BLEU_4,0.11769
val/METEOR,0.14114
val/ROUGE_L,0.30575
