## Examine BLIP-2 from LAVIS and Setup Fine-Tuning

In [3]:
from lavis.models import model_zoo
from PIL import Image

KeyboardInterrupt: 

In [13]:
print(model_zoo)

Architectures                  Types
albef_classification           ve
albef_feature_extractor        base
albef_nlvr                     nlvr
albef_pretrain                 base
albef_retrieval                coco, flickr
albef_vqa                      vqav2
alpro_qa                       msrvtt, msvd
alpro_retrieval                msrvtt, didemo
blip_caption                   base_coco, large_coco
blip_classification            base
blip_feature_extractor         base
blip_image_text_matching       base, large
blip_nlvr                      nlvr
blip_pretrain                  base
blip_retrieval                 coco, flickr
blip_vqa                       vqav2, okvqa, aokvqa
blip2_opt                      pretrain_opt2.7b, pretrain_opt6.7b, caption_coco_opt2.7b, caption_coco_opt6.7b
blip2_t5                       pretrain_flant5xl, pretrain_flant5xl_vitL, pretrain_flant5xxl, caption_coco_flant5xl
blip2_feature_extractor        pretrain, pretrain_vitL, coco
blip2                      

We want the feature extractor. We will take the pre-trained model and fine-tune it.

In [14]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [15]:
from lavis.models import load_model_and_preprocess
model, vis_processors, txt_processors = load_model_and_preprocess(name="blip2_feature_extractor", model_type="pretrain", is_eval=False, device=device)

  state_dict = torch.load(cached_file, map_location="cpu")
  checkpoint = torch.load(cached_file, map_location="cpu")


In [16]:
model

Blip2Qformer(
  (visual_encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-38): 39 x Block(
        (norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1408, out_features=4224, bias=False)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=1408, out_features=1408, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=1408, out_features=6144, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )


In [17]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

186705470

The model provides image and text pre-processors that have both train and eval modes. The image processor behaves differently in the two modes, but the text processor appears to behave the same.

In [18]:
img = Image.open('../../assets/sofa.jpg').convert('RGB')

In [19]:
vis_train_process = vis_processors['train']

In [20]:
vis_train_process.__dict__

{'normalize': Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
 'transform': Compose(
     RandomResizedCrop(size=(224, 224), scale=(0.5, 1.0), ratio=(0.75, 1.3333), interpolation=bicubic, antialias=True)
     RandomHorizontalFlip(p=0.5)
     <lavis.processors.randaugment.RandomAugment object at 0x7f0974ce7dd0>
     ToTensor()
     Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
 )}

In [32]:
vis_input = vis_processors['train'](img)
print(vis_input)
vis_input = vis_input.unsqueeze(0).to(device)
vis_input = torch.cat((vis_input, vis_input), 0)

tensor([[[ 1.0398,  1.0398,  1.0544,  ...,  0.7625,  0.7625,  0.7771],
         [ 1.1566,  1.1420,  1.0252,  ...,  0.7771,  0.7771,  0.7625],
         [ 1.0982,  1.0106,  1.0982,  ...,  0.7771,  0.7771,  0.7771],
         ...,
         [ 0.0179, -0.0550, -0.1280,  ...,  1.3026,  1.1566,  1.2880],
         [ 0.0033, -0.0842, -0.1572,  ...,  1.0690,  0.8501,  0.7625],
         [-0.0113, -0.0988, -0.1718,  ...,  1.0252,  0.8647,  0.5873]],

        [[ 1.1744,  1.1744,  1.1894,  ...,  0.8893,  0.8893,  0.9193],
         [ 1.2945,  1.2795,  1.1594,  ...,  0.9043,  0.9043,  0.9043],
         [ 1.2344,  1.1444,  1.2344,  ...,  0.9043,  0.9043,  0.9343],
         ...,
         [-0.2963, -0.3714, -0.4314,  ...,  0.8292,  0.7092,  0.7842],
         [-0.3264, -0.4014, -0.4764,  ...,  0.5591,  0.3640,  0.2289],
         [-0.3264, -0.4164, -0.4914,  ...,  0.4991,  0.3640,  0.0338]],

        [[ 1.3496,  1.3496,  1.3638,  ...,  0.9088,  0.9088,  0.8945],
         [ 1.4633,  1.4491,  1.3354,  ...,  0

In [22]:
vis_processors['eval'](img)

tensor([[[ 0.9522,  0.9230,  0.9230,  ...,  0.7187,  0.7187,  0.7187],
         [ 0.9084,  0.8501,  0.8938,  ...,  0.7187,  0.7187,  0.7187],
         [ 0.8501,  0.7479,  0.9522,  ...,  0.7187,  0.7187,  0.7187],
         ...,
         [ 0.2515, -0.0550, -0.2156,  ...,  0.9084,  0.8647,  0.8355],
         [ 0.2807,  0.5435,  0.2953,  ...,  0.8792,  0.9376,  0.9814],
         [ 0.5289,  0.3975,  0.3975,  ...,  0.8501,  0.8355,  0.8063]],

        [[ 1.0844,  1.0544,  1.0544,  ...,  0.8593,  0.8593,  0.8593],
         [ 1.0393,  0.9793,  1.0243,  ...,  0.8593,  0.8593,  0.8593],
         [ 0.9793,  0.8743,  1.0844,  ...,  0.8593,  0.8593,  0.8593],
         ...,
         [-0.2663, -0.5815, -0.7316,  ...,  0.9193,  0.8743,  0.8442],
         [-0.2213,  0.0638, -0.1913,  ...,  0.8893,  0.9493,  0.9943],
         [ 0.0488, -0.0862, -0.0712,  ...,  0.8593,  0.8442,  0.8142]],

        [[ 1.2358,  1.2216,  1.2358,  ...,  0.8092,  0.8092,  0.8092],
         [ 1.1932,  1.1363,  1.2074,  ...,  0

In [31]:
txt_input = txt_processors['train']('Hello, world!')
print(txt_input)
txt_input = [txt_input, txt_input]

hello, world


In [29]:
txt_processors['train']('Hello, world')

'hello, world'

### Setup the Dataset and DataLoader

In [1]:
from lavis.models import load_model_and_preprocess
import torch
from torch.utils.data import Dataset, WeightedRandomSampler, DataLoader
from PIL import Image
import pandas as pd

import os
from tqdm import tqdm
import random

  return torch.cuda.amp.custom_fwd(orig_func)  # type: ignore
  return torch.cuda.amp.custom_bwd(orig_func)  # type: ignore


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

We use the Google Shopping dataset, which is examined a bit in the notebook Marqo-GS-10M-EDA.

Notice that we mix up the order of the query and title in the lables so that the model doesn't get used to having the query first.

In [3]:
class GoogleShoppingDataset(Dataset):
    def __init__(self, image_dir: str, annotations_file: str, image_processor: object, text_processor: object):
        self.annotations = pd.read_csv(annotations_file)
        self.image_dir = image_dir
        self.image_processor = image_processor
        self.text_processor = text_processor
        
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx: int):
        image_path = os.path.join(self.image_dir, self.annotations.loc[idx, 'image_local'])
        image = Image.open(image_path).convert('RGB')
        image = self.image_processor(image)
        label_options = (self.annotations.loc[idx, 'query'] + ': ' + self.annotations.loc[idx, 'title'], 
                         self.annotations.loc[idx, 'title'] + ': ' + self.annotations.loc[idx, 'query'])
        label = random.choice(label_options)
        label = self.text_processor(label)
        return image, label

We need the sample weights (inverse of how often an item appears in the dataset and how big its group is) to use WeightedRandomSampler below.

In [4]:
def get_sample_weights(annotations_file: str):
    annotations_df = pd.read_csv(annotations_file)
    query_counts = annotations_df['query_id'].value_counts()
    query_counts_full = query_counts[annotations_df['query_id']].to_numpy()
    product_counts = annotations_df['product_id'].value_counts()
    product_counts_full = product_counts[annotations_df['product_id']].to_numpy()
    weights = 1 / ( query_counts_full * product_counts_full)
    return weights

Build the dataloader based on the GoogleShoppingDataset class.

In [5]:
def build_dataloader(images_dir: str, annotations_file: str, image_processor: callable, 
                     text_processor: callable, seed=42, batch_size=64, num_workers=2) -> DataLoader:
    dataset = GoogleShoppingDataset(image_dir=images_dir, annotations_file=annotations_file,
                                      image_processor=image_processor, text_processor=text_processor)
    weights = get_sample_weights(annotations_file)
    generator = torch.Generator().manual_seed(seed)
    sampler = WeightedRandomSampler(weights, len(weights), replacement=True, generator=generator)
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)
    return dataloader

In [6]:
model, vis_processors, txt_processors = load_model_and_preprocess(name="blip2_feature_extractor", model_type="pretrain", is_eval=False, device=device)
images_dir = '/mnt/d/marqo-gs-10m/images'
train_annotations = '/mnt/d/marqo-gs-10m/marqo-gs-dataset/marqo_gs_full_10m/query_0_product_id_0.csv'
val_annotations = '/mnt/d/marqo-gs-10m/marqo-gs-dataset/marqo_gs_full_10m/query_1_product_id_1.csv'

train_dataloader = build_dataloader(images_dir=images_dir, annotations_file=train_annotations,
                                    image_processor=vis_processors['train'],
                                    text_processor=txt_processors['train'], seed=42,
                                    batch_size=24, num_workers=2)
val_dataloader = build_dataloader(images_dir=images_dir, annotations_file=val_annotations,
                                    image_processor=vis_processors['eval'],
                                    text_processor=txt_processors['eval'], seed=42,
                                    batch_size=24, num_workers=2)

  state_dict = torch.load(cached_file, map_location="cpu")
  checkpoint = torch.load(cached_file, map_location="cpu")


In [7]:
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                              lr=1e-5, betas=(0.9, 0.999), weight_decay=0.05)

We must run the code in distributed mode, even if we're just running on one GPU. The code looks like it was partially setup to allow for local runs, but some distributed code did not get properly wrapped.

In [8]:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
torch.distributed.init_process_group(backend="nccl", world_size=1, rank=0)

The training loop

In [9]:
def train_one_epoch(model, dataloader, device, optimizer):
    losses = []
    running_loss = 0
    
    model.train()
    for i, data in enumerate(tqdm(dataloader)):
        images, labels = data
        images = images.to(device)
        samples = {"image": images, "text_input": labels}
        
        optimizer.zero_grad()
        with torch.autocast(device_type="cuda"):
            output = model(samples)
        output.loss.backward()
        optimizer.step()
        running_loss += output.loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000
            losses.append(last_loss)
            print(f'  batch {i+1} loss: {last_loss}')
            running_loss = 0
                
    return losses

The validation loop

In [10]:
@torch.no_grad()
def validate(model, dataloader, device):
    losses = []
    running_loss = 0
    
    model.eval()
    for i, data in enumerate(tqdm(dataloader)):
        images, labels = data
        images = images.to(device)
        samples = {"image": images, "text_input": labels}
        with torch.autocast(device_type="cuda"):
            output = model(samples)
        running_loss += output.loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000
            losses.append(last_loss)
            print(f'  batch {i+1} loss: {last_loss}')
            running_loss = 0

In [11]:
losses_train = train_one_epoch(model=model, dataloader=train_dataloader, device=device, optimizer=optimizer)

  offset = -low * scale
  offset = -low * scale
  0%|          | 106/163616 [01:13<31:41:39,  1.43it/s]


KeyboardInterrupt: 

In [None]:
losses_validate = validate(model=model, dataloader=val_dataloader, device=device)

  0%|          | 25/40884 [00:11<5:10:49,  2.19it/s]


KeyboardInterrupt: 