### Colab Setting

- torchscript.ipynb, train.ipynb 의 device 와 동일하게 설정 필수

In [None]:
from google.colab import drive
drive.mount('/content/drive')

device = 'cuda'

In [None]:
%cd /content/drive/MyDrive/bridgeblip
!pip install -r requirements.txt

In [None]:
from transformers import InstructBlipProcessor

import pandas as pd
import torch

from PIL import Image
from tqdm.auto import tqdm

import itertools

In [2]:
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    print(f'total params      : {total_params:,}')

In [4]:
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl", use_fast=True)

In [None]:
def predict(model, question, choices, image_id):

    instructions = f'Question: {question} Options: {" ".join([f"({chr(i+97)}) {c}" for i, c in enumerate(choices)])} Short answer:'
    inputs = {
        **processor(
        images=Image.open(f'competition/test_input_images/{image_id}.jpg'),
        text=instructions,
        return_tensors="pt",
        padding='max_length',
        truncation=True,
        max_length=128,
        ).to(device),
        'decoder_input_ids' : torch.full((1, 1), 0, dtype=torch.long).to(device)
    }

    with torch.no_grad():
        output = model(**inputs).squeeze()
        answer = chr(65 + output.argmax(-1).item())
        conf   = torch.softmax(output, dim=-1)[output.argmax(-1)].item()
        return answer, conf


In [None]:
bridge_instructblip = torch.jit.load('checkpoint/torchscript.pt', map_location=device)
bridge_instructblip.eval()

count_parameters(bridge_instructblip)

In [None]:
df = pd.read_csv('competition/test.csv')

answers = []
LABELS  = ['A','B','C','D']
THRESH  = 0.9

perms = list(itertools.permutations(range(4)))
for _, row in tqdm(df.iterrows(), total=len(df)):
    q       = row['Question']
    orig_ch = [row['A'], row['B'], row['C'], row['D']]
    
    score   = torch.zeros(4)
    best    = None
    
    for p in perms:
        perm_ch              = [orig_ch[i] for i in p]
        pred_label, conf     = predict(bridge_instructblip, q, perm_ch, row['ID'])
        
        pred_idx  = LABELS.index(pred_label)
        orig_idx  = p[pred_idx]
        
        score[orig_idx] += conf

        if conf >= THRESH:
            best = LABELS[orig_idx]
            break     
    else:
        best = LABELS[score.argmax()]

    answers.append(best)

    

In [None]:
df = pd.read_csv('competition/sample_submission.csv')

df['answer'] = answers
df.to_csv('submission.csv', index=False)