In [1]:
from lavis.models import load_model_and_preprocess
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd

import os
from tqdm import tqdm

  return torch.cuda.amp.custom_fwd(orig_func)  # type: ignore
  return torch.cuda.amp.custom_bwd(orig_func)  # type: ignore


In [2]:
model_paths = {'gs': '/mnt/d/marqo-gs-10m/model-saves/pretrain_1epoch.pt',
               'abo': '/mnt/d/abo-dataset/model_saves/pretrain_2epochs.pt'}

abo_dataset_dir='/mnt/d/abo-dataset'
model_type='gs'
device='cuda'
save_path='/mnt/d/embeddings'

In [3]:
images_dir = abo_dataset_dir + '/images/small'
metadata_file = abo_dataset_dir + '/abo-listings-final-draft.pkl'
image_metadata_file = abo_dataset_dir + '/images/metadata/images.csv'

model, vis_processors, txt_processors = load_model_and_preprocess(
    name="blip2_feature_extractor", model_type="pretrain", is_eval=True, device='cpu')
if model_type != 'pretrain':
    model = torch.load(model_paths[model_type], weights_only=False)
model.to(device)

  state_dict = torch.load(cached_file, map_location="cpu")
  checkpoint = torch.load(cached_file, map_location="cpu")


Blip2Qformer(
  (visual_encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-38): 39 x Block(
        (norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1408, out_features=4224, bias=False)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=1408, out_features=1408, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=1408, out_features=6144, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )


In [4]:
def abo_image_item_pairs(metadata: pd.DataFrame) -> pd.DataFrame:
    image_ids = []
    item_ids = []
    for item_id in metadata.index:
        main_image_id = metadata.loc[item_id, 'main_image_id']
        if not pd.isna(main_image_id):
            image_ids.append(main_image_id)
            item_ids.append(item_id)
        other_image_ids = metadata.loc[item_id, 'other_image_id']
        if isinstance(other_image_ids, list):
            for other_image_id in other_image_ids:
                image_ids.append(other_image_id)
                item_ids.append(item_id)
        elif not pd.isna(other_image_ids):
            image_ids.append(other_image_ids)
            item_ids.append(item_id)
    return pd.DataFrame({'image_id': image_ids, 'item_id': item_ids})

In [11]:
class ABODataset_multimodal(Dataset):
    # Note: modified from fine-tuning version
    def __init__(self, image_dir: str, metadata: pd.DataFrame,
                 image_metadata: pd.DataFrame, image_item_pairs: pd.DataFrame,
                 image_processor: callable, text_processor: callable):
        self.image_dir = image_dir
        self.metadata = metadata
        self.image_metadata = image_metadata
        self.image_processor = image_processor
        self.text_processor = text_processor
        self.image_item_pairs = image_item_pairs
        
        self._reorg_metadata_columns()
        
    def _reorg_metadata_columns(self):
        # self.metadata = self.metadata.drop(columns=['item_weight', 'main_image_id',
        #                                             'other_image_id', 'country',
        #                                             'marketplace', 'domain_name'])    
        self.metadata = self.metadata[['item_name', 'brand', 'model_name', 'model_year',
                                       'product_description', 'product_type', 'color',
                                       'fabric_type', 'style', 'material', 'item_keywords',
                                       'pattern', 'finish_type', 'bullet_point']]

    def __len__(self):
        return len(self.image_item_pairs)
    
    def __getitem__(self, idx: int):
        image_id = self.image_item_pairs.loc[idx, 'image_id']
        item_id = self.image_item_pairs.loc[idx, 'item_id']
        image_path = os.path.join(self.image_dir, self.image_metadata.loc[image_id, 'path'])
        image = Image.open(image_path).convert('RGB')
        image = self.image_processor(image)
        label = self._row_to_str(self.metadata.loc[item_id])
        label = self.text_processor(label)
        
        return image, label, image_id, item_id
    
    def _row_to_str(self, row):
        row_filtered = row.dropna()
        # heading_data_pairs = list(zip(row_filtered.index, row_filtered))
        text = []
        for row_item in row_filtered:
            if isinstance(row_item, list):
                for list_item in row_item:
                    text.append(str(list_item) + ';')
            else:
                text.append(str(row_item) + ';')
        
        return ' '.join(text).replace('\n', ' ').replace('^', ' ')

In [12]:
metadata = pd.read_pickle(metadata_file)
image_metadata = pd.read_csv(image_metadata_file).set_index('image_id')
image_item_pairs = abo_image_item_pairs(metadata)
dataset_multimodal = ABODataset_multimodal(images_dir, metadata, image_metadata, image_item_pairs,
                                           vis_processors['eval'], txt_processors['eval'])

In [13]:
model.eval()
end_item = 95323
batch_size = 1
for i in range(end_item-batch_size+1, end_item+1):
    print(i)
    image, label, image_id, item_id = dataset_multimodal.__getitem__(i)
    image = image.unsqueeze(0).to(device)
    sample = {"image": image, "text_input": [label]}
    print(label)
    print(model.extract_features(sample).multimodal_embeds[0,0,:])

95323
crimini baby bella mushrooms, 6 oz package produce aisle grocery crimini mushrooms vegetable fresh produce ad09 cremini creminni cremni criminimushroom criminimushrooms fungi fungus krimini mushroom produce vegetable fresh produce ad09 cremini creminni cremni criminimushroom criminimushrooms fungi fungus krimini mushroom produce vegetable fresh produce ad09 cremini creminni cremni criminimushroom criminimushrooms fungi


  return torch.cuda.amp.autocast(dtype=dtype)


tensor([ 3.0193e-01, -3.4187e-01,  1.1122e-01, -8.1153e-01, -5.2237e-01,
         3.6934e-01,  2.1672e-01,  2.7581e-01, -7.6398e-01, -5.1356e-01,
         7.6467e-01,  8.8154e-02,  2.2094e-01,  1.5820e+00, -4.6085e-01,
        -3.0875e-01, -2.1006e-01,  4.8221e-01,  5.6548e-01, -3.7220e-01,
         6.5108e-02, -3.2434e-01,  7.7151e-02, -3.2199e-02,  1.5142e-03,
         5.6789e-02, -6.3519e-01, -1.2742e-01,  4.0435e-01,  2.7062e-01,
         1.5673e-02,  5.6183e-01, -3.2920e-01, -4.3117e-01, -2.3878e-01,
        -3.3439e-01,  2.8009e-01,  2.0495e-02,  1.8467e-01,  5.8730e-02,
        -3.7960e-01, -3.2002e-03, -1.5239e-01,  4.6644e-01, -4.6546e-01,
        -6.1064e-01, -1.8004e+00, -3.4972e-02,  2.2180e-01, -3.2609e-01,
         1.9626e-01, -3.7214e-01, -1.1899e-01, -1.8968e-01, -2.2329e-01,
         3.3938e-01, -6.1180e-01, -2.0553e-01,  2.7324e-01, -4.1762e-02,
         2.0424e-02,  1.8374e-01, -3.7495e-01, -7.9368e-01, -7.8656e-02,
        -3.2709e-02, -1.5525e-02,  1.6715e-01,  6.8