In [1]:
import torch
import torchvision.transforms as transforms
import torchvision
import os
import glob
import imageio
from collections import defaultdict
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange

In [2]:
transformers = transforms.Compose([transforms.ToTensor(),
                                  transforms.Resize((112,112))])
dataset = torchvision.datasets.ImageFolder("/home/fred/datasets/caltech101/101_ObjectCategories/", 
                                                              transform=transformers)
len(dataset)

8733

In [3]:
# folders = glob.glob("/home/fred/datasets/caltech101/101_ObjectCategories/*")
# for folder in folders:
#     img_paths = glob.glob(folder+'/*.jpg')
#     for im_path in img_paths:
#         img = imageio.imread(im_path)
#         if len(img.shape) != 3:
#             os.remove(im_path)

In [4]:
loader = torch.utils.data.DataLoader(dataset, batch_size=128)

In [5]:
imgs, labels = next(iter(loader))
imgs.shape

torch.Size([128, 3, 112, 112])

In [6]:
model = torchvision.models.resnet50(pretrained=True)
model = torch.nn.Sequential(*list(model.children())[:-1])

In [7]:
out = model(imgs)
out.shape

torch.Size([128, 2048, 1, 1])

In [8]:
features = []
model = model.cuda(0)
for imgs, labels in tqdm(loader):
    imgs = imgs.cuda(0)
    feat = model(imgs)
    features.append(feat.detach().cpu())
    del imgs

features = torch.cat(features).squeeze()

  0%|          | 0/69 [00:00<?, ?it/s]

In [16]:
# path = "/home/fred/datasets/caltech101/feature_tensor.pt"
# with open(path, 'wb') as f:
#     torch.save(features, f)

In [11]:
class LSH:
    
    def __init__(self, hash_dim):
        self.hash_dim = hash_dim
        self.batch = 128
        self.hash_dict = defaultdict(list)
        self.device = torch.device("cuda:0") \
                if torch.cuda.is_available() else torch.device('cpu')
    
    def _toTuple(self, tensor):
        return tuple(tensor.tolist())
    
    @torch.no_grad()
    def _getHashes(self, features, hyperplanes):
        
        B = features.shape[0]
        hash_ls = []
        for i in trange(0, B, self.batch):
            feature_chunk  = features[i:i+self.batch, :].to(self.device)
            hashes = feature_chunk @ hyperplanes
            # (batch, hash dim)
            hashes[hashes < 0] = 0
            hashes[hashes > 0] = 1
            hash_ls.append(hashes.cpu())
        
        hash_table = torch.cat(hash_ls).to(torch.int8)
        assert hash_table.shape == (B, self.hash_dim),\
                            f"hash table has wrong shape {hash_table.shape}"
        return hash_table
    
    def build(self, features):
        
        feat_dim = features.shape[1]
    
        self.hyperplanes = torch.randn((feat_dim, self.hash_dim),device=self.device)
        
        hash_table = self._getHashes(features, self.hyperplanes)
        
        
        B = hash_table.shape[0]
        for i in range(B):
            tupleHash = self._toTuple(hash_table[i,:])
            self.hash_dict[tupleHash].append(features[i,:])
        
        return self.hash_dict
    
    def query(self, q_feats):
        
        assert q_feats.shape[1] == self.hyperplanes.shape[0],\
        f"query features and hyperplane shape miss match {q_feats.shape[1], self.hyperplanes.shape[0]}"
        
        q_hashes= self._getHashes(q_feats, self.hyperplanes)
        
        out_feats = []
        for q_hash in q_hashes:
            tupleHash = self._toTuple(q_hash)
            out_feats.append(self.hash_dict[tupleHash])
        
        return out_feats
        
            
        

In [12]:
lsh = LSH(hash_dim=8)

hash_dict = lsh.build(features)

  0%|          | 0/69 [00:00<?, ?it/s]

In [13]:
len(hash_dict)

196

In [14]:
features.shape

torch.Size([8733, 2048])