In [2]:
import pickle
import numpy as np
import pandas as pd
from sklearn.metrics import silhouette_score
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import fcluster

# TreeNode class definition
class TreeNode:
    def __init__(self, left=None, right=None, filenames=None, data=None):
        self.left = left
        self.right = right
        self.filenames = filenames
        self.data = data

# Load the tree
def load_tree(input_file):
    with open(input_file, 'rb') as f:
        return pickle.load(f)

tree = load_tree('../sample_data/combined_dendrogram.pkl')

# Load the data
embeddings = np.load('../sample_data/combined_reduced_embeddings.npz')
filenames = pd.read_csv('../sample_data/combined_filenames.csv', header=None).values.flatten()

In [3]:
# lists contents of embeddings
embeddings.files


['pca5', 'tsne2', 'umap5', 'umap2']

In [None]:

# Get UMAP2 reduced points
umap2_points = embeddings['umap2']

# Function to get clusters at a given level
def get_clusters(tree, level):
    if tree.left is None and tree.right is None:
        return [tree.filenames]
    
    if level == 1:
        return [tree.left.filenames, tree.right.filenames]
    
    return get_clusters(tree.left, level - 1) + get_clusters(tree.right, level - 1)

# Function to display metrics
def display_metrics(tree, level, embeddings):
    clusters = get_clusters(tree, level)
    cluster_labels = np.zeros(len(embeddings))
    
    for i, cluster in enumerate(clusters):
        indices = [np.where(filenames == fname)[0][0] for fname in cluster]
        cluster_labels[indices] = i
    
    score = silhouette_score(embeddings, cluster_labels)
    print(f"Silhouette score for level {level}: {score}")

# Display metrics for various levels
for level in range(1, 6):
    display_metrics(tree, level, embeddings['embeddings'])

# Function to visualize clusters
def visualize_clusters(tree, level, points):
    clusters = get_clusters(tree, level)
    
    plt.figure(figsize=(10, 6))
    
    for i, cluster in enumerate(clusters):
        indices = [np.where(filenames == fname)[0][0] for fname in cluster]
        plt.scatter(points[indices, 0], points[indices, 1], label=f"Cluster {i + 1}")
    
    plt.legend()
    plt.title(f"UMAP2 reduced points visualization for level {level}")
    plt.xlabel("UMAP2-1")
    plt.ylabel("UMAP2-2")
    plt.show()

# Visualize clusters at various levels using UMAP2 reduced points
for level in range(1, 6):
    visualize_clusters(tree, level, umap2_points)