In [1]:
import os
import glob
import random
from collections import Counter
import umap


import cv2
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.notebook import tqdm


from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans, DBSCAN
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GAE
import torchvision.models as models
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, random_split
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from torch.nn import Linear

  Referenced from: <59E7CF6E-B8F0-3584-A1A7-85D47809EB30> /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch_scatter/_version_cpu.so
  Expected in:     <772DF335-D7CB-318F-A275-48A16B0A0CA8> /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/lib/libtorch_cpu.dylib


On utilise resnet50 (donc pré entrainé) pour avoir les features.

In [None]:
resnet50 = models.resnet50()
resnet50.load_state_dict(torch.load('chemin/vers/resnet50-0676ba61.pth'))
resnet50.fc = torch.nn.Identity()
resnet50.eval()



transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def extract_features(image_path): # renvoie une tenseur de taille 1,2048 avec les features d'entrée de la fc de resnet50
    img = Image.open(image_path).convert('RGB')
    img_t = transform(img).unsqueeze(0)
    with torch.no_grad():
        features = resnet50(img_t)
    return features.squeeze().numpy()



Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /Users/celio/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:997)>

Pour l'instant on en selectionne que 1000, sinon ça prend 2h (vraiment)

In [None]:
file_paths = glob.glob('./wikiart/wikiart/*.jpg')
random.shuffle(file_paths)
file_paths_subset = file_paths[:1000]

embeddings = np.array([extract_features(fp) for fp in tqdm(file_paths_subset)]) # tenseur de taille N,2048 contenant pour chaque image les features associées
painters = [os.path.basename(fp).split('_')[0] for fp in file_paths_subset] # tenseur de taille N,1 contenant le nom du peintre associé à chaque tableau

print(embeddings.shape)

In [None]:
painter_to_indices = {}  # on va regarder quels indices correspondent à des tableau du même peintre
for idx, painter in enumerate(painters):
    if painter not in painter_to_indices:
        painter_to_indices[painter] = []
    painter_to_indices[painter].append(idx)

edge_index = []      # pour les edges du graph en init
for indices in painter_to_indices.values():
    for i in indices:
        for j in indices:
            if i != j:
                edge_index.append([i, j])

edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
x = torch.tensor(embeddings, dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
data.x = F.normalize(data.x, p=2, dim=-1)

Features des nœuds (x) + Arêtes (edge_index) -> GCNEncoder -> Embeddings -> Produit scalaire -> Prédiction des arêtes

In [None]:
class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv2 = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

input_dim = data.x.shape[1]
model = GAE(GCNEncoder(input_dim, 128))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # lr par défaut



def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x, data.edge_index)  # forward du GCNencodeur avec l'initialisation choisie au dessus
    loss = model.recon_loss(z, data.edge_index) # voir pour prendre en compte aussi l'init ?
    loss.backward()
    optimizer.step()
    return loss.item()


# print(data.edge_index.size())

Entrainement du model

In [None]:
model.train()
for epoch in range(1, 301): # epoch au harsar pour l'instant
    loss = train()
    if epoch % 50 == 0:
        print(f'Epoch {epoch}, Loss: {loss:.4f}')

In [None]:
model.eval()
z = model.encode(data.x, data.edge_index).detach().cpu().numpy()

kmeans = KMeans(n_clusters=30, random_state=0)
labels = kmeans.fit_predict(z)

def plot_cluster_images(file_paths, labels, cluster_id, n_images=5):
    indices = np.where(labels == cluster_id)[0]
    selected_indices = np.random.choice(indices, min(len(indices), n_images), replace=False)

    fig, axes = plt.subplots(1, len(selected_indices), figsize=(15, 4))
    fig.suptitle(f"Cluster {cluster_id}", fontsize=16)

    for ax, idx in zip(axes, selected_indices):
        img = mpimg.imread(file_paths[idx])
        ax.imshow(img)
        ax.axis('off')
        ax.set_title(os.path.basename(file_paths[idx]), fontsize=9)

    plt.tight_layout()
    plt.show()

for cluster_id in range(7):
    plot_cluster_images(file_paths_subset, labels, cluster_id, n_images=8)

In [None]:
import umap
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

# Réduction en 2D
reducer = umap.UMAP(random_state=42)
z_2d = reducer.fit_transform(z)

# Scatter plot des clusters
plt.figure(figsize=(12, 8))
scatter = plt.scatter(z_2d[:, 0], z_2d[:, 1], c=labels, cmap='tab20', s=10)
plt.colorbar(scatter, label='Cluster ID')

# Afficher quelques images sur des points choisis
for cluster_id in range(7):
    indices = np.where(labels == cluster_id)[0]
    selected_indices = np.random.choice(indices, min(len(indices), 3), replace=False)
    for idx in selected_indices:
        img = mpimg.imread(file_paths_subset[idx])
        imagebox = OffsetImage(img, zoom=0.2)
        ab = AnnotationBbox(imagebox, (z_2d[idx, 0], z_2d[idx, 1]), frameon=False)
        plt.gca().add_artist(ab)

plt.title('2D UMAP projection of clusters with sample paintings')
plt.xlabel('UMAP 1')
plt.ylabel('UMAP 2')
plt.tight_layout()
plt.show()
