### Colab Setting
- torchscript_inference.ipynb, train.ipynb 의 device와 동일하게 설정 필수

In [None]:
from google.colab import drive
drive.mount('/content/drive')

device = 'cuda'
target_weight = 'checkpoint/weights-0.pt' # torchscript 로 바꿀 가중치 파일

In [None]:
%cd /content/drive/MyDrive/bridgeblip
!pip install -r requirements.txt

In [None]:
from transformers import InstructBlipForConditionalGeneration, InstructBlipProcessor, InstructBlipConfig

import torch
import torch.nn as nn

from peft import LoraConfig, get_peft_model

import random
from PIL import Image

In [None]:
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### Create **LoRA-Bridge** InstructBlip - load within 3B

In [None]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    print(f'total params      : {total_params:,}')


In [None]:
class ClassificationHead(nn.Module):
    def __init__(self, in_features, out_features, dropout):
        super().__init__()
        self.layer= nn.Sequential(
            nn.Linear(in_features, in_features),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(in_features, out_features)
        )
    def forward(self, x):
        return self.layer(x)

- instructblip.config.json 은 Huggingface - Salesforce/instructblip-flan-t5-xl 의 config.json 을 아래와 같이 수정한 파일입니다.

- change: **text_config.num_decoder_layers** 24 -> 8
- change: **text_config.vocab_size** 32128 -> 0

In [None]:
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl", use_fast=True)
base_blip = InstructBlipForConditionalGeneration(InstructBlipConfig.from_json_file('instructblip.config.json'))


# transformers.models.t5.modeling_t5.T5Model 을 통해 내부 코드를 확인하면 
# shared, encoder.embed_tokens, decoder.embed_tokens 는 공유 파라메터 구조 입니다.

# 처음 로드될때 불필요하게 lm_head 까지 nn.Linear(2048, 32128) 로 큰 파라메터로 로드 되는것을 막기 위해 vocab_size = 0 으로 설정후 
# encoder, decoder 부분을 원래 instructblip 구조 대로 복원하는 과정입니다. 


# nn.Embedding(0, 2048) to nn.Embedding(32128, 2048)
base_blip.language_model.shared = nn.Embedding(32128, 2048)
base_blip.language_model.encoder.embed_tokens = base_blip.language_model.shared
base_blip.language_model.decoder.embed_tokens = base_blip.language_model.shared


# nn.Linear(2048, 0) to ClassificationHead
base_blip.language_model.lm_head = ClassificationHead(
    in_features=base_blip.language_model.config.d_model,
    out_features=4,
    dropout=0.0
)


# decoder total 585,944,064 to 604,818,432
lora_config = LoraConfig(
    r=96,
    lora_alpha=192,
    target_modules=['q', 'k', 'v'],
    lora_dropout=0.1
)


base_blip.language_model.decoder = get_peft_model(base_blip.language_model.decoder, lora_config)

In [None]:
count_parameters(base_blip) # 2,939,964,164

In [None]:
state = torch.load(target_weight, map_location=device)
base_blip.load_state_dict(state['model'])
base_blip.eval()

### Convert to Torchscript

In [None]:
class BridgeInstructblip(nn.Module):
    def __init__(self, base_blip):
        super().__init__()
        self.base_blip = base_blip
        
    def forward(self, 
        input_ids,
        attention_mask,
        qformer_input_ids,
        qformer_attention_mask,
        pixel_values,
        decoder_input_ids
    ):
        
        return self.base_blip(
            input_ids              = input_ids,
            attention_mask         = attention_mask,
            qformer_input_ids      = qformer_input_ids,
            qformer_attention_mask = qformer_attention_mask,
            pixel_values           = pixel_values,
            decoder_input_ids      = decoder_input_ids
        ).logits

In [None]:
bridge = BridgeInstructblip(base_blip)
bridge = bridge.to(device)

count_parameters(bridge)

In [None]:
# example kwarg inputs
instructions = f'Question: {""} Options: {" ".join([f"({chr(i+97)}) {c}" for i, c in enumerate([])])} Short answer:'
inputs = {
    **processor(
    images=Image.open(f'competition/train_input_images/TRAIN_000.jpg'),
    text=instructions,
    return_tensors="pt",
    padding='max_length',
    truncation=True,
    max_length=128,
    ).to(device), 
    'decoder_input_ids' : torch.full((1, 1), 0, dtype=torch.long).to(device)
}

# TracerWarning 은 InstructBlipForConditionalGeneration 내에서 예상과 다른 입력값 예외 처리를 위해 if 문을 사용하기 때문입니다.

torch_script_model = torch.jit.trace(bridge, example_kwarg_inputs=inputs)
torch_script_model.save('checkpoint/torchscript.pt')