In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import sys
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle
import glob
from utils.sb_clustering import ClusteringFramework

import glob
import pickle
import numpy as np
import itertools
from tqdm import tqdm
from sklearn.cluster import DBSCAN
from joblib import Parallel, delayed


### Find the best weighting of features and DBSCAN hyperparams

In [None]:

# Function to process a single vessel and compute clustering score
def process_vessel(set_, vessel, modality, feature_weights, eps, min_samples, viz=False):
    with open(f'../Data/Registration Dataset v2/{set_}/{vessel}/{modality}_superglue_clusters_using_gt.pkl', "rb") as f:
        ivus_clusters = pickle.load(f)

    # Flatten data and prepare labels
    ivus_flat = []
    ivus_label = []
    for i in range(len(ivus_clusters)):
        ivus_flat.append(ivus_clusters[i][:, [0, 1, 2, 3, 4]])
        ivus_label.extend([i for _ in range(ivus_clusters[i].shape[0])])
    ivus_flat = np.concatenate(ivus_flat, axis=0)

    # Initialize framework
    framework = ClusteringFramework(ivus_flat, ivus_label, feature_weights)
    
    # Fit DBSCAN
    framework.fit_dbscan(eps=eps, min_samples=min_samples)

    # Evaluate clustering performance
    results = framework.evaluate_clustering()

    if viz:
        framework.visualize_clusters()
    return results["Combined Score"], framework  # Return only the combined score

# Function to optimize global hyperparameters
def optimize_global_hyperparams(modality):
    best_score = -np.inf
    best_params = None

    # Iterate over all possible hyperparameter combinations
    for feature_weights, eps, min_samples in tqdm(
        itertools.product(itertools.product(*FEATURE_WEIGHT_RANGES), EPS_RANGE, MIN_SAMPLES_RANGE),
        total=len(list(itertools.product(itertools.product(*FEATURE_WEIGHT_RANGES), EPS_RANGE, MIN_SAMPLES_RANGE)))
    ):
        feature_weights = list(feature_weights)  # Convert to list
        feature_weights = [feature_weights[0], feature_weights[1], feature_weights[1], feature_weights[2]]

        # Parallel computation across all vessels
        scores = Parallel(n_jobs=-1)(delayed(process_vessel)(set_, vessel, modality, feature_weights, eps, min_samples)
                                      for set_, vessel in zip(all_vessel_sets, all_vessel_names))

        # Compute the mean score across all vessels
        avg_score = np.mean(scores)

        # Track best global parameters
        if avg_score > best_score:
            best_score = avg_score
            best_params = {
                "feature_weights": feature_weights,
                "eps": eps,
                "min_samples": min_samples,
                "score": best_score
            }

        print(f"Params: {feature_weights}, eps: {eps}, min_samples: {min_samples}, Avg Score: {avg_score}")

    return best_params

## IVUS sb clustering - hyperparam search

In [None]:
# Define hyperparameter ranges
FEATURE_WEIGHT_RANGES = [
    [4.9, 5.0, 5.1,],  # Feature 0 weight
    [1.6, 1.65, 1.7, 1.75],  # Feature 1+2 weight
    [0.45, 0.5, 0.55, 0.60]   # Feature 3 weight
]
EPS_RANGE = [0.27, 0.28, 0.29]
MIN_SAMPLES_RANGE = [1]

# Run global optimization
best_global_params = optimize_global_hyperparams(modality='ivus')
print(f"\nBest Global Parameters: {best_global_params}")

# Save best global results
with open("best_global_dbscan_hyperparams_ivus.pkl", "wb") as f:
    pickle.dump(best_global_params, f)

print("Global hyperparameter tuning complete! Results saved.")

In [None]:
# Load dataset paths
all_vessels = glob.glob('../Data/Registration Dataset/val/*/ivus_superglue_clusters.pkl')
all_vessel_names = [x.split('/')[-2] for x in all_vessels]
all_vessel_sets = [x.split('/')[-3] for x in all_vessels]

feature_weights = [5.0, 1.65, 1.65, 0.55]
eps = 0.29
min_samples = 1

for vessel, set_ in zip(all_vessel_names, all_vessel_sets):

    results = process_vessel(set_, vessel, 'ivus', feature_weights, eps, min_samples, viz=True)

    print(results)

## OCT sb clustering - hyperparam search

In [None]:
# Define hyperparameter ranges
FEATURE_WEIGHT_RANGES = [
    [8.03, 8.05, 8.07],  # Feature 0 weight
    [0.98, 1, 1.02],  # Feature 1+2 weight
    [0.07, 0.08, 0.09, 0.1, 0.11, 0.12]   # Feature 3 weight
]
EPS_RANGE = [0.29, 0.3, 0.31]
MIN_SAMPLES_RANGE = [1]

# Load dataset paths
all_vessels = glob.glob('../Data/Registration Dataset/val/*/oct_superglue_clusters.pkl')
all_vessel_names = [x.split('/')[-2] for x in all_vessels]
all_vessel_sets = [x.split('/')[-3] for x in all_vessels]

# Run global optimization
best_global_params = optimize_global_hyperparams(modality='oct')
print(f"\nBest Global Parameters: {best_global_params}")

# Save best global results
with open("best_global_dbscan_hyperparams_oct.pkl", "wb") as f:
    pickle.dump(best_global_params, f)

print("Global hyperparameter tuning complete! Results saved.")

In [None]:
# Load dataset paths
all_vessels = glob.glob('../Data/Registration Dataset/val/*/oct_superglue_clusters.pkl')
all_vessel_names = [x.split('/')[-2] for x in all_vessels]
all_vessel_sets = [x.split('/')[-3] for x in all_vessels]

feature_weights = [8.5, 1, 1, 0]
eps = 0.3
min_samples = 1

for vessel, set_ in zip(all_vessel_names, all_vessel_sets):

    results = process_vessel(set_, vessel, 'oct', feature_weights, eps, min_samples, viz=True)

    print(results)