In [None]:
!ls /kaggle/input/chestxraycaption/mimic_cxr/mimic_cxr/

In [None]:
!git clone --branch dev https://github.com/flych3r/R2Gen

In [None]:
%cd R2Gen/

In [None]:
!git pull origin dev

In [None]:
%pip install -r requirements.txt

In [None]:
import wandb
import torch
import argparse
import numpy as np
from modules.tokenizers import Tokenizer
from modules.dataloaders import R2DataLoader
from modules.metrics import compute_scores
from modules.optimizers import build_optimizer, build_lr_scheduler
from modules.trainer import Trainer
from modules.loss import compute_loss
from models.r2gen import R2GenModel
from modules.utils import parse_args

In [None]:
args_str = """
--image_dir /kaggle/input/chestxraycaption/mimic_cxr/mimic_cxr/images \
--ann_path /kaggle/input/chestxraycaption/mimic_cxr/mimic_cxr/annotation.json \
--dataset_name mimic_cxr \
--max_seq_length 100 \
--threshold 3 \
--batch_size 32 \
--steps 8450 \
--eval_steps 1500 \
--save_dir results/mimic_cxr \
--lr_scheduler_step_size 3000 \
--lr_scheduler_gamma 0.8 \
--visual_extractor vit \
--d_vf 768 \
--n_gpu 1 \
--logger wandb \
--seed 456789 \
--lr_ve 2e-4 \
--lr_ed 2e-4
"""

In [None]:
args = parse_args(args_str)

In [None]:
 import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
WANDB_KEY = user_secrets.get_secret("WANDB_KEY")

wandb.login(key=WANDB_KEY)
wandb.init(project=f"r2gen-{args.dataset_name}")
wandb.run.name = f'{args.visual_extractor}-{wandb.run.name}'

In [None]:
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)

In [None]:
# create tokenizer
tokenizer = Tokenizer(args)

In [None]:
# create data loader
train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True)
val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False)
test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False)

In [None]:
# build model architecture
model = R2GenModel(args, tokenizer)

In [None]:
# get function handles of loss and metrics
criterion = compute_loss
metrics = compute_scores

# build optimizer, learning rate scheduler
optimizer = build_optimizer(args, model)
lr_scheduler = build_lr_scheduler(args, optimizer)

In [None]:
# build trainer
trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader)

In [None]:
# start to train
trainer.train()

In [None]:
# test model on test set
trainer.test()

In [None]:
wandb.finish()