In [None]:
%%capture
!pip install https://github.com/flych3r/vxr/archive/main.zip

In [None]:
DATASET = 'mimic_cxr'
LOGGER = 'wandb'
SAMPLE = 1

In [None]:
if LOGGER == 'wandb':
    import wandb
    from kaggle_secrets import UserSecretsClient

    user_secrets = UserSecretsClient()
    WANDB_KEY = user_secrets.get_secret("WANDB_KEY")

    wandb.login(key=WANDB_KEY)

%env WANDB_LOG_MODEL=true

# Model

In [None]:
import json
from pathlib import Path

import torch
import evaluate
from transformers import AutoTokenizer, AutoFeatureExtractor
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

from vxr.models.configuration import XrayReportGenerationConfig
from vxr.models.modeling import XrayReportGeneration
from vxr.utils.data import XrayReportData, collate_fn

In [None]:
max_length = 100
batch_size = 32

encoder_arch = 'google/vit-base-patch16-224-in21k'
decoder_arch = 'google/t5-efficient-base'

In [None]:
feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_arch)
tokenizer = AutoTokenizer.from_pretrained(decoder_arch)

In [None]:
bleu = evaluate.load('bleu')
meteor = evaluate.load('meteor')


def compute_metrics(output):
    prediction_tokens, reference_tokens = output

    predictions = tokenizer.batch_decode(prediction_tokens, skip_special_tokens=True)
    references = tokenizer.batch_decode(reference_tokens, skip_special_tokens=True)

    bleu_score = bleu.compute(references=references, predictions=predictions)
    meteor_score = meteor.compute(references=references, predictions=predictions)

    return {
        **{
            f'bleu-{i}': s
            for i, s in enumerate(bleu_score['precisions'], start=1)
        },
        **meteor_score
    }

In [None]:
data = XrayReportData(
    image_dir=Path(f'/kaggle/input/chestxraycaption/{DATASET}/{DATASET}/images'),
    ann_path=Path(f'/kaggle/input/chestxraycaption/{DATASET}/{DATASET}/annotation.json'),
    max_length=max_length,
    tokenizer=tokenizer,
    feature_extractor=feature_extractor,
    sample=SAMPLE
)

In [None]:
config = XrayReportGenerationConfig(
    encoder_config=encoder_arch, 
    decoder_config=decoder_arch,
)
model = XrayReportGeneration(config)

In [None]:
args = Seq2SeqTrainingArguments(
    output_dir='xrrg-model',
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size // 2,
    evaluation_strategy='steps',
    save_strategy='steps',
    num_train_epochs=1,
    save_steps=1500,
    eval_steps=1500,
    logging_steps=50,
    fp16=False,
    optim='adamw_torch',
    learning_rate=2e-4,
    save_total_limit=1,
    push_to_hub=False,
    remove_unused_columns=False,
    report_to=LOGGER,
    log_level='warning',
    seed=42,
    load_best_model_at_end=True,
    metric_for_best_model='bleu-4',
    predict_with_generate=True,
    generation_max_length=max_length,
    generation_num_beams=1
)

In [None]:
if LOGGER == 'wandb':
    wandb.init(
        project=f"vxr-{DATASET}",
        config={
            "model": json.loads(config.to_json_string()),
            "args": json.loads(args.to_json_string())
        }
    )

In [None]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=data.train,
    eval_dataset=data.validation,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    tokenizer=tokenizer
)

In [None]:
trainer.train()

In [None]:
result = trainer.predict(test_dataset=data.test, max_length=max_length, num_beams=1)

In [None]:
ground_truth = result.label_ids
inference = result.predictions
metrics = result.metrics

ground_truth = tokenizer.batch_decode(ground_truth, skip_special_tokens=True)
inference = tokenizer.batch_decode(inference, skip_special_tokens=True)

In [None]:
with open('test-results.json', 'w') as f:
    json.dump({
        'ground_truth': ground_truth,
        'inference': inference,
        'metrics': metrics,
    }, f)

In [None]:
if LOGGER == 'wandb':
    wandb.run.log({m.replace('test_', 'test/'): v for m, v in metrics.items()})
    test_results = wandb.Table(columns=["ground_truth", "inference"], data=list(zip(ground_truth, inference)))
    wandb.run.log({"test-results": test_results})

In [None]:
if LOGGER == 'wandb':
    wandb.finish()