In [41]:
import imghdr # builtin
from pathlib import Path
import sqlite3

import faiss
from PIL import Image
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
#from torchvision.datasets import ImageFolder
from torchvision.io import read_image
from transformers import CLIPProcessor, CLIPModel
from torchvision.transforms import ToTensor
from torchvision import transforms

In [115]:
# push the default CLIP transforms upstream to the data loader, 
# should permit batching iamges since we'll have a resize operation
# https://huggingface.co/transformers/model_doc/clip.html#transformers.CLIPFeatureExtractor
clip_transforms = transforms.Compose([
    #transforms.ToPILImage(),
    transforms.Resize(size=224),
    transforms.CenterCrop(224),
    #transforms.ToPILImage(),
    transforms.ToTensor(),  # PIL -> ToTensor :: [0,255] -> [0.,1.]
    transforms.Normalize([0.48145466, 0.4578275, 0.40821073], 
                         [0.26862954, 0.26130258, 0.27577711]),
    
])

In [111]:
clip_transforms(im)


tensor([[[-0.0113, -0.0988, -0.1134,  ...,  0.5873,  0.6749,  0.8209],
         [-0.0259, -0.0405, -0.0696,  ...,  0.5143,  0.5727,  0.6165],
         [-0.0259, -0.0259, -0.0988,  ...,  0.9084,  0.7187,  0.6311],
         ...,
         [ 0.3099, -0.0113,  0.1785,  ..., -0.0988, -0.1280,  0.0617],
         [ 0.3975,  0.5289,  0.2077,  ..., -0.1280,  0.1493, -0.3470],
         [-0.2302,  0.0179,  0.3683,  ...,  0.5435,  0.7041,  0.1931]],

        [[ 0.6191,  0.5141,  0.5441,  ...,  0.8893,  1.0093,  1.1744],
         [ 0.6041,  0.5891,  0.5891,  ...,  0.8292,  0.9043,  0.9343],
         [ 0.6041,  0.6041,  0.5591,  ...,  1.2344,  1.0243,  0.9643],
         ...,
         [-0.2363, -0.5665, -0.3714,  ..., -0.6115, -0.6565, -0.4614],
         [-0.1613, -0.0262, -0.3414,  ..., -0.6415, -0.3714, -0.8816],
         [-0.7316, -0.5515, -0.2063,  ...,  0.0488,  0.2139, -0.3114]],

        [[ 1.6482,  1.5771,  1.5913,  ...,  1.5202,  1.6055,  1.7620],
         [ 1.6340,  1.6198,  1.6340,  ...,  1

In [72]:
import torchvision
dir(torchvision.transforms.InterpolationMode)

['BICUBIC',
 'BILINEAR',
 'BOX',
 'HAMMING',
 'LANCZOS',
 'NEAREST',
 '__class__',
 '__doc__',
 '__members__',
 '__module__']

In [2]:
class CLIP(torch.nn.Module):
    def __init__(self, model_string="openai/clip-vit-base-patch32"):
        super().__init__()
        self._model_string = model_string
        self.model = CLIPModel.from_pretrained(model_string)
        self.processor = CLIPProcessor.from_pretrained(model_string)
        self.model.eval()
    def project_images(self, images, normalize=True):
        imgs = self.processor(images=images, return_tensors="pt", padding=True)
        feats = self.model.get_image_features(**imgs)
        if normalize:
            feats = self.normalize(feats)
        return feats
    def project_texts(self, texts, normalize=True):
        txts = self.processor(text=texts, return_tensors="pt", padding=True)
        feats = self.model.get_text_features(**txts)
        if normalize:
            feats = self.normalize(feats)
        return feats
    def normalize(self, x):
        return x / x.norm(dim=-1, keepdim=True)

In [116]:
# https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files
# NB: might be able to use lightning datamodule to parallelize data processing
class RecusiveImagesPath(Dataset):
    def __init__(self, root='sample_data/images', transforms=clip_transforms):
        self.root = root
        self.get_image_paths()
        self.transforms = transforms
    def get_image_paths(self):
        self.img_paths = []
        for path_obj in Path(self.root).glob('*'):
            if imghdr.what(path_obj) is not None:
                self.img_paths.append(path_obj)
    def __len__(self):
        return len(self.img_paths)
    def __getitem__(self, idx):
        path = str(self.img_paths[idx])
        im = Image.open(path)
        #return Image.open(path), path
        #try:
        #    im = read_image(path)
        #except RuntimeError:
        #    #tensorize = ToTensor()
        #    #im_pil = Image.open(path)
        #    #im = tensorize(im_pil)
        #    im = torch.zeros(3,10,10)
        if self.transforms is not None:
            im = self.transforms(im)
        return im, path
        
dataset = RecusiveImagesPath()
[(im,p) for im,p in dataset]
im = dataset[0][0]
im

tensor([[[-0.0259, -0.0696, -0.0988,  ...,  0.5873,  0.6749,  0.7771],
         [-0.0259, -0.0550, -0.0842,  ...,  0.5727,  0.6165,  0.6457],
         [-0.0550, -0.0405, -0.0696,  ...,  0.8209,  0.6895,  0.5873],
         ...,
         [ 0.3391,  0.3391,  0.3829,  ..., -0.0259,  0.0033, -0.0696],
         [ 0.2661,  0.2953,  0.3245,  ...,  0.0763,  0.0617, -0.0405],
         [-0.0550,  0.1493,  0.3537,  ...,  0.4121,  0.3391,  0.2077]],

        [[ 0.6041,  0.5591,  0.5591,  ...,  0.9043,  1.0093,  1.1294],
         [ 0.6041,  0.5891,  0.5741,  ...,  0.8893,  0.9493,  0.9793],
         [ 0.5741,  0.5891,  0.5741,  ...,  1.1444,  1.0093,  0.9193],
         ...,
         [-0.2063, -0.2213, -0.1613,  ..., -0.5515, -0.5215, -0.5965],
         [-0.2663, -0.2513, -0.2363,  ..., -0.4464, -0.4614, -0.5665],
         [-0.5515, -0.3864, -0.2063,  ..., -0.0862, -0.1613, -0.2963]],

        [[ 1.6340,  1.6055,  1.6055,  ...,  1.5344,  1.6055,  1.7193],
         [ 1.6340,  1.6198,  1.6198,  ...,  1

In [4]:
clip = CLIP()

In [13]:
index_fname = 'sample_index.faissindex'
index=faiss.read_index(index_fname)

In [14]:
index.ntotal

8

In [61]:
images = [im for im,p in dataset]
im_embed = clip.project_images(images)
im_embed.shape, im_embed.requires_grad

(torch.Size([8, 512]), True)

In [35]:
texts = ["my baloney has a first name", "tropical beach", "beach", "sand beach with palm trees on a nice day"]
#txts = clip.processor(text=texts, return_tensors="pt", padding=True)
txt_embed = clip.project_texts(texts)
txt_embed.shape, txt_embed.norm(dim=-1)

(torch.Size([4, 512]),
 tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<CopyBackwards>))

In [62]:
sims = txt_embed @ im_embed.T
#torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logit_scale = clip.model.logit_scale.exp()
sims, sims * logit_scale

(tensor([[0.1683, 0.1853, 0.1770, 0.1981, 0.2090, 0.1492, 0.1972, 0.1972],
         [0.2623, 0.2536, 0.1461, 0.1679, 0.1688, 0.1912, 0.1985, 0.1701],
         [0.2609, 0.2563, 0.1553, 0.1691, 0.1775, 0.1463, 0.1899, 0.1892],
         [0.2747, 0.2699, 0.1176, 0.1454, 0.1415, 0.1525, 0.1747, 0.1144]],
        grad_fn=<MmBackward>),
 tensor([[16.8332, 18.5312, 17.7043, 19.8129, 20.9047, 14.9190, 19.7199, 19.7210],
         [26.2350, 25.3636, 14.6130, 16.7930, 16.8836, 19.1164, 19.8490, 17.0067],
         [26.0949, 25.6321, 15.5259, 16.9058, 17.7491, 14.6251, 18.9940, 18.9170],
         [27.4703, 26.9878, 11.7571, 14.5414, 14.1530, 15.2501, 17.4696, 11.4387]],
        grad_fn=<MulBackward0>))

In [26]:
clip.model.logit_scale.exp() # uh.. ok

tensor(100.0000, grad_fn=<ExpBackward>)

In [63]:
#F.softmax(sims, dim=0)
im_embed.shape

torch.Size([8, 512])

In [70]:
im_embed.norm(dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<CopyBackwards>)

In [69]:
# https://github.com/facebookresearch/faiss/wiki/Getting-started
n_feats = 512 # = im_embed.shape
index = faiss.IndexFlatIP(n_feats)

In [71]:
x = im_embed.detach()
index.add(x.numpy())

In [66]:
index.is_trained, index.ntotal

(True, 8)

In [72]:
k = 3
x = txt_embed.detach().numpy()
D, I = index.search(x, k)
D,I

(array([[0.20904651, 0.19812861, 0.19721013],
        [0.26234967, 0.253636  , 0.1984902 ],
        [0.2609494 , 0.25632125, 0.1899397 ],
        [0.27470297, 0.26987785, 0.17469627]], dtype=float32),
 array([[4, 3, 7],
        [0, 1, 6],
        [0, 1, 6],
        [0, 1, 6]], dtype=int64))

In [73]:
faiss.write_index(index, 'sample_index.faissindex')

In [68]:
[
    "my baloney has a first name", 
    "tropical beach", 
    "beach", 
    "sand beach with palm trees on a nice day"]
for i, p in enumerate(dataset.img_paths):
    print(i,p)

0 sample_data\images\beach1.jpg
1 sample_data\images\beach2.jpg
2 sample_data\images\book1.jpg
3 sample_data\images\dog1.png
4 sample_data\images\dog2.png
5 sample_data\images\landscape1.jpg
6 sample_data\images\portrait1.jpg
7 sample_data\images\portrait2.jpg


In [65]:
%%time
# Let's try loading a bigger dataset
# This folder has about 5k images
fpath = r"C:\Users\shagg\Pictures\phone backup"
dataset = RecusiveImagesPath(root=fpath) 
# takes a lot of time just to build this object...
# I'm guessing the imghdr call slows things down a bunch
# trivially fast on reload, so at least we're caching
# Nope, faster after rebuilding the class too. weird.

Wall time: 369 ms


In [67]:
import gc
gc.collect()

77

In [126]:
%%time

# https://stackoverflow.com/questions/42462431/oserror-broken-data-stream-when-reading-image-file/47958486
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

# How long does it take to iterate through this? How does batch size impact this?
# What is the limit on batch size for a forward pass through the model?
# Can I calibrate this somehow? What about recording paths of images that we've processed already?

fpath = r"C:\Users\shagg\Pictures\phone backup"
dataset = RecusiveImagesPath(root=fpath)

#img_loader = DataLoader(dataset, batch_size=64, shuffle=False) # 9min 2s
#img_loader = DataLoader(dataset, batch_size=1, shuffle=False) # 12min43s w/o transforms
#img_loader = DataLoader(dataset, batch_size=2, shuffle=False)
#img_loader = DataLoader(dataset, batch_size=4, shuffle=False) # 8m w transforms (necessary for elevated batch size)
#img_loader = DataLoader(dataset, batch_size=8, shuffle=False) # 8min 28s
img_loader = DataLoader(dataset, batch_size=16, shuffle=False) # 8min 32s
# trying again with larger batch after adding transforms

for i, _ in enumerate(img_loader):
    continue
# yiiikes... 12min43s just to iterate through the files and load them into tensors... yeesh.
# seriously doubt I'm multiprocessing here. Possibly low batch size kills multiproc as option?


# TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; 
# found <class 'PIL.JpegImagePlugin.JpegImageFile'>
#
# fixed above error with torchvision.io.read_image
# Now we've got an uncollated shape issue
#   RuntimeError: stack expects each tensor to be equal size, but got [3, 3024, 4032] at entry 0 and [3, 2268, 4032] at entry 16
# could add a collate fn per: https://discuss.pytorch.org/t/dataloader-gives-stack-expects-each-tensor-to-be-equal-size-due-to-different-image-has-different-objects-number/91941
# ... but then the padding would probably fuck up the CLIP featurization
# could go the other way with a crop, but who even knows what I'd be cropping. How even does the model address the input resolution?
# whatever, let's see how long it takes with a batch_size of 1
#
# RuntimeError: Unsupported marker type 0xa3
# 'C:\\Users\\shagg\\Pictures\\phone backup\\20180913_145729.jpg' - panoramic, 35mb. 
#
# crushing error by returning tensor of zeros

Wall time: 8min 32s


In [118]:
# OSError: broken data stream when reading image file
i, _[1] # the fuck? Why is path 4 files? oh right, batched

(46,
 ('C:\\Users\\shagg\\Pictures\\phone backup\\20180913_144244.jpg',
  'C:\\Users\\shagg\\Pictures\\phone backup\\20180913_144258.jpg',
  'C:\\Users\\shagg\\Pictures\\phone backup\\20180913_144303.jpg',
  'C:\\Users\\shagg\\Pictures\\phone backup\\20180913_144648.jpg'))

In [122]:
# OSError: broken data stream when reading image file
i, _[1] # 
# looks like the issue is with 'C:\\Users\\shagg\\Pictures\\phone backup\\20180913_145648.jpg' 

(94,
 ('C:\\Users\\shagg\\Pictures\\phone backup\\20180913_145647.jpg',
  'C:\\Users\\shagg\\Pictures\\phone backup\\20180913_145648.jpg'))

In [58]:
im_test = Image.open('C:\\Users\\shagg\\Pictures\\phone backup\\20180913_145729.jpg')
#im_test # PIL at least can read the last frame of the panorama.
ToTensor()(im_test)

tensor([[[0.5255, 0.5216, 0.5255,  ..., 0.1961, 0.2078, 0.1333],
         [0.5176, 0.5176, 0.5255,  ..., 0.2078, 0.1961, 0.1333],
         [0.5216, 0.5255, 0.5333,  ..., 0.1686, 0.1529, 0.1294],
         ...,
         [0.9804, 0.9765, 0.9765,  ..., 0.0275, 0.0392, 0.0392],
         [0.9765, 0.9765, 0.9765,  ..., 0.0235, 0.0353, 0.0314],
         [0.9765, 0.9725, 0.9765,  ..., 0.0157, 0.0314, 0.0353]],

        [[0.6196, 0.6157, 0.6118,  ..., 0.1569, 0.1686, 0.0941],
         [0.6118, 0.6118, 0.6118,  ..., 0.1686, 0.1569, 0.0941],
         [0.6078, 0.6118, 0.6196,  ..., 0.1294, 0.1137, 0.0902],
         ...,
         [0.9765, 0.9725, 0.9725,  ..., 0.0235, 0.0353, 0.0353],
         [0.9725, 0.9725, 0.9725,  ..., 0.0196, 0.0314, 0.0275],
         [0.9725, 0.9686, 0.9725,  ..., 0.0118, 0.0275, 0.0314]],

        [[0.7608, 0.7569, 0.7569,  ..., 0.1608, 0.1725, 0.0980],
         [0.7529, 0.7529, 0.7569,  ..., 0.1725, 0.1608, 0.0980],
         [0.7529, 0.7569, 0.7647,  ..., 0.1333, 0.1176, 0.

In [57]:
ToTensor()(im_test)

tensor([[[0.5255, 0.5216, 0.5255,  ..., 0.1961, 0.2078, 0.1333],
         [0.5176, 0.5176, 0.5255,  ..., 0.2078, 0.1961, 0.1333],
         [0.5216, 0.5255, 0.5333,  ..., 0.1686, 0.1529, 0.1294],
         ...,
         [0.9804, 0.9765, 0.9765,  ..., 0.0275, 0.0392, 0.0392],
         [0.9765, 0.9765, 0.9765,  ..., 0.0235, 0.0353, 0.0314],
         [0.9765, 0.9725, 0.9765,  ..., 0.0157, 0.0314, 0.0353]],

        [[0.6196, 0.6157, 0.6118,  ..., 0.1569, 0.1686, 0.0941],
         [0.6118, 0.6118, 0.6118,  ..., 0.1686, 0.1569, 0.0941],
         [0.6078, 0.6118, 0.6196,  ..., 0.1294, 0.1137, 0.0902],
         ...,
         [0.9765, 0.9725, 0.9725,  ..., 0.0235, 0.0353, 0.0353],
         [0.9725, 0.9725, 0.9725,  ..., 0.0196, 0.0314, 0.0275],
         [0.9725, 0.9686, 0.9725,  ..., 0.0118, 0.0275, 0.0314]],

        [[0.7608, 0.7569, 0.7569,  ..., 0.1608, 0.1725, 0.0980],
         [0.7529, 0.7529, 0.7569,  ..., 0.1725, 0.1608, 0.0980],
         [0.7529, 0.7569, 0.7647,  ..., 0.1333, 0.1176, 0.