In [1]:
import torch
import json
from transformers import AutoProcessor, Blip2ForConditionalGeneration, TrainingArguments,AutoTokenizer
import torch
from PIL import Image
import pickle
import random
import peft
from trl import SFTTrainer
from torch.utils.data import Dataset, DataLoader
import copy
import wandb



In [4]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mdieplstks[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [27]:
wandb.init(
    project="BLIP-finetune",
    config={
        "r": 16,
        "lora_alpha": 16,
        "lora_dropout": 0.1,
        "modules_to_save": 'q-former',

    },
)

In [2]:
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
# by default `from_pretrained` loads the weights in float32
# we load in float16 instead to save memory
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16) 

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
tokenizer = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b")
tokenizer.pad_token = tokenizer.eos_token

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
model.to(device);

cuda


In [5]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )


In [8]:
# model_clone = copy.deepcopy(model)
print_trainable_parameters(model)

trainable params: 105137664 || all params: 3744679936 || trainable%: 2.81


In [33]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["qformer"],
)
peft_model = get_peft_model(model, lora_config)


In [34]:
print_trainable_parameters(peft_model)

trainable params: 110380544 || all params: 3855060480 || trainable%: 2.86


In [6]:
for param in model.parameters():
    param.requires_grad = False

In [7]:
for param in model.qformer.parameters():
    param.requires_grad = True


In [13]:
def get_data(qid, data):
    image = data[qid]['imageId']
    question = data[qid]['question']
    answer = data[qid]['answer']
    full_answer = data[qid]['fullAnswer']
    return image, question, answer, full_answer

In [14]:
f = open('questions/train_balanced_questions.json')
data = json.load(f)

In [15]:
random_keys = random.sample(list(data.keys()), 20000)
train_data = {key: data[key] for key in random_keys}

In [36]:
with open("train_questions_finetune.json", 'w') as f:
    json.dump(train_data, f)


In [16]:
# Moslty from: https://github.com/huggingface/notebooks/blob/main/peft/Fine_tune_BLIP2_on_an_image_captioning_dataset_PEFT.ipynb
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(images=item["image"], text=item["question"], 
                                  padding="max_length", max_length=60, return_tensors="pt")
        # remove batch dimension
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        # encoding["question"] = item["question"]
        encoding["answer"] = item["answer"]
        return encoding

def collate_fn(batch):
    # pad the input_ids and attention_mask
    processed_batch = {}
    for key in batch[0].keys():
        if key  not in ["answer"]:
            processed_batch[key] = torch.stack([example[key] for example in batch])
        else:
            text_inputs = processor.tokenizer(
                [example["answer"] for example in batch], padding=True, return_tensors="pt"
            )
            processed_batch["input_ids"] = text_inputs["input_ids"]
            processed_batch["attention_mask"] = text_inputs["attention_mask"]
    return processed_batch



In [17]:
dataset = []
for k in train_data:
    image, question, answer, _ = get_data(k, train_data)
    image = Image.open('images/'+image+'.jpg')
    question_formatted = f'Answer the following question with one word. Question: {question} Answer:' 
    dataset.append({'image': image, 'question': question_formatted, 'answer': answer})


# Create an instance of ImageCaptioningDataset
train_dataset = ImageCaptioningDataset(dataset, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=100, collate_fn=collate_fn)


In [37]:
peft_model.device

device(type='cuda', index=0)

In [44]:
optimizer = torch.optim.SGD(peft_model.parameters(), lr=5e-4)

device = "cuda" if torch.cuda.is_available() else "cpu"

peft_model.train()
for epoch in range(5):
    print(f'Epoch: {epoch}')
    n = 0
    for idx, batch in enumerate(train_dataloader):        
        input_ids = batch.pop("input_ids").to(device)
        pixel_values = batch.pop("pixel_values").to(device, torch.float32)
        outputs = peft_model(input_ids=input_ids,
                        pixel_values=pixel_values,
                        labels=input_ids)

        loss = outputs.loss
        wandb.log({'training_loss': loss.item()})
        
        if (n+1) % 10 == 0:
            print(f'{n+1} Loss: {loss.item()}')
        n += 1
        loss.backward()
    
        torch.nn.utils.clip_grad_norm_(peft_model.parameters(), 1.0)

        optimizer.step()
        optimizer.zero_grad()



Epoch: 0
10 Loss: 11.2734375
20 Loss: 11.3515625
30 Loss: 10.8671875
40 Loss: 11.421875
50 Loss: 11.1875
60 Loss: 11.5546875
70 Loss: 11.2578125
80 Loss: 11.3671875
90 Loss: 10.6796875
100 Loss: 11.34375
110 Loss: 11.359375
120 Loss: 11.34375
130 Loss: 11.2421875
140 Loss: 11.53125
150 Loss: 11.5859375
160 Loss: 11.1796875
170 Loss: 11.2578125
180 Loss: 10.75
190 Loss: 10.875
200 Loss: 11.390625
Epoch: 1
10 Loss: 11.1015625
20 Loss: 10.6640625
30 Loss: 11.1953125
40 Loss: 10.7890625
50 Loss: 11.0390625
60 Loss: 10.6171875
70 Loss: 11.0390625
80 Loss: 10.734375
90 Loss: 11.1796875
100 Loss: 11.2734375
110 Loss: 10.59375
120 Loss: 11.0546875
130 Loss: 10.8515625
140 Loss: 11.046875
150 Loss: 11.5234375
160 Loss: 10.6328125
170 Loss: 10.84375
180 Loss: 11.2265625
190 Loss: 10.53125
200 Loss: 10.8671875
Epoch: 2
10 Loss: 11.296875
20 Loss: 11.078125
30 Loss: 10.6796875
40 Loss: 10.3125
50 Loss: 10.625
60 Loss: 11.0234375
70 Loss: 10.9140625
80 Loss: 10.2734375
90 Loss: 11.015625
100 Loss: 

In [45]:
torch.save(peft_model.state_dict(), 'finetuneqformer15epochstatedict.torch')


In [46]:
torch.save(peft_model, 'finetuneqformer15epoch.torch')


In [47]:
wandb.finish()

VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
training_loss,▆▇▇▇█▇▆▆▆▆▆▆▆▆▅▆▅▆▆▅▅▅▅▄▄▅▄▅▅▄▄▃▃▂▃▂▃▁▁▁

0,1
training_loss,9.76562
