In [None]:
%%capture
!pip install https://github.com/flych3r/vxr/archive/main.zip

In [None]:
import json
from pathlib import Path

import torch
from transformers import AutoTokenizer, AutoFeatureExtractor

from vxr.models.modeling import XrayReportGeneration
from vxr.utils.data import XrayReportData

In [None]:
max_length = 100
batch_size = 32
artifact_dir = Path('model')

In [None]:
transforms = AutoFeatureExtractor.from_pretrained(artifact_dir)
tokenizer = AutoTokenizer.from_pretrained(artifact_dir)

In [None]:
data = XrayReportData(
    image_dir=Path('/kaggle/input/chestxraycaption/mimic_cxr/mimic_cxr/images'),
    ann_path=Path('/kaggle/input/chestxraycaption/mimic_cxr/mimic_cxr/annotation.json'),
    max_length=max_length,
    tokenizer=tokenizer,
    transforms=transforms
)

In [None]:
model = XrayReportGeneration.from_pretrained(artifact_dir)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.device

In [None]:
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

ground_truths = []
greedy_outputs = []
beam_outputs = []
sample_outputs = []
topk_outputs = []

dl = DataLoader(data.test, batch_size=batch_size, shuffle=False)

In [None]:
for batch in tqdm(dl, total=len(dl)):
    gt = batch['input_ids']
    with torch.no_grad():
        pixel_values = batch['pixel_values'].to(model.device)
        go = model.generate(
            pixel_values,
            max_length=max_length,
            num_beams=1
        )

        bo = model.generate(
            pixel_values, 
            max_length=max_length, 
            num_beams=3, 
            early_stopping=True
        )

        so = model.generate(
            pixel_values, 
            do_sample=True, 
            max_length=max_length, 
            top_k=0,
            temperature=0.7
        )

        to = model.generate(
            pixel_values, 
            do_sample=True, 
            max_length=max_length, 
            top_k=50
        )

    ground_truths.append(gt)
    greedy_outputs.append(go)
    beam_outputs.append(bo)
    sample_outputs.append(so)
    topk_outputs.append(to)

In [None]:
ground_truths = torch.cat(ground_truths)
greedy_outputs = torch.cat(greedy_outputs)
beam_outputs = torch.cat(beam_outputs)
sample_outputs = torch.cat(sample_outputs)
topk_outputs = torch.cat(topk_outputs)

In [None]:
text_ground_truth = tokenizer.batch_decode(ground_truths, skip_special_tokens=True)
text_greedy = tokenizer.batch_decode(greedy_outputs, skip_special_tokens=True)
text_beam = tokenizer.batch_decode(beam_outputs, skip_special_tokens=True)
text_sample = tokenizer.batch_decode(sample_outputs, skip_special_tokens=True)
text_topk = tokenizer.batch_decode(topk_outputs, skip_special_tokens=True)

In [None]:
import json

with open('greedy.json', 'w') as f:
    json.dump({
        'ground_truth': text_ground_truth,
        'inference': text_greedy
    }, f, indent=4)

with open('beam.json', 'w') as f:
    json.dump({
        'ground_truth': text_ground_truth,
        'inference': text_beam
    }, f, indent=4)

with open('sample.json', 'w') as f:
    json.dump({
        'ground_truth': text_ground_truth,
        'inference': text_sample
    }, f, indent=4)

with open('topk.json', 'w') as f:
    json.dump({
        'ground_truth': text_ground_truth,
        'inference': text_topk
    }, f, indent=4)