In [None]:
 import sys
 sys.version

'3.7.13 (default, Mar 16 2022, 17:37:17) \n[GCC 7.5.0]'

In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone


In [None]:
!pip install timm annoy



In [None]:
import pandas as pd
import numpy as np
import torchvision
import requests
import random
import torch
from torch import nn
from torchvision import transforms
from google.colab import drive
from transformers import (
    ViTForImageClassification, ViTModel, ViTFeatureExtractor,
    AutoModelForImageClassification, BeitFeatureExtractor, ViTMAEModel,
    DeiTFeatureExtractor, DeiTModel, DetrFeatureExtractor, DetrForSegmentation,
    AutoFeatureExtractor, ViTMAEForPreTraining,
    DeiTForImageClassificationWithTeacher, ImageGPTForCausalImageModeling,
    ImageGPTFeatureExtractor, ImageGPTForImageClassification,
)

from PIL import Image
from annoy import AnnoyIndex
from tqdm import tqdm
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from collections import defaultdict

np.random.seed(42)

In [None]:
drive.mount('/gdrive')
!ls '/gdrive/My Drive/cse6242_project/Data'
!unzip -q '/gdrive/My Drive/cse6242_project/Data/celeba'

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).
accuracy_vs_num_images_18k.csv	celeba.zip  identity_CelebA.txt
replace img_align_celeba/img_align_celeba/000001.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [None]:
class Identity(nn.Module):
    # adds a forward layer to the backbone
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

def init_model(identifier):
    print(identifier)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if identifier == 'microsoft/beit-base-patch16-224-pt22k-ft22k':
        feature_extractor = BeitFeatureExtractor.from_pretrained(identifier)
        model = AutoModelForImageClassification.from_pretrained(identifier)
        dim = 768
    elif identifier == 'google/vit-base-patch16-224':
        feature_extractor = ViTFeatureExtractor.from_pretrained(identifier)
        model = ViTForImageClassification.from_pretrained(identifier)
        dim = 768
    elif identifier == 'facebook/vit-mae-base':
        feature_extractor = AutoFeatureExtractor.from_pretrained(identifier)
        model = ViTMAEForPreTraining.from_pretrained(identifier)
        dim = 196
    elif identifier == 'facebook/deit-base-distilled-patch16-224':
        feature_extractor = DeiTFeatureExtractor.from_pretrained(identifier)
        model = DeiTForImageClassificationWithTeacher.from_pretrained(identifier)
        dim = 1000
    elif identifier == 'facebook/detr-resnet-50-panoptic':
        feature_extractor = DetrFeatureExtractor.from_pretrained(identifier)
        model = DetrForSegmentation.from_pretrained(identifier)
        dim = 768
    elif identifier == 'openai/imagegpt-small':
        feature_extractor = ImageGPTFeatureExtractor.from_pretrained(identifier)
        model = ImageGPTForImageClassification.from_pretrained(identifier) 
        dim = 2
    elif identifier == 'openai/imagegpt-large':
        feature_extractor = ImageGPTFeatureExtractor.from_pretrained(identifier)
        model = ImageGPTForCausalImageModeling.from_pretrained(identifier) 
        dim = 2

    model.classifier = Identity()
    model.eval()
    model.to(device)

    return model, feature_extractor, dim

def prepare_dataset(n_identities=None):
    print('loading data...')
    identities = pd.read_csv("/gdrive/My Drive/cse6242_project/Data/identity_CelebA.txt", sep=" ", header=None)
    identities.rename(columns={0: "file", 1: "identity"}, inplace=True) 

    if n_identities is not None:
        identity_selection = np.random.choice(identities.identity.unique(), n_identities)
    else:
        identity_selection = identities.identity.unique()

    df = identities[identities.identity.isin(identity_selection)].reset_index()
    return df, identity_selection

def build_annoy_index(model, feature_extractor, df, dim=768, batch_size=25):
    print('building annoy index...')
    idx_to_identity = df.to_dict('index')
    identity_to_idx = defaultdict(list) 
    index = AnnoyIndex(dim, 'euclidean')
    index.on_disk_build("on_disk_index.ann")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    image_paths, keys = [], []
    for k, v in idx_to_identity.items():
        identity_to_idx[v["identity"]].append(k)
        image_paths.append("img_align_celeba/img_align_celeba/" + v["file"])
        keys.append(k)

    # batched computation
    print('num images:', len(image_paths))
    for i in tqdm(range(0, len(image_paths), batch_size)):

        images = []
        for path in image_paths[i:i+batch_size]:
            images.append(Image.open(path))
        batch_im, batch_keys = images, keys[i:i+batch_size]
        batch_encodings = feature_extractor(images=batch_im, return_tensors="pt")
        batch_pixel_values = batch_encodings['pixel_values'].to(device)
        batch_outputs = model(batch_pixel_values).logits
        for i, embedding in enumerate(batch_outputs):
            index.add_item(batch_keys[i], embedding.squeeze())

    index.build(128)
    index.save('celeba.ann')

    return index, idx_to_identity, identity_to_idx

def retrival_accuracy(base_idx, image_idx, idx_to_identity, identity_to_idx):
    true_identity = idx_to_identity[base_idx]["identity"]
    n_true_matches = len(identity_to_idx[true_identity])

    hits = []

    for idx in tqdm(image_idx):
        idx_identity = idx_to_identity[idx]["identity"]

        if idx_identity == true_identity:
            hits.append(True)
        else:
            hits.append(False)

    return sum(hits) / min(len(image_idx), n_true_matches)

In [None]:
model_identifier = 'microsoft/beit-base-patch16-224-pt22k-ft22k'

results = defaultdict(list)
df, identity_selection = prepare_dataset()


model, feature_extractor, dim = init_model(model_identifier)
index, idx_to_identity, identity_to_idx = build_annoy_index(model, feature_extractor, df, 
                                                            dim, batch_size=32)

loading data...
microsoft/beit-base-patch16-224-pt22k-ft22k


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


building annoy index...
num images: 202599


100%|██████████| 6332/6332 [1:24:10<00:00,  1.25it/s]


In [None]:
torch.cuda.empty_cache()

In [None]:
!cp on_disk_index.ann "/gdrive/My Drive/cse6242_project/Data"