In [1]:
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt

import os
import cv2
import numpy as np
from tqdm import tqdm
import pandas as pd


from sklearn.cluster import KMeans, OPTICS, HDBSCAN
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from scipy.cluster.hierarchy import linkage, dendrogram
from sklearn.metrics import silhouette_score
import plotly.express as px
from sklearn.cluster import DBSCAN

device = 'cuda'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class PlantDataset(Dataset):
  def __init__(self, images_path, df):
    super(PlantDataset, self).__init__()

    self.imgs_folder = images_path
    self.df = df

  def __len__(self):
    return len(self.df)
  
  def __getitem__(self, idx):
    img_path = os.path.join(self.imgs_folder, f"{self.df['id'][idx]}.jpeg")

    img = Image.open(img_path).convert('RGB')

    preprocess = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    input_tensor = preprocess(img)

    return input_tensor

In [3]:
df = pd.read_csv('planttraits2024/train.csv')
dataset = PlantDataset("planttraits2024/train_images", df)
loader = DataLoader(dataset=dataset, batch_size=256, shuffle=False)

In [4]:
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.eval()  # Set to evaluation mode 

encoder = torch.nn.Sequential(*(list(model.children())[:-1])).to(device)

In [5]:
vectors = []

for image in tqdm(loader):
  image = image.to(device)

  results = encoder(image)
  vectors.extend(results.detach())

vectors = np.array([x.cpu().numpy() for x in vectors])
vectors = vectors.reshape((vectors.shape[0], vectors.shape[1]))

 19%|█▉        | 41/217 [01:19<05:41,  1.94s/it]


KeyboardInterrupt: 

In [None]:
vectors.shape

(55489, 512)

In [None]:
tsne = TSNE(n_components=3, verbose=1, perplexity=40, n_iter=300)
tsne_results = tsne.fit_transform(vectors)

[t-SNE] Computing 121 nearest neighbors...
[t-SNE] Indexed 55489 samples in 0.008s...


[t-SNE] Computed neighbors for 55489 samples in 4.989s...
[t-SNE] Computed conditional probabilities for sample 1000 / 55489
[t-SNE] Computed conditional probabilities for sample 2000 / 55489
[t-SNE] Computed conditional probabilities for sample 3000 / 55489
[t-SNE] Computed conditional probabilities for sample 4000 / 55489
[t-SNE] Computed conditional probabilities for sample 5000 / 55489
[t-SNE] Computed conditional probabilities for sample 6000 / 55489
[t-SNE] Computed conditional probabilities for sample 7000 / 55489
[t-SNE] Computed conditional probabilities for sample 8000 / 55489
[t-SNE] Computed conditional probabilities for sample 9000 / 55489
[t-SNE] Computed conditional probabilities for sample 10000 / 55489
[t-SNE] Computed conditional probabilities for sample 11000 / 55489
[t-SNE] Computed conditional probabilities for sample 12000 / 55489
[t-SNE] Computed conditional probabilities for sample 13000 / 55489
[t-SNE] Computed conditional probabilities for sample 14000 / 55489

In [None]:
dbscan_model = DBSCAN(eps=20, min_samples=5)
dbscan_result = dbscan_model.fit_predict(vectors)

print(np.unique(dbscan_result, return_counts=True))

(array([0]), array([55489]))

In [None]:
print(tsne_results.shape)

(55489, 3)


In [None]:
fig = px.scatter_3d(x=tsne_results[:,0], y=tsne_results[:,1], z=tsne_results[:,2], color=dbscan_result, width=800, height=800)
fig.update_layout(scene=dict(
    xaxis_title='X Label',
    yaxis_title='Y Label',
    zaxis_title='Z Label'
))
fig.update_layout(title='3D Scatter Plot')
fig.update_traces(marker=dict(size=1))

# Display the plot
fig.show()