In [1]:
from bokeh.plotting import ColumnDataSource, figure, output_file, show
import numpy as np 
import pandas as pd
import bokeh.io
from annoy import AnnoyIndex
from sklearn.decomposition import PCA
from PIL import Image
import requests
from tqdm import tqdm
from collections import defaultdict
from transformers import BeitFeatureExtractor
from sklearn.manifold import TSNE
import plotly.graph_objects as go
bokeh.io.output_notebook()

In [2]:
#Grab Identities and necessary data files 

def init(N_IDENTITIES = 15, random_seed = 42):
    
    np.random.seed(random_seed)
    identities = pd.read_csv("data/identity_CelebA.txt", sep=" ", header=None)
    identities.rename(columns={0: "file", 1: "identity"}, inplace=True)
    
    #select indneities for plotting 
    
    identity_selection = np.random.choice(identities.identity.unique(), N_IDENTITIES)
    files_to_load = identities[identities.identity.isin(identity_selection)].reset_index()
    
    return files_to_load

files_to_load = init(N_IDENTITIES = 25)

In [3]:
from transformers import ViTForImageClassification
from transformers import AutoModelForImageClassification
import torch
from torch import nn
import torchvision
from torchvision import transforms


model_identifier = 'microsoft/beit-base-patch16-224-pt22k-ft22k'


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

latentspace = AutoModelForImageClassification.from_pretrained(model_identifier)
latentspace.classifier = Identity()
latentspace.eval()
latentspace.to(device)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


BeitForImageClassification(
  (beit): BeitModel(
    (embeddings): BeitEmbeddings(
      (patch_embeddings): PatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): BeitEncoder(
      (layer): ModuleList(
        (0): BeitLayer(
          (attention): BeitAttention(
            (attention): BeitSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=False)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (relative_position_bias): BeitRelativePositionBias()
            )
            (output): BeitSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (interme

In [4]:
#create index

def annoy(files_to_load, DIM = 768, index_build = 128):
    
    feature_extractor = BeitFeatureExtractor.from_pretrained(model_identifier)
    idx_to_identity = files_to_load.to_dict('index')
    identity_to_idx = defaultdict(list)

    for k, v in idx_to_identity.items():
        identity_to_idx[v["identity"]].append(k)
        
    index = AnnoyIndex(DIM, 'euclidean')


    #Build index 
    
    for k, v in tqdm(idx_to_identity.items()):
        im = Image.open("img_align_celeba/" + v["file"])
        encoding = feature_extractor(images=im, return_tensors="pt")
        pixel_values = encoding['pixel_values'].to(device)
        outputs = latentspace(pixel_values)
        embedding = outputs.logits
        index.add_item(k, embedding.squeeze())
        
    index.build(index_build)   
    
    return index, idx_to_identity

index, idx_to_identity = annoy(files_to_load)


100%|█████████████████████████████████████████| 528/528 [04:44<00:00,  1.86it/s]


In [5]:
def reduction(index, idx_to_identity, random_state = 6242, technique = 'pca', n_components = 2):
    
    X = [index.get_item_vector(i) for i in range(len(idx_to_identity))]
    
    #run reduction on n_components =2 from the data 
    
    if technique == 'pca':
        points = PCA(n_components=n_components, random_state=random_state)
        points.fit(X)
        result=pd.DataFrame(points.transform(X), columns=['PCA%i' % i for i in range(n_components)])
        
    else:
        points = TSNE(n_components=n_components, verbose=1, random_state=random_state)
        result=pd.DataFrame(points.fit_transform(X), columns=['TSNE%i' % i for i in range(n_components)])
    
    return result 

result_2d_pca = reduction(index ,idx_to_identity, random_state = 6242, technique = 'pca', n_components = 2)
results_2d_tsne = reduction(index ,idx_to_identity, random_state = 6242, technique = 'tsne', n_components = 2)
result_3d_pca = reduction(index ,idx_to_identity, random_state = 6242, technique = 'pca', n_components = 3)
results_3d_tsne = reduction(index ,idx_to_identity, random_state = 6242, technique = 'tsne', n_components = 3)




[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 528 samples in 0.001s...
[t-SNE] Computed neighbors for 528 samples in 0.021s...
[t-SNE] Computed conditional probabilities for sample 528 / 528
[t-SNE] Mean sigma: 7.762428
[t-SNE] KL divergence after 250 iterations with early exaggeration: 74.077110
[t-SNE] KL divergence after 1000 iterations: 0.760180
[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 528 samples in 0.000s...
[t-SNE] Computed neighbors for 528 samples in 0.017s...
[t-SNE] Computed conditional probabilities for sample 528 / 528
[t-SNE] Mean sigma: 7.762428




[t-SNE] KL divergence after 250 iterations with early exaggeration: 119.430634
[t-SNE] KL divergence after 1000 iterations: 2.018383


In [6]:
def plot_2d(results, files_to_load, filename = 'plots/pca_2D_interactive.html', sample_size = 350):
    
    #results = results.sample(sample_size)
    
    output_file(filename)
    
    #Build interactive bokeh_plot
    imgs = []

    for i,v in enumerate(files_to_load['index']):
        img = ("img_align_celeba/" + files_to_load["file"][i])
        imgs.append(img)
    
    if 'pca' in filename:
        title = '2D PCA Interactive Plot with Images'
        source = ColumnDataSource(data=dict(
            x=list(results['PCA0']),
            y=list(results['PCA1']),
            desc=['Image:' + str(i) for i in files_to_load['index']],
            imgs=imgs
                ))
    else:
        title = '2D TSNE Interactive Plot with Images'
        source = ColumnDataSource(data=dict(
            x=list(results['TSNE0']),
            y=list(results['TSNE1']),
            desc=['Image:' + str(i) for i in files_to_load['index']],
            imgs=imgs
                ))

    TOOLTIPS = """
    <div>
        <div>
            <img
                src="@imgs" height="42" alt="@imgs" width="42"
                style="float: left; margin: 0px 15px 15px 0px;"
                border="2"
            ></img>
        </div>
        <div>
            <span style="font-size: 17px; font-weight: bold;">@desc</span>
            <span style="font-size: 15px; color: #966;">[$index]</span>
        </div>
        <div>
            <span style="font-size: 15px;">Location</span>
            <span style="font-size: 10px; color: #696;">($x, $y)</span>
        </div>
    </div>
            """
    
    p = figure(width=900, height=900, tooltips=TOOLTIPS,
               title=title, align = 'center')

    p.circle('x', 'y', size=15, source=source)

    show(p)
    
plot_2d(result_2d_pca, files_to_load, filename = 'plots/pca_2D_interactive.html')

In [7]:
plot_2d(results_2d_tsne, files_to_load, filename = 'plots/tsne_2D_interactive.html')

In [11]:
def plot_3d(technique = 'pca'):
    
    if technique == 'pca':
        data = reduction(index, idx_to_identity, random_state = 6242, technique = 'pca', n_components = 3)
        fig = go.Figure(data=[go.Scatter3d(x=data['PCA0'], y=data['PCA1'], z=data['PCA2'],
                                   mode='markers')])
        fig.update_layout(title_text='3D visualization: Principal Component Analysis', title_x=0.5, title_y = 0.85,
                    scene = dict(
                    xaxis = dict(title = 'PCA0'),
                     yaxis = dict(title = 'PCA1'),
                     zaxis = dict(title = 'PCA2')),
                      width = 800,
                      height = 750)
    
    else:
        
        data = reduction(index, idx_to_identity, random_state = 6242, technique = 'tsne', n_components = 3)
        fig = go.Figure(data=[go.Scatter3d(x=data['TSNE0'], y=data['TSNE1'], z=data['TSNE2'],
                                   mode='markers')])
        fig.update_layout(title_text='3D visualization: TSNE', title_x=0.5, title_y = 0.85,
                    scene = dict(
                    xaxis = dict(title = 'PCA0'),
                     yaxis = dict(title = 'PCA1'),
                     zaxis = dict(title = 'PCA2')),
                     width = 800,
                     height = 750)
    fig.show()
    
plot_3d(technique = 'pca')

In [12]:
plot_3d(technique = 'tsne')


The default initialization in TSNE will change from 'random' to 'pca' in 1.2.


The default learning rate in TSNE will change from 200.0 to 'auto' in 1.2.



[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 528 samples in 0.001s...
[t-SNE] Computed neighbors for 528 samples in 0.021s...
[t-SNE] Computed conditional probabilities for sample 528 / 528
[t-SNE] Mean sigma: 7.762428
[t-SNE] KL divergence after 250 iterations with early exaggeration: 119.430634
[t-SNE] KL divergence after 1000 iterations: 2.018383
