In [1]:
import torch
import torchvision.transforms as transforms
import torchvision
import os
import glob
import imageio
from collections import defaultdict
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange

In [2]:
transformers = transforms.Compose([transforms.ToTensor(),
                                  transforms.Resize((112, 112))])
dataset = torchvision.datasets.ImageFolder("/home/fred/datasets/caltech101/101_ObjectCategories/",
                                           transform=transformers)
len(dataset)

8733

In [3]:
dataset.imgs[9][0]

'/home/fred/datasets/caltech101/101_ObjectCategories/BACKGROUND_Google/image_0011.jpg'

In [None]:
img = imageio.imread(dataset.imgs[9][0])
img = transformers(img)
img.shape

In [None]:
'''
Remove black and white images in caltech101
'''
# folders = glob.glob("/home/fred/datasets/caltech101/101_ObjectCategories/*")
# for folder in folders:
#     img_paths = glob.glob(folder+'/*.jpg')
#     for im_path in img_paths:
#         img = imageio.imread(im_path)
#         if len(img.shape) != 3:
#             os.remove(im_path)

In [None]:
loader = torch.utils.data.DataLoader(dataset, batch_size=128)

In [None]:
imgs, labels = next(iter(loader))
imgs.shape

In [None]:
plt.imshow(imgs[9].permute(1,2,0))


In [None]:
model = torchvision.models.resnet50(pretrained=True)
model = torch.nn.Sequential(*list(model.children())[:-1])

In [None]:
out = model(imgs)
out.shape

In [None]:
features = []
model = model.cuda(0)
for imgs, labels in tqdm(loader):
    imgs = imgs.cuda(0)
    feat = model(imgs)
    features.append(feat.detach().cpu())
    del imgs

features = torch.cat(features).squeeze()

In [None]:
# path = "/home/fred/datasets/caltech101/feature_tensor.pt"
# with open(path, 'wb') as f:
#     torch.save(features, f)

In [None]:
path = "/home/fred/datasets/caltech101/feature_tensor.pt"
with open(path, 'rb') as f:
    features = torch.load(f)
features.shape

In [None]:
img_paths = []
folders = glob.glob("/home/fred/datasets/caltech101/101_ObjectCategories/*")
for folder in folders:
    for path in glob.glob(folder+"/*"):
        img_paths.append(path)

img_paths

In [7]:
class LSH:

    def __init__(self, hash_dim, batch=128):
        self.hash_dim = hash_dim
        self.batch = batch
        self.hash_dict = defaultdict(list)
        
        self.device = torch.device("cuda:0") \
            if torch.cuda.is_available() else torch.device("cpu")
        self.transformers = transforms.Compose([transforms.ToTensor(),
                                                transforms.Resize((112, 112))])
        
        model = torchvision.models.resnet50(pretrained=True)
        self.feature_extractor = torch.nn.Sequential(*list(model.children())[:-1]).to(self.device)
        self.feat_dim = 2048
        
        self.hyperplanes = torch.randn((self.feat_dim, self.hash_dim), device=self.device)
    
    def _imgToFeatures(self, img_tensors):
        return self.feature_extractor(img_tensors).squeeze_()

    def _toTuple(self, tensor):
        return tuple(tensor.tolist())

    def _getHashes(self, features):
        
        if len(features.shape) == 1:
            features.unsqueeze_(0)

        hashes = features @ self.hyperplanes  # (batch, hash dim)
        
        hashes[hashes < 0] = 0
        hashes[hashes > 0] = 1
        return hashes.to(torch.int8).cpu()

    def build(self, root):
        
        dataset = torchvision.datasets.ImageFolder(root, transform=self.transformers)
        loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch)
        
        all_hashes = []
        for img_batch, _ in tqdm(loader, desc="Hashing"):
            features = self._imgToFeatures(img_batch.to(self.device))
            hashes = self._getHashes(features)
            all_hashes.append(hashes)
        all_hashes = torch.cat(all_hashes, dim=0)

        
        B = all_hashes.shape[0]
        for i in trange(B, desc="Buidling table"):
            tupleHash = self._toTuple(all_hashes[i, :])
            self.hash_dict[tupleHash].append((dataset.imgs[i][0]))

        return self.hash_dict

    def query(self, q_paths):
        
        for q_path in q_paths:
            img = imageio.imread(q_path)
            img = self.transformers(img).unsqueeze_(0).to(self.device)
            
            feat = self._imgToFeatures(img)
            img_hash = self._getHashes(feat).squeeze()
            
            assert len(img_hash.shape) == 1, f"hash has wrong shape {img_hash.shape}"
            
            match_set = self.hash_dict[self._toTuple(img_hash)]
            
        return match_set
            

In [None]:
root = "/home/fred/datasets/caltech101/101_ObjectCategories/"
lsh = LSH(hash_dim=8)

hash_dict = lsh.build(root)

Hashing:   0%|          | 0/69 [00:00<?, ?it/s]

In [None]:
dataset.imgs[-8][0]

In [6]:
lsh.query([dataset.imgs[-8][0]])

TypeError: _getHashes() takes 2 positional arguments but 3 were given