In [9]:
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, AutoConfig
import numpy as np
from datasets import load_dataset
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as T
from vicuna_llava import vicuna_llava, dataset_llava

import zipfile
import wget
from accelerate import Accelerator
accelerator = Accelerator()

In [2]:
# grab vicuna and its tokenizer
model_name = "lmsys/vicuna-7b-v1.5"
config = AutoConfig.from_pretrained(model_name)

vicunallava = accelerator.prepare(vicuna_llava(config))


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [3]:
linear_llava_proj = torch.load('mm_projector.bin', weights_only=True)
linear_llava_weights = linear_llava_proj['model.mm_projector.weight']
linear_llava_biases = linear_llava_proj['model.mm_projector.bias']

with torch.no_grad():
    vicunallava.im_embedding.weight.copy_(linear_llava_weights)
    vicunallava.im_embedding.bias.copy_(linear_llava_biases)

In [4]:
# directories pointing to images directory and chat.json, I've downloaded the dataset locally to deal with some of the missing images
chat = 'CC3M/chat.json'
im_dir = "CC3M/images/"

cc3m_dataset = dataset_llava(chat, im_dir)


In [5]:
# don't compute gradient on any vicunallava layers besides im_embedding
for i in iter(vicunallava.parameters()):
    i.requires_grad = False

for i in vicunallava.im_embedding.parameters():
    i.requires_grad = True

In [11]:
testprompt, testimage, _ = cc3m_dataset[15]
transform = T.ToPILImage()

output = vicunallava.generate(testimage, testprompt, max_new_tokens=10)
print(f'input prompt: {testprompt}/n')
# transform(testimage).show()

print(f'response: {output}')

input prompt: Create a compact narrative representing the image presented.
<image>/n
response: ###://githubwyvernlogicdevopsfilesystem


Training (Stage 1)

In [7]:
from torch.optim import AdamW

optimizer =  accelerator.prepare(AdamW(vicunallava.parameters(), lr=2e-3))
loss_fn = nn.CrossEntropyLoss(reduction='none', ignore_index=-100)
batch_size=1#128
losses = []

cc3m_dataloader =  accelerator.prepare(DataLoader(cc3m_dataset, batch_size=batch_size))

vicunallava.train()
for i in range(1):
    batchiter = 0

    for batchprompt,batchimage,batchresp in cc3m_dataloader:
        

        input = [batchprompt[i]+'###'+batchresp[i] for i in range(batch_size)]

        tokenized_input = vicunallava.tokenize(input)

        encoded_batchim = vicunallava.vision_tower(batchimage)

        separator_token_id = vicunallava.tokenizer.convert_tokens_to_ids("###")

        # call model on encoded/tokenized inputs
        outs = vicunallava(batchimage, input, batch_size=batch_size)

        # shift outputs for causal loss calculation
        shifted_outs = outs[:,:-1,:]
        shifted_labels = tokenized_input['input_ids'][:,1:]

        # mask image out of loss computation
        im_loss_mask = torch.full_like(encoded_batchim[:,:,0], -100)

        labels = shifted_labels.clone()
        labels[(labels == separator_token_id).cumsum(dim=1) == 0] = -100

        # loss mask will ensure loss is only computed over the response tokens
        lossmask = torch.cat((im_loss_mask, labels), dim=1).type(torch.LongTensor)

        # compute loss, append it to list
        loss = loss_fn(shifted_outs.reshape(-1,vicunallava.model.config.vocab_size), lossmask.reshape(-1))
        losses.append(loss)

        # compute loss per sample only over non-masked elements
        loss_per_sample = loss[loss != 0].mean()

        # optimize model
        accelerator.backward(loss_per_sample)
        optimizer.step()
        optimizer.zero_grad()
        batchiter+=1
        print(f"iteration:{batchiter}, loss:{loss_per_sample}")
        torch.save(vicunallava.im_embedding, 'vicunallava_im_embedding_stage1.pt')


        # takes ~10min for me to locally train through 5 batches of size 20

KeyboardInterrupt: 