In [None]:
import sys
sys.path.append('../../annotations')
import annotation_metrics as am

import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE, Isomap, MDS
from sklearn.decomposition import PCA

# Define metrics
metrics = ['pathophysiology_norm_mean', 'epidemiology_norm_mean', 'etiology_norm_mean', 'history_norm_mean',
           'physical_norm_mean', 'exams_norm_mean', 'differential_norm_mean', 'therapeutic_norm_mean']

cluster_methods = [
    'kmeans 3 emb', 'kmeans 3 cat',
    'agg 3 emb', 'agg 3 cat',
    'gmm 3 emb', 'gmm 3 cat',
    'birch 3 emb', 'birch 3 cat'
]


cluster_names = ["Novice", "Developing", "Proficient",
                 "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20"]

In [None]:
annotation_sets = ['Human_84', 'Human_435', 'BioBERT', 'BioBERT_Llama','Llama', 'Llama_aug']

metrics_all = {'Human_84': pd.read_csv('medical_specialist/annotations-dpoc-medical_specialist_metrics_84.csv'),
               'Human_435': pd.read_csv('medical_specialist/aligned_annotations-dpoc-medical_specialist_metrics_435.csv'),
               'BioBERT': pd.read_csv('biobert_balanced/annotations-dpoc-biobert_metrics.csv'),
               'BioBERT_Llama': pd.read_csv('biobert-llama_balanced/annotations-dpoc-biobert-llama_metrics.csv'),
               'Llama': pd.read_csv('llama/annotations-dpoc-llm_10_tf_idf_custom_shot_metrics.csv'),
            #    'Llama_aug': pd.read_csv('llama/annotations-dpoc-llm_augmented_10_tf_idf_custom_shot_metrics.csv')
               'Llama_aug': pd.read_csv('llama/annotations-dpoc-llm_10_tf_idf_custom_shot_metrics.csv')}

for metric_label in metrics_all:
    metrics_all[metric_label] = metrics_all[metric_label].sort_values('annotation id')

common_ids_435 = metrics_all['Human_435']['annotation id'].isin(metrics_all['Llama']['annotation id'])
metrics_all['Human_435'] = metrics_all['Human_435'][common_ids_435]
metrics_all['Llama'] = metrics_all['Llama'][common_ids_435]
metrics_all['Llama_aug'] = metrics_all['Llama_aug'][common_ids_435]

common_ids_86 = metrics_all['Human_84']['annotation id'].isin(metrics_all['BioBERT']['annotation id'])
metrics_all['Human_84'] = metrics_all['Human_84'][common_ids_86]
metrics_all['BioBERT'] = metrics_all['BioBERT'][common_ids_86]
metrics_all['BioBERT_Llama'] = metrics_all['BioBERT_Llama'][common_ids_86]

stats_all = {'Human_84': pd.read_csv('medical_specialist/annotations-dpoc-medical_specialist_stats_84.csv'),
             'Human_435': pd.read_csv('medical_specialist/annotations-dpoc-medical_specialist_stats_435.csv'),
             'BioBERT': pd.read_csv('biobert_balanced/annotations-dpoc-biobert_stats.csv'),
             'BioBERT_Llama': pd.read_csv('biobert-llama_balanced/annotations-dpoc-biobert-llama_stats.csv'),
             'Llama': pd.read_csv('llama/annotations-dpoc-llm_10_tf_idf_custom_shot_stats.csv'),
            #  'Llama_aug': pd.read_csv('llama/annotations-dpoc-llm_augmented_10_tf_idf_custom_shot_stats.csv')
             'Llama_aug': pd.read_csv('llama/annotations-dpoc-llm_10_tf_idf_custom_shot_stats.csv')}

reduction_method = ['tsne', 'pca', 'isomap', 'mds']
reduction_all = {}
for ann_set in annotation_sets:
    tsne = TSNE(n_components=2, random_state=42, init='pca')
    pca = PCA(n_components=2)
    isomap = Isomap(n_components=2)
    mds = MDS(n_components=2, random_state=42)
    reduction_all[ann_set] = {
        'tsne': tsne.fit_transform(metrics_all[ann_set][am.cluster_labels].values),
        'pca': pca.fit_transform(metrics_all[ann_set][am.cluster_labels].values),
        'isomap': isomap.fit_transform(metrics_all[ann_set][am.cluster_labels].values),
        'mds': mds.fit_transform(metrics_all[ann_set][am.cluster_labels].values)
    }

comparison = [['Human_435', 'Llama'], ['Human_84', 'BioBERT'], ['Human_84', 'BioBERT_Llama'], ['Llama', 'Llama_aug'], ['Human_435', 'Llama_aug']]

def create_radar_charts(method_data, method, stats_label):
    """Create radar charts from DataFrame"""
    metrics = ['pathophysiology_norm_mean', 'epidemiology_norm_mean', 'etiology_norm_mean', 'history_norm_mean',
               'physical_norm_mean', 'exams_norm_mean', 'differential_norm_mean', 'therapeutic_norm_mean']
    
    clusters = method_data['cluster'].unique()
    
    # Calculate subplot layout
    n_clusters = len(clusters)
    n_rows = (n_clusters + 2) // 3  # Max 3 plots per row
    n_cols = min(n_clusters, 3)
    
    # Create figure with subplots
    fig, axs = plt.subplots(
        n_rows, n_cols, 
        figsize=(5*n_cols, 5*n_rows), 
        subplot_kw={'projection': 'polar'}
    )
    
    # Flatten axs for easier indexing if multiple rows
    if n_clusters > 1:
        axs = axs.flatten() if n_rows > 1 else axs
    
    # Color map for different clusters
    colors = plt.cm.rainbow(np.linspace(0, 1, len(clusters)))
    
    # Plot data for each cluster
    for i, (cluster, color) in enumerate(zip(clusters, colors)):
        # Handle subplot indexing
        ax = axs[i] if n_clusters > 1 else axs
        
        cluster_data = method_data[method_data['cluster'] == cluster]
        values = cluster_data[metrics].values.flatten()
        
        # Compute angles for metrics
        theta = np.linspace(0, 2*np.pi, len(metrics), endpoint=False)
        
        # Close the plot by repeating the first value
        values = np.concatenate((values, [values[0]]))
        theta = np.concatenate((theta, [theta[0]]))
        
        # Plot the radar chart
        ax.plot(theta, values, color=color)
        ax.fill(theta, values, color=color, alpha=0.25)
        
        # Set labels
        ax.set_xticks(theta[:-1])
        ax.set_xticklabels([
            metric.replace('_norm_mean', '').replace('_', ' ').title() 
            for metric in metrics
        ])
        
        # Set title with cluster and year
        ots_min = cluster_data['objective test score_min'].values[0]
        ots_max = cluster_data['objective test score_max'].values[0]
        ots_mean = cluster_data['objective test score_mean'].values[0]
        ol_min = cluster_data['organization level_min'].values[0]
        ol_max = cluster_data['organization level_max'].values[0]
        ol_mean = cluster_data['organization level_mean'].values[0]
        gs_min = cluster_data['global score_min'].values[0]
        gs_max = cluster_data['global score_max'].values[0]
        gs_mean = cluster_data['global score_mean'].values[0]

        ax.set_title(f'Cluster {cluster_names[cluster]}\nobjective test {ots_min}-{ots_max} / {ots_mean:.1f} avg\n' +
                     f'organization level {ol_min}-{ol_max} / {ol_mean:.1f} avg\n' +
                     f'global score {gs_min}-{gs_max} / {gs_mean:.1f} avg')
    
    # Remove extra subplots if any
    if n_clusters < len(axs.flatten()):
        for j in range(n_clusters, len(axs.flatten())):
            fig.delaxes(axs.flatten()[j])
    
    # Overall figure title
    fig.suptitle(f'{method} - {stats_label}', fontsize=16)
    
    # Adjust layout and display
    plt.tight_layout()
    plt.show()

for comparison_pair in comparison:
    for method in cluster_methods:
        # Heatmap
        # -------

        column_name = f'cluster {method}'

        cluster1_labels = metrics_all[comparison_pair[0]][column_name].values
        cluster2_labels = metrics_all[comparison_pair[1]][column_name].values

        # Create contingency matrix
        contingency = confusion_matrix(cluster1_labels, cluster2_labels)

        # Calculate proportion matrix
        row_sums = contingency.sum(axis=1).reshape(-1, 1)
        proportion_matrix = contingency / row_sums

        # Plot contingency matrix heatmap
        plt.subplot(212)
        sns.heatmap(proportion_matrix, annot=True, fmt='.2f', cmap='YlOrRd',
                    xticklabels=[f'{comparison_pair[1]} {cluster_names[i]}' for i in range(len(np.unique(cluster2_labels)))],
                    yticklabels=[f'{comparison_pair[0]} {cluster_names[i]}' for i in range(len(np.unique(cluster1_labels)))],
                    vmin=0, vmax=1)
        plt.title(f'{comparison_pair[0]} vs {comparison_pair[1]} - Proportion Matrix for {method}')

        plt.tight_layout()
        plt.show()

        # Radar Plot
        # ----------

        for stats_label in comparison_pair:
            stats = stats_all[stats_label]
            method_data = stats[stats['method'] == method]
            create_radar_charts(method_data, method, stats_label)

            # Bar Chart
            # ---------

            # Extract year columns
            year_columns = ['year_1', 'year_2', 'year_3', 'year_4', 'year_5', 'year_6']

            # Plot bar chart for each cluster
            clusters = method_data['cluster'].unique()
            n_clusters = len(clusters)
            
            fig, axs = plt.subplots(1, n_clusters, figsize=(5 * n_clusters, 6), sharey=True)
            
            if n_clusters == 1:
                axs = [axs]
            
            for ax, cluster in zip(axs, clusters):
                cluster_data = method_data[method_data['cluster'] == cluster]
                ax.bar(year_columns, cluster_data[year_columns].values.flatten(), label=f'Cluster {cluster}')
                ax.set_title(f'Cluster {cluster_names[cluster]}')
                ax.set_xlabel('Students per Year')
                ax.set_ylabel('Values')
                ax.legend()

            # Display the bar chart
            plt.tight_layout()
            plt.show()

        # Scatter Plot
        # ------------

        fig, axs = plt.subplots(4, 4, figsize=(20, 16))
        axs = axs.flatten()

        pos = 0

        combination = [[0, 0], [1,1], [0,1], [1,0]]
        for comb in combination:
            for rm in reduction_method:
                scatter = axs[pos].scatter(
                    reduction_all[comparison_pair[comb[1]]][rm][:, 0],
                    reduction_all[comparison_pair[comb[1]]][rm][:, 1],
                    c=metrics_all[comparison_pair[comb[0]]][column_name].values, cmap='viridis', alpha=0.6)
                fig.colorbar(scatter, ax=axs[pos], label='Cluster')
                axs[pos].set_title(
                    f'{comparison_pair[comb[0]]} ann/{comparison_pair[comb[1]]} spc-{method}')
                axs[pos].set_xlabel(f'{rm} 1')
                axs[pos].set_ylabel(f'{rm} 2')
                pos += 1

        plt.tight_layout()
        plt.show()