**Description**: Compile results and generate confusion matrices for the classification models.

The Python libraries required to run this notebook are listed in requirements.txt

In [4]:
import os
import json
from pathlib import Path
from itertools import product

from IPython import get_ipython
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [10]:
# Constants for results compilation

RESULTS_DIR = Path(get_ipython().getoutput('pwd')[0]).parent / 'output' / 'condition_stratification'
NEGATIVE = 0    #'healthy'
POSITIVE = 1    #'endometriosis'
PHASES = ['all_phases', 'early_secretory', 'mid_secretory', 'proliferative']
GENE_SETS = ['all_genes', 'all_matrisome', 'core_matrisome']
DATA_SPLIT = 'test'

In [14]:
""" Plot configuration section.

Change the values in this section to change:
    - the conditions, phases, and name of gene sets displayed on the plot
    - the colormap
    - the font sizes
    - figure size
"""

# Change the right-hand values to modify the conditions displayed on the plot.
DISPLAYED_CONDITIONS = {
    "Negative": "Healthy",
    "Positive": "Endometriosis"
}

# Change the right-hand values to modify the phases displayed on the plot.
DISPLAYED_PHASES = {
    "All Phases": "all phases",
    "Proliferative Phase": "proliferative phase",
    "Early Secretory Phase": "early secretory phase",
    "Mid Secretory Phase": "mid secretory phase"
}

# Change the right-hand values to modify the gene sets displayed on the plot.
DISPLAYED_GENES = {
    "All Genes": "all genes",
    "All Matrisome": "all matrisome genes",
    "Core Matrisome": "core matrisome genes"
}

# See https://matplotlib.org/stable/tutorials/colors/colormaps.html
COLORMAP = "YlGnBu"

# Font size for the counts in each square
FONT_SIZE_FOR_COUNTS = 14

# Font size for the labels on the axes
FONT_SIZE_FOR_LABELS = 12

# Font size for the title
FONT_SIZE_FOR_TITLE = 14

# Font size for condition labels
FONT_SIZE_FOR_CONDITIONS = 12

# Figure size (width, height)
FIGURE_SIZE = (8, 6)

# Vertical padding between title and top of the confusion matrix
# should probably be somewhere between 1.0 and 1.1 
TITLE_PADDING = 1.05

In [15]:
""" Compile results into results.json file. """

results = {}

for phase, gene_set in product(PHASES, GENE_SETS):
    filepath = RESULTS_DIR / f'{phase}_{gene_set}_{DATA_SPLIT}_results.tsv'
    assert os.path.exists(filepath), f'{filepath} does not exist'

    # Read data
    df = pd.read_csv(RESULTS_DIR / filepath, sep='\t')

    # Get phase and genes
    phase = phase.replace('_', ' ').title()
    if phase != 'All Phases':
        phase += ' Phase'
    gene_set = gene_set.replace('_', ' ').title()

    # Initialize dict if not exists
    if phase not in results:
        results[phase] = {}

    # Get classification results
    true_negatives = len(df[(df['true_label'] == NEGATIVE) & (df['predicted_label'] == NEGATIVE)])
    false_positives = len(df[(df['true_label'] == NEGATIVE) & (df['predicted_label'] == POSITIVE)])
    false_negatives = len(df[(df['true_label'] == POSITIVE) & (df['predicted_label'] == NEGATIVE)])
    true_positives = len(df[(df['true_label'] == POSITIVE) & (df['predicted_label'] == POSITIVE)])

    # Update results
    results[phase][gene_set] = {
        "False Negatives": false_negatives,
        "False Positives": false_positives,
        "True Negatives": true_negatives,
        "True Positives": true_positives
    }

# Write results to JSON
with open(RESULTS_DIR / 'results.json', 'w', encoding='utf-8') as f:
    json.dump(results, f, indent=4)

In [16]:
""" Confusion matrix section.
Generates confusion matrices and saves them to PNG files.
Data is loaded from the results.json file.
"""

# Load JSON data
with open(RESULTS_DIR / "results.json") as f:
    data = json.load(f)

# List of phases
phases = ["All Phases", "Proliferative Phase", "Early Secretory Phase", "Mid Secretory Phase"]

# List of gene types
gene_types = ["All Genes", "All Matrisome", "Core Matrisome"]

# Generate confusion matrix for each phase and gene type
for phase in phases:
    for gene_type in gene_types:
        # Extract values
        vals = data[phase][gene_type]

        # Convert values to confusion matrix
        conf_mat = [[vals["True Negatives"], vals["False Positives"]],
                    [vals["False Negatives"], vals["True Positives"]]]

        # Convert to DataFrame
        df = pd.DataFrame(
            conf_mat,
            columns=[DISPLAYED_CONDITIONS["Negative"], DISPLAYED_CONDITIONS["Positive"]],
            index=[DISPLAYED_CONDITIONS["Negative"], DISPLAYED_CONDITIONS["Positive"]]
        )

        # Plot confusion matrix
        plt.figure(figsize=(8, 6))
        sns.heatmap(df, annot=True, fmt="d", cmap=COLORMAP, annot_kws={"size": FONT_SIZE_FOR_COUNTS})

        plt.title(
            f"Classification results for {DISPLAYED_GENES[gene_type]} in {DISPLAYED_PHASES[phase]}",
            fontsize=FONT_SIZE_FOR_TITLE,
            y=TITLE_PADDING
        )

        plt.ylabel('True Label', fontsize=FONT_SIZE_FOR_LABELS)
        plt.xlabel('Predicted Label', fontsize=FONT_SIZE_FOR_LABELS)

        plt.xticks(fontsize=FONT_SIZE_FOR_CONDITIONS)
        plt.yticks(fontsize=FONT_SIZE_FOR_CONDITIONS)

        plot_filename = f"{phase}_{gene_type}.png"

        # Clean up filename
        plot_filename = plot_filename.lower()
        plot_filename = plot_filename.replace(" ", "_")

        # Save in current directory
        plt.savefig(RESULTS_DIR / plot_filename, dpi=400)
        plt.close()