### Setup & Imports

In [2]:
%pip install Pillow sentence-transformers torch milvus pandas imageio pymilvus onnxruntime ftfy openai-clip matplotlib

Note: you may need to restart the kernel to use updated packages.


In [3]:
%run cliponnx/models.py
%run cliponnx/simple_tokenizer.py

In [4]:
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from PIL import Image
import torch
import glob
import regex as rg
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from cliponnx.models import TextualModel, VisualModel
import math
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import precision_score

from helpers import *

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'mps')
torch.set_num_threads(10)

In [6]:
ground_truth_de, ground_truth_en = create_ground_truth_dicts()
# print(ground_truth_de)
# print(ground_truth_en)

### Model 1: CLIP+Vision Transformers (ViT) model with multilingual training

In [7]:
img_model = SentenceTransformer('clip-ViT-B-32').to(device)
text_model = SentenceTransformer('sentence-transformers/clip-ViT-B-32-multilingual-v1').to(device)
emb_dims=512
cliparts = None
cliparts_txt = None

In [8]:
from pymilvus import (
    connections,
    utility,
    FieldSchema,
    CollectionSchema,
    DataType,
    Collection,
)
connections.connect("default", host="localhost", port="19530")

In [9]:
def drop_collections():
    utility.drop_collection("cliparts")
    utility.drop_collection("cliparts_txt")

In [10]:
def setup_collections():
    global cliparts
    global cliparts_txt
    fields = [
        FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name="img_embeddings", dtype=DataType.FLOAT_VECTOR, dim=emb_dims),
        FieldSchema(name="text", dtype=DataType.VARCHAR, dim=emb_dims, max_length=1024),
        FieldSchema(name="path", dtype=DataType.VARCHAR, dim=emb_dims, max_length=1024),
        FieldSchema(name='txt_embeddings', dtype=DataType.ARRAY, element_type=DataType.FLOAT, max_capacity=emb_dims)
    ]
    schema = CollectionSchema(fields, "Clipart collection")
    cliparts = Collection("cliparts", schema,
            enable_dynamic_field=True)
    fields = [
        FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name="txt_embeddings", dtype=DataType.FLOAT_VECTOR, dim=emb_dims),
        FieldSchema(name="text", dtype=DataType.VARCHAR, dim=emb_dims, max_length=1024),
        FieldSchema(name="path", dtype=DataType.VARCHAR, dim=emb_dims, max_length=1024),
        FieldSchema(name='img_embeddings', dtype=DataType.ARRAY, element_type=DataType.FLOAT, max_capacity=emb_dims)
    ]
    schema = CollectionSchema(fields, "Clipart collection (text vectors)")
    cliparts_txt = Collection("cliparts_txt", schema,
            enable_dynamic_field=True)

In [11]:
def create_indices():
    global cliparts
    global cliparts_txt
    
    cliparts.create_index(
        field_name="img_embeddings",
        index_params={
            "metric_type": "COSINE",
            "index_type": "IVF_FLAT",
            "params": {"nlist": 1024},
        },
    )
    cliparts_txt.create_index(
        field_name="txt_embeddings",
        index_params={
            "metric_type": "COSINE",
            "index_type": "IVF_FLAT",
            "params": {"nlist": 1024},
        },
    )

In [18]:
img_paths=glob.glob("data/Cliparts/01_Kate Hadfield/**/*.png", recursive=True)
imgs=pd.DataFrame({'path':img_paths})
imgs

Unnamed: 0,path
0,data/Cliparts/01_Kate Hadfield/logo-cropped.png
1,data/Cliparts/01_Kate Hadfield/EggHunting/khad...
2,data/Cliparts/01_Kate Hadfield/EggHunting/khad...
3,data/Cliparts/01_Kate Hadfield/EggHunting/khad...
4,data/Cliparts/01_Kate Hadfield/EggHunting/khad...
...,...
8212,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...
8213,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...
8214,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...
8215,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...


In [19]:
imgs['text']=imgs['path'].apply(lambda p:" ".join(rg.split(r'[\\|/_\\.-]', p)).strip())
imgs['img_embeddings']=[[]]*imgs.shape[0]
imgs['txt_embeddings']=[[]]*imgs.shape[0]
imgs

Unnamed: 0,path,text,img_embeddings,txt_embeddings
0,data/Cliparts/01_Kate Hadfield/logo-cropped.png,data Cliparts 01 Kate Hadfield logo cropped png,[],[]
1,data/Cliparts/01_Kate Hadfield/EggHunting/khad...,data Cliparts 01 Kate Hadfield EggHunting khad...,[],[]
2,data/Cliparts/01_Kate Hadfield/EggHunting/khad...,data Cliparts 01 Kate Hadfield EggHunting khad...,[],[]
3,data/Cliparts/01_Kate Hadfield/EggHunting/khad...,data Cliparts 01 Kate Hadfield EggHunting khad...,[],[]
4,data/Cliparts/01_Kate Hadfield/EggHunting/khad...,data Cliparts 01 Kate Hadfield EggHunting khad...,[],[]
...,...,...,...,...
8212,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...,data Cliparts 01 Kate Hadfield DoctorsAndNurse...,[],[]
8213,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...,data Cliparts 01 Kate Hadfield DoctorsAndNurse...,[],[]
8214,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...,data Cliparts 01 Kate Hadfield DoctorsAndNurse...,[],[]
8215,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...,data Cliparts 01 Kate Hadfield DoctorsAndNurse...,[],[]


In [26]:
imgs=imgs[imgs['path'].apply(lambda p:'stamp' not in p.lower())]
imgs

Unnamed: 0,path,text,img_embeddings,txt_embeddings
0,data/Cliparts/01_Kate Hadfield/logo-cropped.png,data Cliparts 01 Kate Hadfield logo cropped png,"[-0.1224354059, 0.3337236047, -0.4869712591, 0...","[0.11748978500000001, -0.033165860900000003, 0..."
1,data/Cliparts/01_Kate Hadfield/EggHunting/khad...,data Cliparts 01 Kate Hadfield EggHunting khad...,"[0.061049770600000004, -0.1090632156, -0.38299...","[-0.1486460865, -0.0283475928, 0.0670743361, -..."
2,data/Cliparts/01_Kate Hadfield/EggHunting/khad...,data Cliparts 01 Kate Hadfield EggHunting khad...,"[-0.0441452377, 0.0552457534, -0.2363186777, 0...","[-0.1432264298, -0.0117266253, 0.0473660268, -..."
3,data/Cliparts/01_Kate Hadfield/EggHunting/khad...,data Cliparts 01 Kate Hadfield EggHunting khad...,"[0.3908630013, 0.1392904371, 0.0251698196, 0.2...","[-0.1174651235, 0.0058466634, 0.0603898764, -0..."
4,data/Cliparts/01_Kate Hadfield/EggHunting/khad...,data Cliparts 01 Kate Hadfield EggHunting khad...,"[0.1650562286, -0.14565433560000002, 0.0494911...","[-0.1221045926, 0.0345767699, 0.0234132409, -0..."
...,...,...,...,...
8212,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...,data Cliparts 01 Kate Hadfield DoctorsAndNurse...,"[0.1290355921, -0.1939580441, -0.395267278, 0....","[0.1815917343, 0.0183345601, 0.0043674563, -0...."
8213,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...,data Cliparts 01 Kate Hadfield DoctorsAndNurse...,"[0.1826949716, 0.22084504370000002, -0.2322636...","[0.13958533110000002, 0.0363522507, 0.03278619..."
8214,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...,data Cliparts 01 Kate Hadfield DoctorsAndNurse...,"[0.1671190262, 0.2088604718, -0.2348356694, 0....","[0.1449439079, 0.0388322063, 0.0345865823, -0...."
8215,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...,data Cliparts 01 Kate Hadfield DoctorsAndNurse...,"[0.1290355921, -0.1939580441, -0.395267278, 0....","[0.1923192739, 0.0340455398, 0.0081001017, -0...."


In [27]:
n_batch = 256
g = imgs.index % n_batch
for idx,chunk in tqdm(imgs.groupby(g), 'encoding imgs'):    
    img_dat=[openImg(img) for img in chunk['path']]
    img_dict={i:d for i,d in enumerate(img_dat) if d is not None}
    # img_dat=torch.Tensor(img_dat).to(device)
    encodings=img_model.encode(list(img_dict.values()))
    results=[encodings[list(img_dict.keys()).index(i),...] if i in img_dict.keys() else np.zeros(encodings.shape[-1]) for i in range(len(img_dat))]
    result_series=pd.Series(results)
    result_series.index=chunk.index
    imgs.loc[chunk.index,'img_embeddings']=result_series    

encoding imgs:   0%|          | 0/256 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  imgs.loc[chunk.index,'img_embeddings']=result_series


cannot identify image file 'data/Cliparts/01_Kate Hadfield/WeekendAtHome/khadfield_WeekendAtHome_birdhouse.png'
cannot identify image file 'data/Cliparts/01_Kate Hadfield/AVeryGermanChristmas/khadfield_AVeryGermanChristmas_adventcalendar2.png'
cannot identify image file 'data/Cliparts/01_Kate Hadfield/MyKindaPet2/khadfield_MyKindaPet2_spider.png'
cannot identify image file 'data/Cliparts/01_Kate Hadfield/FruitStand/khadfield_FruitStand_lemonslice2.png'
cannot identify image file 'data/Cliparts/01_Kate Hadfield/SolarSystem/khadfield_SolarSystem_moon3.png'


In [28]:
for idx,chunk in tqdm(imgs.groupby(g), 'encoding texts'):
    img_texts=list(chunk['text'])
    encodings=text_model.encode(img_texts)
    results=[encodings[i,...] for i in range(len(img_texts))]
    result_series=pd.Series(results)
    result_series.index=chunk.index
    imgs.loc[chunk.index,'txt_embeddings']=result_series

encoding texts:   0%|          | 0/256 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  imgs.loc[chunk.index,'txt_embeddings']=result_series


In [29]:
imgs.to_json('data/img_embedded_model_visual.json',orient='records')

In [30]:
imgs=pd.read_json('data/img_embedded_model_visual.json',orient='records')
imgs

Unnamed: 0,path,text,img_embeddings,txt_embeddings
0,data/Cliparts/01_Kate Hadfield/logo-cropped.png,data Cliparts 01 Kate Hadfield logo cropped png,"[-0.1224354059, 0.3337236047, -0.4869712591, 0...","[0.1174896359, -0.0331657864, 0.10817032310000..."
1,data/Cliparts/01_Kate Hadfield/EggHunting/khad...,data Cliparts 01 Kate Hadfield EggHunting khad...,"[0.061049770600000004, -0.1090632156, -0.38299...","[-0.14864604180000002, -0.0283475779, 0.067074..."
2,data/Cliparts/01_Kate Hadfield/EggHunting/khad...,data Cliparts 01 Kate Hadfield EggHunting khad...,"[-0.0441452377, 0.0552457534, -0.2363186777, 0...","[-0.1432264447, -0.0117265927, 0.0473659486, -..."
3,data/Cliparts/01_Kate Hadfield/EggHunting/khad...,data Cliparts 01 Kate Hadfield EggHunting khad...,"[0.3908630013, 0.1392904371, 0.0251698196, 0.2...","[-0.1174652576, 0.0058467556000000006, 0.06038..."
4,data/Cliparts/01_Kate Hadfield/EggHunting/khad...,data Cliparts 01 Kate Hadfield EggHunting khad...,"[0.1650562286, -0.14565433560000002, 0.0494911...","[-0.1221045852, 0.034576833200000004, 0.023413..."
...,...,...,...,...
3418,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...,data Cliparts 01 Kate Hadfield DoctorsAndNurse...,"[0.1290355176, -0.1939579695, -0.3952674568, 0...","[0.1815917343, 0.0183345601, 0.0043674563, -0...."
3419,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...,data Cliparts 01 Kate Hadfield DoctorsAndNurse...,"[0.1826951355, 0.2208450288, -0.23226381840000...","[0.13958533110000002, 0.0363522507, 0.03278619..."
3420,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...,data Cliparts 01 Kate Hadfield DoctorsAndNurse...,"[0.1671190709, 0.2088604271, -0.2348358035, 0....","[0.1449439079, 0.0388320014, 0.0345865935, -0...."
3421,data/Cliparts/01_Kate Hadfield/DoctorsAndNurse...,data Cliparts 01 Kate Hadfield DoctorsAndNurse...,"[0.1290355176, -0.1939579695, -0.3952674568, 0...","[0.1923192739, 0.0340455398, 0.0081001017, -0...."


In [31]:
def insert_into_db():
    global cliparts
    global cliparts_txt
    
    for i, rec in tqdm(imgs.iterrows(), total=imgs.shape[0]):
        d=rec.to_dict()
        if d['img_embeddings'] is None:
            print(d['path'], 'skipped!')
            continue
        cliparts.insert(d)
        cliparts_txt.insert(d)

    cliparts.load()
    cliparts_txt.load()

In [32]:
drop_collections()
setup_collections()
create_indices()
insert_into_db()

  0%|          | 0/3423 [00:00<?, ?it/s]