In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib


new_rc_params = {'text.usetex': False,
"svg.fonttype": 'none'
}
matplotlib.rcParams.update(new_rc_params)

In [4]:
df = pd.read_csv('../ranks.csv', index_col=0)


In [5]:
common_kits = ['MagBac', 'MagMic', 'MagSoi', 'MagSto', 'SilMet', 'SilSoi', 'SkySto', 'SkySoi']
kits_dict = {}
for t in sample_types:
    if t == 'Water':
        kits_dict[t] = common_kits + ['B&T']
    elif t == 'Feces':
        kits_dict[t] = common_kits + ['PowSoi', 'PowFec']
    elif t in ('Gut flora', 'Sediment'):
        kits_dict[t] = common_kits + ['PowSoi']
    else:
        kits_dict[t] = common_kits

In [34]:
# Define sample types
sample_types = ['Water', 'Sediment', 'Gut flora', 'Feces']

# Create a 4x3 grid of subplots for heatmaps and dotplots
fig, axes = plt.subplots(4, 3, figsize=(15, 20), gridspec_kw={'width_ratios': [1.5, 1, 1]})

# Create heatmaps and dotplots for each sample type
for i, sample_type in enumerate(sample_types):
    # Heatmap
    columns_list = ['Simplicity', 'Cost per sample'] + df.columns[df.columns.str.contains(sample_type)].tolist()

    # Filter dataframe for kits specific to this sample type
    sample_df = df[columns_list]
    sample_df = sample_df[sample_df.index.isin(kits_dict[sample_type])]
    
    # Remove sample type prefix from column names
    sample_df.columns = sample_df.columns.str.replace(f'{sample_type}_', '', regex=False)
    sns.heatmap(sample_df, annot=True, cmap='Blues_r', cbar=False, ax=axes[i,0], annot_kws={'size': 18})
    axes[i,0].set_ylabel(f'{sample_type.capitalize()}', fontsize=18)
    axes[i,0].tick_params(axis='both', labelsize=14)  # Increase tick label size
    if i != 3:
        axes[i,0].tick_params(labelbottom=False)
    # Dotplot with all scores
    total_scores = sample_df.sum(axis=1)
    sorted_scores = total_scores.sort_values()
    axes[i,1].plot(sorted_scores, 'o')
    axes[i,1].set_xticks(range(len(sorted_scores)))
    axes[i,1].set_xticklabels(sorted_scores.index, rotation=90, fontsize=12)
    axes[i,1].set_ylabel('Total Rank', fontsize=16)
    axes[i,1].set_box_aspect(1)

    # Dotplot without Simplicity and Cost
    sample_df_no_cost = sample_df.drop(['Simplicity', 'Cost per sample'], axis=1)
    total_scores_no_cost = sample_df_no_cost.sum(axis=1)
    sorted_scores_no_cost = total_scores_no_cost.sort_values()
    axes[i,2].plot(sorted_scores_no_cost, 'o')
    axes[i,2].set_xticks(range(len(sorted_scores_no_cost)))
    axes[i,2].set_xticklabels(sorted_scores_no_cost.index, rotation=90, fontsize=12)
    axes[i,2].set_ylabel('Quality Rank', fontsize=16)
    axes[i,2].set_box_aspect(1)

# Add legend in bottom right corner
legend_text = "Lower rank - better performance\nHigher rank - worse performance"
props = dict(boxstyle='round', facecolor='white', alpha=0.8)
axes[3,2].text(1, -0.4, legend_text, transform=axes[3,2].transAxes, 
               fontsize=16, verticalalignment='bottom', horizontalalignment='right',
               bbox=props)

plt.tight_layout()
plt.savefig('../plots/sample_type_heatmaps.pdf', format='pdf', bbox_inches='tight')
plt.close()