# Exercise 3 - Features and clustering 

In this exercise we will try to separate some of the objects we collected using different clustering techniques. We will do this using features, and embeddings: Image embeddings are high-dimensional numerical representations of visual content that condense the essential features of an image into a vector of fixed size, which may include shapes, colors, patterns, and spatial hierarchies, etc.  

<center><img src=../assets/butterfly-rois.png width="800" ></center>


## Feature extraction
First we load some packages:

In [None]:
import os
import timm
import torch
import phenopype as pp
from torchvision import transforms
from torch import nn 
from PIL import Image
from tqdm import tqdm
import pandas as pd

os.chdir(r"C:\Users\mluerig\Downloads\phenomics-workshop-aussois-main") 

We will use a small model (Resnet50) that was trained on the ImageNet dataset. The model can be downloaded using timm (Pytorch image models library: https://timm.fast.ai/). After downloading the model we need to create a transform function to convert our image array to a tensor.

In [None]:
## Download model via timm
model = timm.create_model(model_name="resnet50", pretrained=True)

## set into evaluation mode (changes its behavior from training to inference)
model = model.eval()

## get the penultimate (fully connected) layer used for classification
model = torch.nn.Sequential(*list(model.children())[:-1]) ## second to 

# Define the necessary transformations
transform = transforms.Compose([
    transforms.ToTensor(),  # Converts the image to a tensor and scales the values to [0, 1]
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ImageNet mean and std
])

The following piece of code will 1) load a saved ROI and apply the mask from the alpha-channel to the RGB channel, 2) tranform the image to a tensor, and 3) extract the image embeddings:

In [None]:
## 1) load image with alpha channel and apply mask 
img_path = r"C:\Users\mluerig\Downloads\phenomics-workshop-aussois-main\phenopype\rois\mcz-ent-170049_1.jpg_05.png_002.png"
img_RGBA = pp.load_image(img_path)    
img, alpha_channel = img_RGBA[:,:,:3], img_RGBA[:,:,3]
mask = alpha_channel > 0
img[~mask]=0

# 2) apply the transformations
img_tens = transform(Image.fromarray(img))
img_tens = img_tens.unsqueeze(0)

# 3) extract_embeddings, normalize, convert back to array
features = model(img_tens)
features = nn.functional.normalize(features)
features = features.detach().numpy()
features

In [None]:
pp.show_image(img_RGBA)

Let's build a little pipeline to do this for all our ROIs - first we create a list of all images:

In [None]:
list_all_masks = os.listdir("phenopype/rois")
list_all_masks

Then we place our code for inference into a loop whose progress we trace with a progress bar:

In [None]:
## create an empty dictionary to store embeddings
results_dict_embeddings = dict()

## create progress bar
pbar = tqdm(total=len(list_all_masks), position=0, leave=False, desc="Embedding images")

## run 
for img_name in list_all_masks:
    
        ## construct img_path
        img_path = os.path.join("phenopype", "rois", img_name)

        ## 1) load image
        img_RGBA = pp.load_image(img_path)    
        img, alpha_channel = img_RGBA[:,:,:3], img_RGBA[:,:,3]
        mask = alpha_channel > 0
        img[~mask]=0
        
        # 2) apply the transformations
        img_tens = transform(Image.fromarray(img))
        img_tens = img_tens.unsqueeze(0)
        
        # 3) extract_embeddings, normalize, convert back to array
        features = model(img_tens)
        features = nn.functional.normalize(features)
        features = features.detach().numpy()
        results_dict_embeddings[img_name] = features[0]
    
        ## update progress bar
        pbar.update(1)

In [None]:
## save results
os.makedirs("data", exist_ok=True)
results_df = pd.DataFrame.from_dict(results_dict_embeddings, orient="index")
results_df.reset_index(inplace=True)
results_df.rename(columns={'index': 'mask_name'}, inplace=True)
results_df.to_csv("data/embeddings.csv", index=False)

In [None]:
results_df

## Clustering

We need to load some more packages for clustering and plotting:

In [None]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

from bokeh.io import show
from bokeh.models import Div, HoverTool, CustomJS, ColumnDataSource, tickers
from bokeh.plotting import figure, output_file
from bokeh.transform import linear_cmap, factor_cmap
from bokeh.palettes import all_palettes
from bokeh import layouts

def get_offset_img(path, zoom=1):
    image = pp.load_image(path, mode="rgb")
    image_resized = pp.resize_image(image, max_dim=250)
    return OffsetImage(image_resized, zoom=zoom)

def assign_color(name):
    if name.startswith("red"):
        return 'red'
    elif name.startswith("blue"):
        return 'blue'
    elif name.startswith("green"):
        return 'green'
    else:
        return 'purple'

def sample_group(group, fraction, seed=42):
    num_samples = max(1, int(len(group) * fraction))
    return group.sample(n=num_samples, replace=False, random_state=seed)

In [None]:
## we can cluster the embeddings...
data_emb = pd.read_csv("data/embeddings.csv")
rename_dict = {col: f'emb_{col}' for col in data_emb.columns[1:]}
data_emb = data_emb.rename(columns=rename_dict)
cols_emb = [col for col in list(data_emb) if col.startswith("emb")]
data_emb

In [None]:
## ... for instance using tSNE (t-distributed stochastic neighbor embedding)

## this can take A LOT of time
tsne = TSNE(n_components=2, random_state=0, perplexity=100)  # n_components=2 for 2D visualization
transformed_data = tsne.fit_transform(data_emb[cols_emb])


In [None]:
## prep plot df
data_plot = pd.concat([data_emb[["mask_name"]], pd.DataFrame(
    data = transformed_data, columns = ['Dim1', 'Dim2']) ], axis=1)

data_plot['mask_path'] = data_plot.apply(lambda row: os.path.join("phenopype","rois", row['mask_name']), axis=1)
data_plot['mask_path_plots'] = data_plot.apply(lambda row: os.path.join("..", row['mask_path']), axis=1)

In [None]:
## let's plot the embeddings!
fig, ax = plt.subplots(figsize=(7, 7))  
plt.tight_layout(pad=3)
ax.set_aspect('equal')
ax.set_title("TSNE - embeddings")
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
scatter = ax.scatter(data_plot['Dim1'], data_plot['Dim2'], s=20)

In [None]:
## with pictograms
fig, ax = plt.subplots(figsize=(7, 7))  
plt.tight_layout(pad=3)
ax.set_aspect('equal')
ax.set_title("TSNE - embeddings")
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
scatter = ax.scatter(data_plot['Dim1'], data_plot['Dim2'], s=20)

pbar = tqdm(total=len(data_plot), position=0, leave=False, desc="Plotting...")
for idx, row in data_plot.iterrows():
    ab = AnnotationBbox(get_offset_img(row["mask_path"], zoom=0.1), 
        (row["Dim1"], row["Dim2"]), frameon=False)
    ax.add_artist(ab)
    pbar.update(1)

os.makedirs("plots", exist_ok=True)
fig.savefig("plots/embeddings_pictograms.png", dpi=300) 

In [None]:
## prepare htlm output 
filepath = "plots/embeddings_interactive.html"
output_file(filepath)

## convert to datasource
ds_points = ColumnDataSource(data=dict(data_plot))

## add hover tool
hover=HoverTool(
        tooltips = [
        ("mask_name", "@mask_name"),
        ]
)

## add interactive panel
div = Div(text="")
hover.callback = CustomJS(args=dict(div=div, ds=ds_points), code="""
    const hit_test_result = cb_data.index;
    const indices = hit_test_result.indices;
    if (indices.length > 0) {
        div.text = `<img 
        src="${ds.data['mask_path_plots'][indices[0]]}"
        style="float: left; margin: 0px 15px 15px 0px; max-width: 500px; max-height: 500px; width: auto; height: auto;"
        border="2"
        />`;
    }
    """)

## create scatter panel
p = figure(tools=[
    "pan", 'reset', 'wheel_zoom', 'box_zoom', "lasso_select", "tap", hover], 
    active_scroll="wheel_zoom",active_drag ="pan", output_backend="webgl",
    width=1200, height=800, match_aspect=True,
    x_axis_label="Dim1", y_axis_label="Dim2") 
p.xaxis.ticker = tickers.SingleIntervalTicker(interval=1)
p.yaxis.ticker = tickers.SingleIntervalTicker(interval=1)
p.scatter(x='Dim1', y='Dim2', source=ds_points, size=15)

## plot
layout = layouts.row(p, div)
show(layout)

In [None]:
os.makedirs("plots", exist_ok=True)

## A more complete example

Let's load segmentation masks, handcrafted features and embeddings I have created on a larger Junonia dataset using a pretrained encoder (ViT). First we load all files, then merge them together:  

In [None]:
df_feat = pd.read_csv("data_raw/data_features.csv")
df_emb = pd.read_csv("data_raw/data_embeddings.csv")

data_all = pd.merge(df_feat, df_emb, on=['mask_name',"species"], how='outer')

cols_emb = [col for col in list(data_all) if col.startswith("emb")]
cols_feat = [col for col in list(data_all) if col.startswith(("shape","red","green","blue"))]

In [None]:
## prep pca
scaler = StandardScaler()
X_scaled_main = scaler.fit_transform(data_all[cols_emb])
X_scaled_support = scaler.fit_transform(data_all[cols_feat].values)

## do pca
pca = PCA(n_components=len(cols_emb))
principalComponents = pca.fit_transform(X_scaled_main)
components_main = pca.transform(X_scaled_main) 

##Calculate the loadings (correlations) of the supplementary variables on the principal components
supp_loadings = np.dot(components_main.T, X_scaled_support) / (X_scaled_support.shape[0] - 1)
data_supp = pd.DataFrame(supp_loadings.T[:,:2], columns=["Dim1","Dim2"])
data_supp["features"] = cols_feat
data_supp['Color'] = data_supp['features'].apply(assign_color)

## percent variance explained
var_explained = pca.explained_variance_ratio_[:2]

In [None]:
## do tsne (takes a few seconds)
tsne = TSNE(n_components=2, random_state=0, perplexity=100)  # n_components=2 for 2D visualization
transformed_data = tsne.fit_transform(data_all[cols_emb])

In [None]:
## prep plot df
data_plot = pd.concat([data_all[["mask_name", "species"]], pd.DataFrame(data = transformed_data, columns = ['Dim1', 'Dim2'])], axis=1)
data_plot = pd.concat([data_plot, pd.DataFrame(data = components_main[:,:2], columns = ['PC1', 'PC2'])], axis=1)
data_plot['mask_path'] = data_plot.apply(lambda row: os.path.join("data_raw","segmentation_masks", row['species'], row["mask_name"]), axis=1)
data_plot['mask_path_plots'] = data_plot.apply(lambda row: os.path.join("..","data_raw","segmentation_masks", row['species'], row["mask_name"]), axis=1)

##
mask_subset = data_plot.groupby('species').apply(sample_group, fraction=0.10, seed=42)
mask_subset.reset_index(drop=True, inplace=True)

# colors for bokeh
unique_species = data_plot['species'].unique()
color_map = plt.get_cmap("nipy_spectral", len(unique_species))
color_palette = [mcolors.to_hex(color_map(i)) for i in range(len(unique_species))]
color_mapper = factor_cmap('species', palette=color_palette, factors=unique_species)

## colors for matplotlib
data_supp['Color'] = data_supp['features'].apply(assign_color)
category_to_color = {category: color for category, color in zip(unique_species, color_palette)}
data_plot['Color_species'] = colors = data_plot['species'].map(category_to_color)

In [None]:
## prepare htlm output 
filepath = "plots/embeddings_interactive_complete-ex.html"
output_file(filepath)

## convert to datasource
ds_points = ColumnDataSource(data=dict(data_plot))

## add hover tool
hover=HoverTool(
        tooltips = [
        ("species", "@species"),
        ]
)

## add interactive panel
div = Div(text="")
hover.callback = CustomJS(args=dict(div=div, ds=ds_points), code="""
    const hit_test_result = cb_data.index;
    const indices = hit_test_result.indices;
    if (indices.length > 0) {
        div.text = `<img 
        src="${ds.data['mask_path_plots'][indices[0]]}"
        style="float: left; margin: 0px 15px 15px 0px; max-width: 500px; max-height: 500px; width: auto; height: auto;"
        border="2"
        />`;
    }
    """)

## create scatter panel
p = figure(tools=[
    "pan", 'reset', 'wheel_zoom', 'box_zoom', "lasso_select", "tap", hover], 
    active_scroll="wheel_zoom",active_drag ="pan", output_backend="webgl",
    width=1200, height=800, match_aspect=True,
    x_axis_label="Dim1", y_axis_label="Dim2") 
p.xaxis.ticker = tickers.SingleIntervalTicker(interval=50)
p.yaxis.ticker = tickers.SingleIntervalTicker(interval=50)
p.scatter(x='Dim1', y='Dim2', source=ds_points, color=color_mapper, size=10)

## plot
layout = layouts.row(p, div)
show(layout)

In [None]:
## scatter plot
fig, ax = plt.subplots(figsize=(15, 15))   
plt.tight_layout(pad=3)
ax.set_aspect('equal')
ax.set_title("PCA - Junonia all")
ax.set_xlabel(f"Dimension 1 - ({var_explained[0]*100:.2f}% variance explained)")
ax.set_ylabel(f"Dimension 2 - ({var_explained[1]*100:.2f}% variance explained)")
scatter = ax.scatter(data_plot['PC1'], data_plot['PC2'], c=data_plot["Color_species"], s=15)

## pictograms
pbar = tqdm(total=len(mask_subset), position=0, leave=False, desc="Plotting...")
for idx, row in mask_subset.iterrows():
    ab = AnnotationBbox(get_offset_img(row["mask_path"], zoom=0.2), 
        (row["PC1"], row["PC2"]), frameon=False)
    ax.add_artist(ab)
    pbar.update(1)

## arrows
data_supp_sub = data_supp[data_supp["features"].isin([
    "shape_area", "shape_diameter", "shape_circularity","shape_hu3", 
    "red_mean", "red_variance", "red_uniformity",
    "green_mean", "green_variance", "green_uniformity",
    "blue_mean", "blue_variance", "blue_uniformity"])]
amod = 5
for idx, (load_x, load_y, feature, Color) in data_supp_sub.iterrows():
    ax.arrow(0, 0, load_x*amod, load_y*amod, color=Color, alpha=1, width=0.1, head_width=0.5, length_includes_head=True, zorder=4)
    ax.text(load_x*amod, load_y*amod, feature, color=Color,fontsize=8, zorder=5,
            bbox=dict(facecolor="white", alpha=1, edgecolor=Color, linewidth=0.5, boxstyle="Round"),
            ha='right' if load_x < 0 else 'left', va='bottom' if load_y < 0 else 'top')

fig.savefig("plots/embeddings_pictograms_supp.png", dpi=300) 