# Coassociation analysis

# 1) Set up libraries and datasets

## 1.1) Import libraries and models

In [None]:
# Import libraries
import os
import re
import json
import datetime
import statistics
import pandas as pd
import numpy as np
import requests
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.ticker as mticker
from tqdm import tqdm
import networkx as nx
import pyvis
import community
from scipy import stats
from scipy.stats import chi2_contingency, fisher_exact
from scipy.cluster.hierarchy import fcluster, linkage, dendrogram
from statsmodels.stats.multitest import multipletests
from matplotlib.backends.backend_pdf import PdfPages
from node2vec import Node2Vec
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.decomposition import PCA

print("Success!")

## 1.2) Load datasets

In [None]:
# Set the working directory and file paths
input_directory = "INPUT_DIRECTORY"
output_directory = "OUTPUT_DIRECTORY"
variantscape_directory = "VARIANTSCAPE_DIRECTORY"
figure_directory = "FIGURE_DIRECTORY"
variantscape_llm_coas_directory = "VARIANTSCAPE_LLM_COAS_DIRECTORY"

os.chdir(variantscape_directory)
print("Current directory:", os.getcwd())

In [None]:
# Load datasets

#### In input directory
os.chdir(input_directory)
oncomine_genes = pd.read_csv("oncomine_ngs_panel.csv", header=None)
ESCAT_genes = pd.read_csv("ESCAT_pc_genes.csv", header=None)
ESCAT_genes.rename(columns={ESCAT_genes.columns[0]: "ESCAT_genes"}, inplace=True)

#### In variantscape directory
# Change to the variantscape directory
os.chdir(variantscape_directory)

# Load the variant dataset (v3) and metadata mapping
variant_analysis_df = pd.read_csv("cleaned_df_v4.csv", low_memory=False)
metadata_mapping = pd.read_csv("metadata_mapping.csv", low_memory=False)

# Check the lengths and columns of datasets
len_variant_analysis_df_rows, len_variant_analysis_df_cols = variant_analysis_df.shape
len_oncomine_genes_rows, len_oncomine_genes_cols = oncomine_genes.shape
len_ESCAT_genes_rows, len_ESCAT_genes_cols = ESCAT_genes.shape
len_metadata_mapping_rows, len_metadata_mapping_cols = metadata_mapping.shape

# Validate that the number of columns in variant_analysis_df matches the number of entries in metadata_mapping
if len_variant_analysis_df_cols == len_metadata_mapping_cols:
    print("The number of columns in variant_analysis_df matches the number of metadata entries.")
else:
    print(f"Warning: The number of columns in variant_analysis_df ({len_variant_analysis_df_cols}) does not match the number of metadata entries ({len_metadata_mapping_cols}).")
print("\nSuccess!")
print(f"\nMerged variant dataset: {len_variant_analysis_df_rows:,} rows, {len_variant_analysis_df_cols:,} columns")
print(f"Number of oncomine genes: {len_oncomine_genes_rows:,} rows, {len_oncomine_genes_cols:,} columns")
print(f"Number of ESCAT genes: {len_ESCAT_genes_rows:,} rows, {len_ESCAT_genes_cols:,} columns")
print("\n")
print(ESCAT_genes)

# Validate that the column names in variant_analysis_df correspond to metadata mapping
# Convert metadata_mapping (which is a row) into a dictionary for comparison
metadata_dict = metadata_mapping.iloc[0].to_dict()
missing_metadata = [col for col in variant_analysis_df.columns if col not in metadata_dict]
if not missing_metadata:
    print("\nAll columns in variant_analysis_df have corresponding metadata.")
else:
    print(f"Warning: The following columns in variant_analysis_df do not have corresponding metadata: {missing_metadata}")

# Step 1) Data Preparation

In [None]:
# Subdataset 
coas_variant_df = variant_analysis_df.copy()
print(f"Total rows in coas_variant_df: {len(coas_variant_df):,}")
print(f"Total cols in coas_variant_df: {len(coas_variant_df.columns):,}")

# Create the metadata dictionary
print(f"Total cols in metadata mapping: {len(metadata_mapping.columns):,}")

# Calculate the total sum for each column (excluding non-numeric columns like 'PaperId')
total_sum_per_column = coas_variant_df.iloc[:, 5:].sum()
zero_sum_columns = total_sum_per_column[total_sum_per_column == 0]
if zero_sum_columns.empty:
    print("All columns have non-zero totals!")
else:
    print("Columns with a total sum of zero (no associations):")
    print(zero_sum_columns)

# ==================================================

In [None]:
# Create bar chart
dictionary = metadata_mapping.copy()
coas_variant_df = variant_analysis_df.copy()
total_articles = len(coas_variant_df)

dictionary = dictionary.transpose().reset_index()
dictionary.columns = ["Column_Name", "Category"]

treatment_columns = dictionary[dictionary['Category'] == "Treatment"]["Column_Name"].tolist()
cancer_columns = dictionary[dictionary['Category'] == "Cancer"]["Column_Name"].tolist()
variant_columns = dictionary[dictionary['Category'] == "Variant"]["Column_Name"].tolist()

treatment_df = coas_variant_df[treatment_columns]
cancer_df = coas_variant_df[cancer_columns]
variant_df = coas_variant_df[variant_columns]

# Count the mentions for each treatment, cancer, and variant
top_treatments = treatment_df.sum()

# EXCLUDE SPECIFIC TREATMENTS
excluded_treatments = [
    'chemotherapy',
    'immunotherapy',
    'hormone therapy',
    'radiation therapy',
    'adjuvant chemotherapy',
    'radiation ionizing radiotherapy',
    'tyrosine kinase inhibitor',
    'braf inhibitor'
]
top_treatments = top_treatments[~top_treatments.index.str.lower().isin([t.lower() for t in excluded_treatments])]
top_treatments = top_treatments.sort_values(ascending=False).head(20)
top_cancers = cancer_df.sum().sort_values(ascending=False).head(20)
top_cancers.index = top_cancers.index.str.capitalize()
top_variants = variant_df.sum().sort_values(ascending=False).head(20)

# Calculate percentages
top_treatments_percent = (top_treatments / total_articles) * 100
top_cancers_percent = (top_cancers / total_articles) * 100
top_variants_percent = (top_variants / total_articles) * 100

def weighted_co_occurrence_matrix(df1, df2, weights, scaling_factor=1.0):
    weights = weights * scaling_factor
    weighted_df1 = df1.mul(weights, axis=0)
    weighted_df2 = df2.mul(weights, axis=0)
    matrix = weighted_df1.T.dot(weighted_df2)
    return pd.DataFrame(matrix, index=df1.columns, columns=df2.columns)
weights = coas_variant_df['Study_weight']

def format_variant_label(label):
    """Convert 'v600e_BRAF' to 'V600E BRAF'."""
    if "_" in label:
        variant, gene = label.split("_")
        return f"{gene.upper()} {variant.upper()}"
    return label.upper()
def plot_and_save_bar_chart(data, title, xlabel, ylabel, file_name, total_articles):
    data_with_padding = data.copy()
    data_with_padding['Invisible Padding'] = 0
    fig, ax = plt.subplots(figsize=(16, 10), constrained_layout=True)
    data_with_padding.plot(kind='bar', color=['#1f20b4'] * len(data) + ['white'], edgecolor='black', ax=ax)
    ax.set_title(f'{title.capitalize()} out of {total_articles:,} articles', fontsize=18)
    ax.set_xlabel(xlabel, fontsize=14)
    ax.set_ylabel(ylabel, fontsize=14)
    ax.set_ylim(0, data.max() * 1.4)
    plt.gca().yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{int(x):,}'))
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    ax.set_xticks(range(len(data)))
    ax.set_xticklabels(data.index, rotation=45, ha='right', fontsize=14)
    for i, (label, count) in enumerate(zip(data.index, data.values)):
        percentage = (count / data.sum()) * 100
        plt.text(i, count + (data.max() * 0.04), f"{int(count):,} ({percentage:.2f}%)", 
                 ha='left', va='bottom', fontsize=14, rotation=45)
    plt.subplots_adjust(left=0.05, right=0.99, bottom=0.35, top=0.95)
    plt.savefig(file_name, bbox_inches='tight', pad_inches=0.5, dpi=300)
    plt.show()

current_directory = os.getcwd()
statistics_file = os.path.join(current_directory, 'output_statistics.txt')
top_variants.index = [format_variant_label(v) for v in top_variants.index]

plot_and_save_bar_chart(top_treatments, 'Top 20 most mentioned treatments', 'Treatment', 'Mentions', 
                        os.path.join(current_directory, 'Top_20_Most_Mentioned_Treatments.png'), total_articles)

plot_and_save_bar_chart(top_cancers, 'Top 20 most mentioned cancers', 'Cancer', 'Mentions', 
                        os.path.join(current_directory, 'Top_20_Most_Mentioned_Cancers.png'), total_articles)

plot_and_save_bar_chart(top_variants, 'Top 20 most mentioned variants', 'Variant', 'Mentions', 
                        os.path.join(current_directory, 'Top_20_Most_Mentioned_Variants.png'), total_articles)

# Generate output
with open(statistics_file, 'w') as f:
    f.write("DataFrame Statistics:\n")
    f.write(f"Number of rows: {len(coas_variant_df)}\n")
    f.write(f"Number of columns: {len(coas_variant_df.columns)}\n\n")
    f.write(f"Treatment Columns: {len(treatment_columns)}\n")
    f.write(f"Cancer Columns: {len(cancer_columns)}\n")
    f.write(f"Variant Columns: {len(variant_columns)}\n\n")
    f.write("Top 20 Most Mentioned Treatments (with percentages):\n")
    for treatment, count, percent in zip(top_treatments.index, top_treatments, top_treatments_percent):
        f.write(f"{treatment}: {count} mentions, {percent:.2f}%\n")
    f.write("\nTop 20 Most Mentioned Cancers (with percentages):\n")
    for cancer, count, percent in zip(top_cancers.index, top_cancers, top_cancers_percent):
        f.write(f"{cancer}: {count} mentions, {percent:.2f}%\n")
    f.write("\nTop 20 Most Mentioned Variants (with percentages):\n")
    for variant, count, percent in zip(top_variants.index, top_variants, top_variants_percent):
        f.write(f"{variant}: {count} mentions, {percent:.2f}%\n")
print("Bar charts and statistics have been generated and saved!")

# Coassociation analysis weighted by study design

In [None]:
# Function to calculate percentage-based co-occurrence matrix
def normalized_co_occurrence_matrix(df1, df2, weights, scaling_factor=1.0):
    # Apply weights to dataframes
    weights = weights * scaling_factor
    weighted_df1 = df1.mul(weights, axis=0)
    weighted_df2 = df2.mul(weights, axis=0)
    raw_matrix = weighted_df1.T.dot(weighted_df2)
    row_totals = weighted_df1.sum(axis=0)
    percentage_matrix = raw_matrix.div(row_totals, axis=0).mul(100)
    
    return percentage_matrix

# Generate normalized co-occurrence matrices
print("Generating normalized co-occurrence matrices...")
treatment_variant_matrix_normalized = normalized_co_occurrence_matrix(treatment_df, variant_df, weights)
cancer_variant_matrix_normalized = normalized_co_occurrence_matrix(cancer_df, variant_df, weights)
treatment_cancer_matrix_normalized = normalized_co_occurrence_matrix(treatment_df, cancer_df, weights)
print("Normalized co-occurrence matrices generated successfully!\n")

# Function to apply Fisher's Exact Test and FDR correction
def apply_statistical_testing(co_occurrence_matrix):
    print(f"Applying statistical tests to matrix of size {co_occurrence_matrix.shape}...")
    p_values = []
    matrix_shape = co_occurrence_matrix.shape
    flattened_matrix = co_occurrence_matrix.values.flatten()
    total_entries = len(flattened_matrix)
    
    for i, value in enumerate(tqdm(flattened_matrix, desc="Processing statistical tests", total=total_entries)):
        if value > 0:
            contingency_table = np.array([[value, np.sum(flattened_matrix) - value], 
                                          [np.sum(co_occurrence_matrix.sum(axis=1)) - value, np.sum(flattened_matrix)]])
            _, p_value = fisher_exact(contingency_table)
        else:
            p_value = 1.0
        p_values.append(p_value)
    print("Applying Benjamini-Hochberg correction...")
    corrected_p_values = multipletests(p_values, method='fdr_bh')[1]
    corrected_p_values_matrix = np.reshape(corrected_p_values, matrix_shape)
    return pd.DataFrame(corrected_p_values_matrix, index=co_occurrence_matrix.index, columns=co_occurrence_matrix.columns)

treatment_variant_pvalues = apply_statistical_testing(treatment_variant_matrix_normalized)
cancer_variant_pvalues = apply_statistical_testing(cancer_variant_matrix_normalized)
treatment_cancer_pvalues = apply_statistical_testing(treatment_cancer_matrix_normalized)
print("Saving results to CSV files...")

treatment_variant_matrix_normalized.to_csv("treatment_variant_matrix_normalized.csv")
cancer_variant_matrix_normalized.to_csv("cancer_variant_matrix_normalized.csv")
treatment_cancer_matrix_normalized.to_csv("treatment_cancer_matrix_normalized.csv")
treatment_variant_pvalues.to_csv("treatment_variant_pvalues.csv")
cancer_variant_pvalues.to_csv("cancer_variant_pvalues.csv")
treatment_cancer_pvalues.to_csv("treatment_cancer_pvalues.csv")
print("Results saved successfully!\n")

In [None]:
# Display top 50 heatmaps
treatment_variant_matrix = pd.read_csv("treatment_variant_matrix_normalized.csv", index_col=0)
cancer_variant_matrix = pd.read_csv("cancer_variant_matrix_normalized.csv", index_col=0)
treatment_cancer_matrix = pd.read_csv("treatment_cancer_matrix_normalized.csv", index_col=0)

# Function to plot filtered heatmaps
def plot_filtered_heatmap(matrix, title, filename, x_label, y_label, threshold=1.0, top_n=50):
    filtered_matrix = matrix[matrix > threshold].fillna(0)
    row_totals = filtered_matrix.sum(axis=1)
    col_totals = filtered_matrix.sum(axis=0)
    top_rows = row_totals.nlargest(top_n).index
    top_cols = col_totals.nlargest(top_n).index
    filtered_matrix = filtered_matrix.loc[top_rows, top_cols]
    plt.figure(figsize=(12, 10))
    sns.heatmap(filtered_matrix, cmap="coolwarm", annot=False, linewidths=0.5, cbar=True)
    plt.title(title, fontsize=16)
    plt.xlabel(x_label, fontsize=12)
    plt.ylabel(y_label, fontsize=12)
    plt.xticks(rotation=45, ha="right", fontsize=7)
    plt.yticks(fontsize=9)
    plt.tight_layout()
    plt.savefig(filename, dpi=300)
    plt.show()

# Plot heatmaps
plot_filtered_heatmap(treatment_variant_matrix, 
                      "Treatment-variant co-occurrence", 
                      "filtered_treatment_variant_heatmap.png", 
                      "Variants", "Treatments", 
                      threshold=1.0, top_n=50)

plot_filtered_heatmap(cancer_variant_matrix, 
                      "Cancer-variant co-occurrence", 
                      "filtered_cancer_variant_heatmap.png", 
                      "Variants", "Cancers", 
                      threshold=1.0, top_n=50)

plot_filtered_heatmap(treatment_cancer_matrix, 
                      "Treatment-cancer co-occurrence", 
                      "filtered_treatment_cancer_heatmap.png", 
                      "Cancers", "Treatments", 
                      threshold=1.0, top_n=50)

# Create focused heatmap for variant-treatment co-associations

In [None]:
# Load CIVIC dataset
final_output_filepath = os.path.join(output_directory, "CIVIC_ncit_df_finalparent_treatmentcategory.csv")
CIVIC_ncit_df_finalparent = pd.read_csv(final_output_filepath)
total_rows = len(CIVIC_ncit_df_finalparent)

# Terms to exclude
treatments_to_exclude = [
    "chemotherapy", "immunotherapy", "targeted therapy",
    "radiation therapy", "radiation ionizing radiotherapy", "folfox regimen", "iniparib",
    "epidermal growth factor receptor tyrosine kinase inhibitor","tyrosine kinase inhibitor",
    "mitogen-activated protein kinase kinase inhibitor","iodine i-131",
    "anti-vegf monoclonal antibody", "radioactive iodine","egfr tyrosine kinase inhibitor therapy",
    "aromatase inhibitor", "anti-pd-l1 monoclonal antibody", "mrna vaccine", "pd1 inhibitor", 
]
treatments_to_exclude = [t.strip().lower() for t in treatments_to_exclude]

# Load treatment-variant matrix
treatment_variant_matrix = pd.read_csv(
    os.path.join(variantscape_directory, "treatment_variant_matrix_normalized.csv"),
    index_col=0
)
treatment_variant_matrix.columns = treatment_variant_matrix.columns.str.strip().str.lower()
treatment_variant_matrix.index   = treatment_variant_matrix.index.str.strip().str.lower()
treatment_variant_matrix = treatment_variant_matrix[
    ~treatment_variant_matrix.index.isin(treatments_to_exclude)
]

df_consensus = pd.read_csv(
    os.path.join(variantscape_LLM_coas_directory, "final_variant_treatment_consensus.csv")
)
df_consensus["Variant_Treatment_Pair"] = (
    df_consensus["Variant_Treatment_Pair"]
    .str.lower()
    .str.strip()
)
consensus_dict = dict(
    zip(df_consensus["Variant_Treatment_Pair"], df_consensus["Resolved_Prediction"])
)

# Apply consensus adjustments
adjusted_matrix = treatment_variant_matrix.copy().astype(float)
print("Applying consensus adjustments...")
for treatment, variant in tqdm(
    [(t, v) for t in adjusted_matrix.index for v in adjusted_matrix.columns],
    desc="Consensus Adjustment",
    unit="pair"
):
    key = f"{variant} + {treatment}".strip().lower()
    consensus = consensus_dict.get(key, None)

    if consensus is None or consensus.lower() == "no consensus":
        adjusted_matrix.loc[treatment, variant] = np.nan
    elif consensus.lower() == "resistant":
        adjusted_matrix.loc[treatment, variant] = -abs(adjusted_matrix.loc[treatment, variant])
    # others (sensitive, diagnostic, unrelated) remain positive

## Figure 1) Based on strongest positive and negative associations 

In [None]:
# Select strongest positive and negative associations
flat = adjusted_matrix.unstack().dropna()
sorted_flat = flat.sort_values()

top_n = 50  # number of strong positive + negative associations
top_negative = sorted_flat.head(top_n)
top_positive = sorted_flat.tail(top_n)

# Combine and reformat for plotting
focus_pairs = pd.concat([top_positive, top_negative])
focus_df = focus_pairs.reset_index()
focus_df.columns = ["Variant", "Treatment", "Score"]

# Add manually selected treatments (only top 3 variant associations per treatment)
treatments_to_force = [
    "abiraterone", "trastuzumab", "osimertinib",
    "erlotinib", "gefitinib", "crizotinib", "dabrafenib/trametinib regimen", "dabrafenib", "trametinib",
    "alectinib", "ceritinib", "vemurafenib", "encorafenib",
    "pembrolizumab", "nivolumab", "bevacizumab",
]
treatments_to_force = [t.lower().strip() for t in treatments_to_force]

manual_rows = []
for treatment in treatments_to_force:
    if treatment in adjusted_matrix.index:
        variant_scores = adjusted_matrix.loc[treatment].dropna()
        top_variants = variant_scores.reindex(variant_scores.abs().sort_values(ascending=False).index).head(3)
        for variant, score in top_variants.items():
            manual_rows.append({"Treatment": treatment, "Variant": variant, "Score": score})

manual_df = pd.DataFrame(manual_rows)
focus_df = pd.concat([focus_df, manual_df], ignore_index=True).drop_duplicates()

#### Build the focused matrix
top_variants = focus_df["Variant"].unique()
forced_variant_set = set(manual_df["Variant"].unique())
all_variants = list(pd.Index(top_variants).union(forced_variant_set))
all_treatments = list(pd.Index(focus_df["Treatment"].unique()).union(treatments_to_force))
all_treatments = [t for t in all_treatments if t in adjusted_matrix.index]
focus_matrix = adjusted_matrix.loc[all_treatments, all_variants]

# Order rows/columns
col_order = focus_df.groupby("Variant")["Score"].mean().sort_values(ascending=False).index
row_order = focus_df.groupby("Treatment")["Score"].mean().sort_values(ascending=False).index
row_order = pd.Index(row_order.tolist() + [t for t in all_treatments if t not in row_order])
focus_matrix = focus_matrix.loc[row_order, col_order]

plt.figure(figsize=(18, 14))
sns.set(style="white")
sns.heatmap(
    focus_matrix,
    cmap="RdYlGn",
    center=0,
    linewidths=0.5,  
    linecolor="lightgray", 
    square=False,
    cbar_kws={
        "label": "Association score",
        "shrink": 0.8,
        "orientation": "vertical"
    },
    mask=focus_matrix.isna()
)

# Force display of all Y-axis labels
plt.yticks(
    ticks=np.arange(len(focus_matrix.index)) + 0.5,
    labels=focus_matrix.index,
    fontsize=9,
    rotation=0
)

# Format rest of the plot
cbar = plt.gca().collections[0].colorbar
cbar.ax.set_ylabel("Association score", labelpad=-10)
cbar.ax.yaxis.label.set_rotation(90)

plt.title("Variant–treatment associations", fontsize=18, pad=20)
plt.xlabel("Variants", fontsize=14, labelpad=10)
plt.ylabel("Treatments", fontsize=14, labelpad=10)
plt.xticks(rotation=45, ha="right", fontsize=9)
plt.tight_layout(rect=[0, 0, 1, 0.97])

plt.savefig("variant_treatment_heatmap_clean.png", dpi=300)
plt.show()
print("Final treatments shown on Y-axis:")
print(focus_matrix.index.tolist())

## Figure 2) Based on strongest positive and negative associations and pre-selected treatments

In [None]:
# Select strongest positive and negative associations
flat = adjusted_matrix.unstack().dropna()
sorted_flat = flat.sort_values()

top_n = 50
top_negative = sorted_flat.head(top_n)
top_positive = sorted_flat.tail(top_n)

# Combine and reformat for plotting
focus_pairs = pd.concat([top_positive, top_negative])
focus_df = focus_pairs.reset_index()
focus_df.columns = ["Variant", "Treatment", "Score"]

# Add manually selected treatments (only top 3 variant associations per treatment)
treatments_to_force = [
    # Targeted therapies
    "trastuzumab", "pertuzumab", "osimertinib", "erlotinib", "gefitinib",
    "crizotinib", "alectinib", "ceritinib", "larotrectinib", "entrectinib",
    "dabrafenib", "trametinib", "vemurafenib", "encorafenib", "binimetinib",
    "imatinib", "sunitinib", "axitinib", "bevacizumab",

    # Immunotherapies
    "pembrolizumab", "nivolumab", "atezolizumab", "durvalumab", "ipilimumab",

    # Hormonal therapies
    "tamoxifen", "letrozole", "anastrozole", "exemestane",
    "abiraterone", "enzalutamide",

    # PARP inhibitors
    "olaparib", "rucaparib", "niraparib", "talazoparib",

    # Chemotherapy
    "cisplatin", "carboplatin", "paclitaxel", "docetaxel","dabrafenib/trametinib","dabrafenib","trametinib"
]
treatments_to_force = [t.lower().strip() for t in treatments_to_force]

manual_rows = []
for treatment in treatments_to_force:
    if treatment in adjusted_matrix.index:
        variant_scores = adjusted_matrix.loc[treatment]
        top_variants = variant_scores.reindex(variant_scores.abs().sort_values(ascending=False).index).head(3)
        for variant in top_variants.index:
            score = variant_scores.get(variant, np.nan)
            manual_rows.append({"Treatment": treatment, "Variant": variant, "Score": score})

manual_df = pd.DataFrame(manual_rows)
focus_df = pd.concat([focus_df, manual_df], ignore_index=True).drop_duplicates()

# Build the focused matrix
# Variants from top-N plus manual associations
top_variants = focus_df["Variant"].unique()
forced_variant_set = set(manual_df["Variant"].unique())
all_variants = list(pd.Index(top_variants).union(forced_variant_set))

# Forced treatments to display
all_treatments = list(pd.Index(focus_df["Treatment"].unique()).union(treatments_to_force))
all_treatments = [t for t in all_treatments if t in adjusted_matrix.index]

focus_matrix = adjusted_matrix.loc[all_treatments, all_variants]
focus_matrix = focus_matrix.dropna(axis=1, how='all')
col_order = focus_df.groupby("Variant")["Score"].mean().sort_values(ascending=False).index
row_order = focus_df.groupby("Treatment")["Score"].mean().sort_values(ascending=False).index
row_order = pd.Index(row_order.tolist() + [t for t in all_treatments if t not in row_order])
col_order = [v for v in col_order if v in focus_matrix.columns] 
focus_matrix = focus_matrix.loc[row_order, col_order]

treatments_to_drop = [
    "naquotinib", "sotrastaurin acetate", "uprosertib", "rindopepimut", 
    "rilotumumab",
]
treatments_to_drop = [t.lower().strip() for t in treatments_to_drop]

# Filter out from focus_matrix
focus_matrix = focus_matrix[~focus_matrix.index.isin(treatments_to_drop)]
row_order = [t for t in row_order if t in focus_matrix.index]
focus_matrix = focus_matrix.loc[row_order]

# Plot heatmap
plt.figure(figsize=(18, 14))
sns.set(style="white")
sns.heatmap(
    focus_matrix,
    cmap="RdYlGn",
    center=0,
    linewidths=0.5,
    linecolor="lightgray",
    square=False,
    cbar_kws={
        "label": "Association score",
        "shrink": 0.8,
        "orientation": "vertical"
    }
)

# Y-axis (treatments)
plt.yticks(
    ticks=np.arange(len(focus_matrix.index)) + 0.5,
    labels=focus_matrix.index,
    fontsize=8,
    rotation=0
)

# X-axis (variants)
plt.xticks(
    ticks=np.arange(len(focus_matrix.columns)) + 0.5,
    labels=focus_matrix.columns,
    fontsize=8,
    rotation=45,
    ha="right"
)
cbar = plt.gca().collections[0].colorbar
cbar.ax.set_ylabel("Association score", labelpad=-10)
cbar.ax.yaxis.label.set_rotation(90)
plt.title("Variant–treatment associations", fontsize=18, pad=20)
plt.xlabel("Variants", fontsize=14, labelpad=10)
plt.ylabel("Treatments", fontsize=14, labelpad=10)
plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.savefig("variant_treatment_heatmap_clean.png", dpi=300)
plt.show()

In [None]:
# Print all non-NaN variant–treatment associations
# Flatten the matrix and drop NaNs
all_associations = focus_matrix.stack().reset_index()
all_associations.columns = ["Treatment", "Variant", "Score"]
sorted_associations = all_associations.sort_values(by="Score", ascending=False)
pd.set_option("display.max_rows", None) 
print(sorted_associations)

In [None]:
###### Search by Variant ######
# Input: variant name (e.g., "v600e_braf")
search_variant = "c797s_EGFR" #"g1202r_ALK" 
search_variant = search_variant.strip().lower()
if search_variant in focus_matrix.columns:
    results = focus_matrix[search_variant].dropna().sort_values(ascending=False)
    print(f"\n Treatments associated with variant '{search_variant}':\n")
    print(results)
else:
    print(f" Variant '{search_variant}' not found in the heatmap.")

In [None]:
###### Search by Treatment  ######
search_treatment = "dabrafenib"  
search_treatment = search_treatment.strip().lower()

# Check and print associated variants
matches = focus_matrix.loc[focus_matrix.index == search_treatment]

if not matches.empty:
    if len(matches) > 1:
        print(f"Note: multiple entries found for treatment '{search_treatment}' — showing all.")
    for i, (index, row) in enumerate(matches.iterrows()):
        print(f"\nEntry {i+1} — Variants associated with treatment '{index}':\n")
        result = row.dropna().sort_values(ascending=False)
        print(result)
else:
    print(f"Treatment '{search_treatment}' not found in the heatmap.")

In [None]:
# Print top associations for a specific variant of interest
variant_of_interest = "v600e_braf"
if variant_of_interest in focus_matrix.columns:
    print(f"\n Treatments associated with variant: {variant_of_interest.upper()}")
    variant_scores = focus_matrix[variant_of_interest].dropna().sort_values(ascending=False)
    top_positives = variant_scores[variant_scores > 0].head(20)
    top_negatives = variant_scores[variant_scores < 0].tail(20)
    if not top_positives.empty:
        print("\nTop POSITIVE associations:")
        for treatment, score in top_positives.items():
            print(f"{treatment} --> {score:.4f}")
    else:
        print("\nNo strong positive associations found.")

    if not top_negatives.empty:
        print("\nTop NEGATIVE associations:")
        for treatment, score in top_negatives.items():
            print(f"{treatment} --> {score:.4f}")
    else:
        print("\nNo strong negative associations found.")
else:
    print(f"\nVariant {variant_of_interest.upper()} not found in heatmap.")

# =====================================================

# Create dot plot

In [None]:
# Load the cancer-variant matrix
matrix_path = os.path.join(variantscape_directory, "cancer_variant_matrix_normalized.csv")
cancer_variant_matrix = pd.read_csv(matrix_path, index_col=0)
cancer_variant_matrix.index = cancer_variant_matrix.index.str.strip().str.lower()
cancer_variant_matrix.columns = cancer_variant_matrix.columns.str.strip().str.lower()

# Define cancers and get top 5 strongest variants per cancer (skip "v600")
target_cancers = ['lung cancer', 'breast cancer', 'colon cancer', 'prostate cancer', 'melanoma']
top_rows = []
for cancer in target_cancers:
    row = cancer_variant_matrix.loc[cancer]
    sorted_row = row.reindex(row.abs().sort_values(ascending=False).index)

    picked = 0
    for variant, score in sorted_row.items():
        if variant == "v600" or (variant.startswith("v600_") and variant.count("_") == 1):
            continue
        top_rows.append({
            "Cancer": cancer,
            "Variant": variant,
            "Score": score
        })
        picked += 1
        if picked == 5:
            break
top_df = pd.DataFrame(top_rows)

# Print top 5 table
print("\nTop 5 strongest variant associations per cancer:\n")
for cancer in target_cancers:
    subset = top_df[top_df["Cancer"] == cancer]
    print(f"--- {cancer.upper()} ---")
    print(subset[["Variant", "Score"]].to_string(index=False))
    print()

# Dot plot with bubble legend
plt.figure(figsize=(12, 6))
sns.set(style="whitegrid")
dot_plot = sns.scatterplot(
    data=top_df,
    x="Variant",
    y="Cancer",
    size=np.abs(top_df["Score"]),
    hue=top_df["Score"],
    palette="Blues",
    sizes=(50, 400),
    edgecolor="black",
    legend="brief"
)

plt.title("Top variant associations by cancer", fontsize=16)
plt.xlabel("Variant", fontsize=12)
plt.ylabel("Cancer", fontsize=12)
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
handles, labels = plt.gca().get_legend_handles_labels()
unique = list(dict(zip(labels, handles)).items())
filtered_handles = [h for l, h in unique if l.replace('.', '', 1).isdigit()]
filtered_labels = [l for l in labels if l.replace('.', '', 1).isdigit()]
plt.legend(
    filtered_handles,
    filtered_labels,
    title="Weighted co-occurrence (%)",
    loc='upper left',
    bbox_to_anchor=(1.05, 1),
    borderaxespad=0.,
    labelspacing=1.2,
    frameon=True
)
plt.savefig("top_variant_dotplot.png", dpi=300, bbox_inches="tight")
plt.show()

# =======================================================

# Calculate statistically significant co-associations

In [None]:
def save_filtered_matrix_as_csv(matrix, csv_filename, threshold=1.0, top_n=50):
    """
    Filters the matrix and saves the filtered result to a CSV file.

    Parameters:
    - matrix (pd.DataFrame): The original matrix to filter.
    - csv_filename (str): The name of the CSV file to save the filtered matrix.
    - threshold (float): The minimum value threshold for filtering. Default is 1.0.
    - top_n (int): The number of top rows and columns to keep based on totals. Default is 50.
    """
    # Filter the matrix to only keep high-value associations
    filtered_matrix = matrix[matrix > threshold].fillna(0)
    row_totals = filtered_matrix.sum(axis=1)
    col_totals = filtered_matrix.sum(axis=0)
    top_rows = row_totals.nlargest(top_n).index
    top_cols = col_totals.nlargest(top_n).index
    filtered_matrix = filtered_matrix.loc[top_rows, top_cols]
    filtered_matrix.to_csv(csv_filename)
    print(f"Filtered matrix saved to {csv_filename}")
    
save_filtered_matrix_as_csv(treatment_variant_matrix, 
                        "filtered_treatment_variant_matrix.csv", 
                        threshold=1.0, top_n=50)

In [None]:
# Load the Treatment x Cancer and Cancer x Variant matrices
treatment_cancer_matrix = pd.read_csv("treatment_cancer_matrix_normalized.csv", index_col=0)
cancer_variant_matrix = pd.read_csv("cancer_variant_matrix_normalized.csv", index_col=0)

# Identify top associations for Treatment x Cancer matrix
top_treatment_cancer = treatment_cancer_matrix.unstack().sort_values(ascending=False).head(50)
top_treatment_cancer_df = top_treatment_cancer.reset_index()
top_treatment_cancer_df.columns = ['Cancer', 'Treatment', 'Association Value']

# Identify top associations for Cancer x Variant matrix
top_cancer_variant = cancer_variant_matrix.unstack().sort_values(ascending=False).head(10)
top_cancer_variant_df = top_cancer_variant.reset_index()
top_cancer_variant_df.columns = ['Variant', 'Cancer', 'Association Value']

print("\nTop Associations - Treatment x Cancer:")
print(top_treatment_cancer_df)
print("\nTop Associations - Cancer x Variant:")
print(top_cancer_variant_df)

In [None]:
# Investigate significant associations 
treatment_variant_pvalues = pd.read_csv("treatment_variant_pvalues.csv", index_col=0)
total_associations = treatment_variant_pvalues.size

# Identify significant associations
SIGNIFICANCE_THRESHOLD = 0.05
significant_results = treatment_variant_pvalues[treatment_variant_pvalues <= SIGNIFICANCE_THRESHOLD].stack().reset_index()
significant_results.columns = ['Treatment', 'Variant', 'Corrected_p_value']
num_significant = significant_results.shape[0]
percentage_significant = (num_significant / total_associations) * 100

print("\nStatistical Significance Analysis Summary:\n")
print(f"{'Total Associations':<45} {total_associations:>15,}")
print(f"{'Significant Associations (p < 0.05)':<45} {num_significant:>15,}")
print(f"{'Percentage of Significant Associations (%)':<45} {percentage_significant:>15.2f}")