In [None]:
import sys
sys.path.append('../../annotations')
import annotation_metrics as am

import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE, Isomap, MDS
from sklearn.decomposition import PCA

dimensions = ['pathophysiology', 'epidemiology', 'etiology', 'history', 'physical', 'exams', 'differential', 'therapeutic']

dimensions_mean = [dim + '_mean' for dim in dimensions]

metrics = [dim + '_norm_mean' for dim in dimensions]

cluster_methods = [
    'kmeans 3 emb', 'kmeans 3 cat',
    'agg 3 emb', 'agg 3 cat',
    'gmm 3 emb', 'gmm 3 cat',
    'birch 3 emb', 'birch 3 cat'
]

cluster_names = ["Novice", "Developing", "Proficient",
                 "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20"]

In [None]:
annotation_sets = ['Human_84', 'Human_435', 'BioBERTpt', 'BioBERT_Llama', 'Llama']

metrics_all = {'Human_84': pd.read_csv('medical_specialist/annotations-dpoc-medical_specialist_metrics_84.csv'),
               'Human_435': pd.read_csv('medical_specialist/aligned_annotations-dpoc-medical_specialist_metrics_435.csv'),
               'BioBERTpt': pd.read_csv('biobert_balanced/annotations-dpoc-biobert_metrics.csv'),
               'BioBERT_Llama': pd.read_csv('biobert-llama_balanced/annotations-dpoc-biobert-llama_metrics.csv'),
               'Llama': pd.read_csv('llama/annotations-dpoc-llm_3_static_shot_metrics.csv')}

for metric_label in metrics_all:
    metrics_all[metric_label] = metrics_all[metric_label].sort_values('annotation id')

common_ids_435 = metrics_all['Human_435']['annotation id'].isin(metrics_all['Llama']['annotation id'])
metrics_all['Human_435'] = metrics_all['Human_435'][common_ids_435]
metrics_all['Llama'] = metrics_all['Llama'][common_ids_435]

common_ids_86 = metrics_all['Human_84']['annotation id'].isin(metrics_all['BioBERTpt']['annotation id'])
metrics_all['Human_84'] = metrics_all['Human_84'][common_ids_86]
metrics_all['BioBERTpt'] = metrics_all['BioBERTpt'][common_ids_86]
metrics_all['BioBERT_Llama'] = metrics_all['BioBERT_Llama'][common_ids_86]

stats_all = {'Human_84': pd.read_csv('medical_specialist/annotations-dpoc-medical_specialist_stats_84.csv'),
             'Human_435': pd.read_csv('medical_specialist/annotations-dpoc-medical_specialist_stats_435.csv'),
             'BioBERTpt': pd.read_csv('biobert_balanced/annotations-dpoc-biobert_stats.csv'),
             'BioBERT_Llama': pd.read_csv('biobert-llama_balanced/annotations-dpoc-biobert-llama_stats.csv'),
             'Llama': pd.read_csv('llama/annotations-dpoc-llm_3_static_shot_stats.csv')}

reduction_method = ['tsne', 'pca', 'isomap', 'mds']
reduction_all = {}
for ann_set in annotation_sets:
    tsne = TSNE(n_components=2, random_state=42, init='pca')
    pca = PCA(n_components=2)
    isomap = Isomap(n_components=2)
    mds = MDS(n_components=2, random_state=42)
    reduction_all[ann_set] = {
        'tsne': tsne.fit_transform(metrics_all[ann_set][am.cluster_labels].values),
        'pca': pca.fit_transform(metrics_all[ann_set][am.cluster_labels].values),
        'isomap': isomap.fit_transform(metrics_all[ann_set][am.cluster_labels].values),
        'mds': mds.fit_transform(metrics_all[ann_set][am.cluster_labels].values)
    }

comparison = [['Human_435', 'Llama'], ['Human_84', 'BioBERTpt'], ['Human_84', 'BioBERT_Llama']]

def create_radar_charts(method_data, method, stats_label, show_scores=True):
    stats_label = stats_label.replace('_435', '').replace('_84', '').replace('BioBERT_Llama', 'BioBERT-Llama')
    """Create radar charts from DataFrame"""
    metrics = ['pathophysiology_norm_mean', 'epidemiology_norm_mean', 'etiology_norm_mean', 'history_norm_mean',
               'physical_norm_mean', 'exams_norm_mean', 'differential_norm_mean', 'therapeutic_norm_mean']
    
    clusters = method_data['cluster'].unique()
    
    # Calculate subplot layout
    n_clusters = len(clusters)
    n_rows = (n_clusters + 2) // 3  # Max 3 plots per row
    n_cols = min(n_clusters, 3)
    
    # Create figure with subplots
    fig, axs = plt.subplots(
        n_rows, n_cols, 
        figsize=(5*n_cols, 5*n_rows), 
        subplot_kw={'projection': 'polar'}
    )
    
    # Flatten axs for easier indexing if multiple rows
    if n_clusters > 1:
        axs = axs.flatten() if n_rows > 1 else axs
    
    # Color map for different clusters
    colors = plt.cm.rainbow(np.linspace(0, 1, len(clusters)))
    
    # Plot data for each cluster
    for i, (cluster, color) in enumerate(zip(clusters, colors)):
        # Handle subplot indexing
        ax = axs[i] if n_clusters > 1 else axs
        
        cluster_data = method_data[method_data['cluster'] == cluster]
        values = cluster_data[metrics].values.flatten()
        
        # Compute angles for metrics
        theta = np.linspace(0, 2*np.pi, len(metrics), endpoint=False)
        
        # Close the plot by repeating the first value
        values = np.concatenate((values, [values[0]]))
        theta = np.concatenate((theta, [theta[0]]))
        
        # Plot the radar chart
        ax.plot(theta, values, color=color)
        ax.fill(theta, values, color=color, alpha=0.25)
        
        # Set labels
        ax.set_xticks(theta[:-1])
        ax.set_xticklabels([
            metric.replace('_norm_mean', '').replace('_', ' ').title() 
            for metric in metrics
        ])
        ax.set_ylim(0, 0.60)

        cn = cluster_names[cluster] if show_scores else cluster
        title = f'{cn} Cluster\nannotations {cluster_data["annotations mean"].values[0]:.1f} avg'
        
        if show_scores:
            # Set title with cluster and year
            ots_min = cluster_data['objective test score_min'].values[0]
            ots_max = cluster_data['objective test score_max'].values[0]
            ots_mean = cluster_data['objective test score_mean'].values[0]
            ol_min = cluster_data['organization level_min'].values[0]
            ol_max = cluster_data['organization level_max'].values[0]
            ol_mean = cluster_data['organization level_mean'].values[0]
            gs_min = cluster_data['global score_min'].values[0]
            gs_max = cluster_data['global score_max'].values[0]
            gs_mean = cluster_data['global score_mean'].values[0]

            title += (f'\nobjective test {ots_min}-{ots_max} / {ots_mean:.1f} avg\n' +
                      f'organization level {ol_min}-{ol_max} / {ol_mean:.1f} avg\n' +
                      f'global score {gs_min}-{gs_max} / {gs_mean:.1f} avg')
            
        ax.set_title(title)
    
    # Remove extra subplots if any
    if n_clusters < len(axs.flatten()):
        for j in range(n_clusters, len(axs.flatten())):
            fig.delaxes(axs.flatten()[j])
    
    # Overall figure title
    fig.suptitle(f'{stats_label}', fontsize=16, y=1.05)
    
    # Adjust layout and display
    plt.tight_layout()
    plt.show()

for comparison_pair in comparison:

    method = 'kmeans 3 cat'
    # method = 'gmm 3 emb'
    # Heatmap
    # -------

    column_name = f'cluster {method}'

    cluster1_labels = metrics_all[comparison_pair[0]][column_name].values
    cluster2_labels = metrics_all[comparison_pair[1]][column_name].values

    # Create contingency matrix
    contingency = confusion_matrix(cluster1_labels, cluster2_labels)

    # Calculate proportion matrix
    row_sums = contingency.sum(axis=1).reshape(-1, 1)
    proportion_matrix = contingency / row_sums

    # Plot contingency matrix heatmap
    plt.subplot(212)
    # Increase figure size to reduce overlap
    plt.gcf().set_size_inches(9, 4.5)
    sns.heatmap(proportion_matrix, annot=True, fmt='.2f', cmap='YlOrRd',
                xticklabels=[f'{comparison_pair[1].replace("BioBERT_Llama","BioBERT-Llama")} {cluster_names[i]}' for i in range(len(np.unique(cluster2_labels)))],
                yticklabels=[f'{comparison_pair[0].replace("_435","").replace("_84","").replace("_86","")} {cluster_names[i]}' for i in range(len(np.unique(cluster1_labels)))],
                vmin=0, vmax=1)
    plt.title(f'{comparison_pair[0].replace("_435","").replace("_84","").replace("_86","")} vs {comparison_pair[1].replace("BioBERT_Llama","BioBERT-Llama")}')
    plt.xticks(rotation=0)
    plt.tight_layout()
    plt.show()

    # Radar Plot
    # ----------

    for stats_label in comparison_pair:
        stats = stats_all[stats_label]
        stats['annotations mean'] = stats[dimensions_mean].sum(axis=1)
        method_data = stats[stats['method'] == method]
        create_radar_charts(method_data, method, stats_label)

        # Bar Chart
        # ---------

        # Extract year columns
        year_columns = ['year_1', 'year_2', 'year_3', 'year_4', 'year_5', 'year_6']

        # Plot bar chart for each cluster
        clusters = method_data['cluster'].unique()
        n_clusters = len(clusters)
        print(clusters)
        fig, axs = plt.subplots(1, n_clusters, figsize=(5 * n_clusters, 6), sharey=True)
        
        if n_clusters == 1:
            axs = [axs]
        year_cluster_normalized = {0:{1: 0, 2: 0, 3: 0, 4: 0, 5:0, 6:0},1:{1: 0, 2: 0, 3: 0, 4: 0, 5:0, 6:0},2:{1: 0, 2: 0, 3: 0, 4: 0, 5:0, 6:0}}
        year_count = {1: 0, 2: 0, 3: 0, 4: 0, 5:0, 6:0}
        for cluster in clusters:
            cluster_data = method_data[method_data['cluster'] == cluster]
            for year in range(1, 7):
                year_count[year] += cluster_data[f'year_{year}'].sum()
        for cluster in clusters:
            for year in range(1, 7):
                year_cluster_normalized[cluster][year] = method_data[method_data['cluster'] == cluster][f'year_{year}'].sum() / year_count[year]
                print('attempting to normalize', cluster, year, year_count[year],method_data[method_data['cluster'] == cluster][f'year_{year}'].sum(), year_cluster_normalized[cluster][year])
        for ax, cluster in zip(axs, clusters):
            cluster_data = method_data[method_data['cluster'] == cluster]
            ax.bar(year_columns, cluster_data[year_columns].values.flatten())
            ax.set_title(f'{cluster_names[cluster]} Cluster')
            ax.set_xlabel('Students per Year')
            ax.set_ylabel('Values')
            ax.legend()
            for i, value in enumerate(cluster_data[year_columns].values.flatten()):
                ax.text(i, value, str(int(value)), ha='center', va='bottom')

        # Display the bar chart
        plt.tight_layout()
        plt.show()


    # Select elements labeled as 0 in cluster 1 that were labeled as 1 in cluster 2
    intersect = [[0, 1], [1, 0]]
    for cp in comparison_pair:
        mean_dimensions_all = pd.DataFrame()
        for inter in intersect:
            selected_elements = metrics_all[cp][(cluster1_labels == inter[0]) & (cluster2_labels == inter[1])]
            mean_dimensions = selected_elements[dimensions].mean()
            annotations_mean = mean_dimensions.sum()
            mean_dimensions = mean_dimensions / annotations_mean
            mean_dimensions.index = [f'{dim}_norm_mean' for dim in mean_dimensions.index]
            mean_dimensions['annotations mean'] = annotations_mean
            mean_dimensions['cluster'] = f'{cluster_names[inter[0]]}/{cluster_names[inter[1]]}'
            mean_dimensions_all = pd.concat([mean_dimensions_all, mean_dimensions.to_frame().T], ignore_index=True)
        create_radar_charts(mean_dimensions_all, method, cp, False)

    # Scatter Plot
    # ------------

    fig, axs = plt.subplots(4, 4, figsize=(20, 16))
    axs = axs.flatten()

    pos = 0

    combination = [[0, 0], [1,1], [0,1], [1,0]]
    for comb in combination:
        for rm in reduction_method:
            scatter = axs[pos].scatter(
                reduction_all[comparison_pair[comb[1]]][rm][:, 0],
                reduction_all[comparison_pair[comb[1]]][rm][:, 1],
                c=metrics_all[comparison_pair[comb[0]]][column_name].values, cmap='viridis', alpha=0.6)
            fig.colorbar(scatter, ax=axs[pos], label='Cluster')
            axs[pos].set_title(
                f'{comparison_pair[comb[0]]} ann/{comparison_pair[comb[1]]}')
            axs[pos].set_xlabel(f'{rm} 1')
            axs[pos].set_ylabel(f'{rm} 2')
            pos += 1

    plt.tight_layout()
    plt.show()