# Image Clustering
This notebook clusters the dataset images using the pretrained DinoV2 model. The model is used to extract features from the images, which are then clustered using KMeans.

### Imports

In [None]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import os
from sklearn.cluster import KMeans
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_score
from sklearn.manifold import TSNE
import seaborn as sns
import pandas as pd
import numpy as np  
import plotly.express as px

### Load config

In [None]:
img_size = 224
datapath = r"D:\Database\animals\dataset"

### Feature Extraction

In [None]:
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14', pretrained=True)
model.cuda()
model

In [None]:
transform = transforms.Compose([
    transforms.Resize(int(img_size * 1.1)),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

# Iterate over all images and extract features
outputs = []
for root, dirs, files in os.walk(datapath):
    for p in tqdm(files):
        if not p.endswith('.jpg'):
            continue
        img = Image.open(os.path.join(root, p))
        img = transform(img).unsqueeze(0).cuda()
        with torch.no_grad():
            output = model(img)
            output = output.squeeze(0)
            output = output.cpu().numpy()
            outputs.append(output)

### How Many Clusters?

In [None]:
kmax = 250
sil = []
K = range(2, kmax+1)
for k in tqdm(K, desc='Finding best k', unit='k'):
  kmeans = KMeans(n_clusters = k, n_init='auto').fit(outputs)
  labels = kmeans.labels_
  sil.append(silhouette_score(outputs, labels, metric = 'cosine'))


best_k = sil.index(max(sil)) + 2 # offset due to k in [2, 250]
best_k

In [None]:
# Plot the silhouette scores
plt.figure(figsize=(8, 6))
plt.plot(K, sil)

## KMeans Clustering with best `k`

In [None]:
kmeans = KMeans(n_clusters=best_k, n_init='auto')
labels = kmeans.fit_predict(outputs)

### Visualize Clusters with t-SNE

In [None]:
# 2D t-SNE plot
tsne = TSNE(n_components=2, random_state=0)
X = tsne.fit_transform(np.array(outputs))
df = pd.DataFrame(X, columns=['x', 'y'])
df['label'] = labels

fig = px.scatter(df, x='x', y='y', color='label', color_discrete_sequence=px.colors.qualitative.G10)
fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y'), showlegend=False,    xaxis=dict(visible=False), yaxis=dict(visible=False),)
fig.show()

In [None]:
# 3D t-SNE plot
tsne = TSNE(n_components=3, random_state=0)
X = tsne.fit_transform(np.array(outputs))
df = pd.DataFrame(X, columns=['x', 'y', 'z'])
df['label'] = labels


In [None]:
import plotly.graph_objects as go
fig = px.scatter_3d(df, x='x', y='y', z='z', color='label', color_discrete_map="identity", template='plotly_white')
fig.update_traces(marker=dict(size=2))
fig.update_layout(scene = dict(showlegend=False, xaxis=dict(visible=False), yaxis=dict(visible=False),
        ))

x_eye = -1.25
y_eye = 2
z_eye = 0.5
fig.update_layout(
         title='Animation Test',
         width=1000,
         height=1000,
         scene_camera_eye=dict(x=x_eye, y=y_eye, z=z_eye),
         updatemenus=[dict(type='buttons',
                  showactive=False,
                  y=1,
                  x=0.8,
                  xanchor='left',
                  yanchor='bottom',
                  pad=dict(t=45, r=10),
                  buttons=[dict(label='Play',
                                 method='animate',
                                 args=[None, dict(frame=dict(duration=5, redraw=True), 
                                                             transition=dict(duration=0),
                                                             fromcurrent=True,
                                                             mode='immediate'
                                                            )]
                                            )
                                      ]
                              )
                        ]
)

def rotate_z(x, y, z, theta):
    w = x+1j*y
    return np.real(np.exp(1j*theta)*w), np.imag(np.exp(1j*theta)*w), z

frames=[]
for t in np.arange(0, 6.26, 0.1):
    xe, ye, ze = rotate_z(x_eye, y_eye, z_eye, -t)
    frames.append(go.Frame(layout=dict(scene_camera_eye=dict(x=xe, y=ye, z=ze))))
fig.frames=frames
plt.axis('off')
fig.show()

### Concluding Notes
The best determined number of clusters is deviating from the number of classes in the dataset (90). This could be due to the fact that the DinoV2 model was trained on a different dataset, which might have different features. The model might not be able to extract the features that are needed to cluster the images by their classes. Furthermore, there could be unsuitable samples in the dataset, which are not representative for their class. 