# Coassociation analysis

# 1) Set up libraries and datasets

## 1.1) Import libraries and models

In [None]:
# Import libraries
import os
import json
import re
import datetime
import pandas as pd
import numpy as np
import requests
import pyvis
import networkx as nx
import seaborn as sns
from tqdm import tqdm
from statsmodels.stats.multitest import multipletests
from scipy import stats
from scipy.stats import chi2_contingency, fisher_exact
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import ticker as mtick
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from pyvis.network import Network
from scipy.cluster.hierarchy import fcluster, linkage, dendrogram

print("Success!")

## 1.2) Set working directories

In [None]:
# Set the working directory and file paths
input_directory = "INPUT_DIRECTORY"
output_directory = "OUTPUT_DIRECTORY"
variantscape_directory = "VARIATESCAPE_DIRECTORY"
figures_for_publication = "FIGURES_FOR_PUBLICATION"
LLM_variant_directory = "LLM_VARIANT_DIRECTORY"
classifier_directory = "CLASSIFIER_DIRECTORY"


os.chdir(output_directory)
print("Current directory:", os.getcwd())

# =====================================================

# Figure 1

In [None]:
###### 01_Literature trend analysis ######

os.chdir(output_directory)
file_path = "clean_df_step4.csv"

if 'clean_df_step4' in globals():
    figure_df = clean_df_step4.copy()
else:
    os.chdir(output_directory)
    clean_df_step4 = pd.read_csv(file_path)
    figure_df = clean_df_step4.copy()
    
figure_df['PubDate'] = pd.to_datetime(figure_df['PubDate'], errors='coerce')
figure_df['Year'] = figure_df['PubDate'].dt.year
figure_df['YearMonth'] = figure_df['PubDate'].dt.to_period('M').astype(str)
figure_df['YearWeek'] = figure_df['PubDate'].dt.to_period('W').astype(str)
figure_df = figure_df[figure_df['Year'].between(2014, 2025)]
yearly_counts = figure_df['Year'].value_counts().sort_index()
monthly_counts = figure_df['YearMonth'].value_counts().sort_index()
weekly_counts = figure_df['YearWeek'].value_counts().sort_index()
years = list(range(2015, 2026))
shifted_labels = list(range(2014, 2025))
blue_color = '#1f77b4'

os.chdir(figures_for_publication)

###### Yearly Publications ######
plt.figure(figsize=(12, 8))
plt.bar(yearly_counts.index + 1, yearly_counts.values, color=blue_color)
plt.xlabel('Year')
plt.ylabel('Number of articles')
plt.title('Publications per year (2014-2024)')
plt.xlim(2014, 2026)
plt.ylim(0, yearly_counts.max() * 1.2)
plt.xticks(years, shifted_labels)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: f'{int(x):,}'))
plt.savefig("01_Yearly_literature trend analysis.png", bbox_inches='tight', dpi=300)
plt.show()

###### Monthly Publications ######
plt.figure(figsize=(12, 8))
plt.plot(monthly_counts.index, monthly_counts.values, marker='o', linestyle='-', color=blue_color, linewidth=2)
plt.xticks([f"{year}-01" for year in range(2014, 2026)], range(2014, 2026))
plt.xlabel('Year')
plt.ylabel('Number of articles')
plt.title('Publications per month (2014-2024)')
plt.xlim(monthly_counts.index.min(), monthly_counts.index.max())
plt.ylim(0, monthly_counts.max() * 1.2)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: f'{int(x):,}'))
plt.savefig("01_Monhtly_literature trend analysis.png", bbox_inches='tight', dpi=300)
plt.show()

###### Weekly Publications ######
plt.figure(figsize=(12, 8))
weekly_counts_smoothed = weekly_counts.rolling(window=4, center=True).mean()
plt.plot(weekly_counts.index, weekly_counts_smoothed, linestyle='-', color=blue_color, linewidth=2)
years = list(range(2014, 2026))  # Including 2025
year_labels = [f"{year}-W01" for year in years]
valid_year_labels = [label for label in year_labels if label in weekly_counts.index]
valid_years = [str(year) for year in years if f"{year}-W01" in weekly_counts.index]
if "2025-W01" not in valid_year_labels:
    valid_year_labels.append("2025-W01")
    valid_years.append("2025")

plt.xticks(valid_year_labels, valid_years, rotation=0)
plt.xlabel('Year')
plt.ylabel('Number of articles')
plt.title('Publications per week (2014-2025)')
plt.xlim(weekly_counts.index.min(), weekly_counts.index.max())
plt.ylim(0, weekly_counts_smoothed.max() * 1.2)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: f'{int(x):,}'))
plt.savefig("01_Weekly_literature trend analysis.png", bbox_inches='tight', dpi=300)
plt.show()

# =====================================================

# Figure 2

In [None]:
# Bar chart for ranking of summed relevant columns
os.chdir(input_directory)
oncomine_df = pd.read_csv("oncomine_ngs_panel.csv", header=None)
os.chdir(output_directory)
NER_gene_list = pd.read_csv("cleaned_BioBERT_data.csv")

relevant_columns = oncomine_df.iloc[:, 0].tolist()
column_sums = NER_gene_list[relevant_columns].sum().sort_values(ascending=False)
column_sums_formatted = column_sums.apply(lambda x: f"{x:,.0f}")
total_mentions = NER_gene_list['Sum_Gene_Mentions'].sum()
average_mentions = NER_gene_list['Sum_Gene_Mentions'].mean()
mentions_distribution = NER_gene_list['Sum_Gene_Mentions'].value_counts().sort_index()
mentions_distribution_formatted = mentions_distribution.apply(lambda x: f"{x:,}")


fig, ax = plt.subplots(figsize=(14, 7))
top_genes = column_sums.sort_values(ascending=False).head(15)
total_articles = len(NER_gene_list)
percentages = (top_genes / total_articles) * 100
average_mentions = column_sums.sum() / total_articles 
colors = [plt.cm.Blues(1 - (i / len(top_genes))) for i in range(len(top_genes))]

bars = top_genes.plot(kind='bar', color=colors, edgecolor='black', ax=ax)
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: f'{int(x):,}'))
ax.set_ylim(0, top_genes.max() * 1.2) 
for bar, value, percent in zip(bars.patches, top_genes, percentages):
    ax.text(
        bar.get_x() + bar.get_width() / 2, 
        bar.get_height() + (top_genes.max() * 0.04), 
        f"{int(value):,}\n({percent:.1f}%)", 
        ha='center', va='bottom', fontsize=10, fontweight='bold', color='black'
    )

ax.set_title(f'Top 15 genes mentioned by articles (n={total_articles:,.0f} articles)', 
             fontsize=14, fontweight='bold', pad=15)
ax.set_ylabel('Number of articles mentioning gene', fontsize=12)
ax.set_xlabel('Gene (out of 161 Oncomine NGS panel)', fontsize=12)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
fig.tight_layout(rect=[0, 0.05, 1, 0.9])
os.chdir(figures_for_publication)
plt.savefig("02_Top_15_mentioned_genes_in_publications.png", bbox_inches='tight', dpi=300)
plt.show()

# =====================================================

# Figure 3

In [None]:
# Distribution of gene mentions per article

%matplotlib inline  

fig, ax = plt.subplots(figsize=(12, 7))
colors = sns.color_palette("Blues_d", len(mentions_distribution))  
ax = sns.barplot(x=mentions_distribution.index, y=mentions_distribution.values, palette=colors, edgecolor='black')
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: f'{int(x):,}'))
ax.set_title(f'Distribution of gene mentions per article (n={total_articles:,.0f} articles)', fontsize=14, fontweight='bold', pad=25)
plt.figtext(0.5, 0.6, f'Average number of gene mentions per article: {average_mentions:,.2f}', fontsize=12, color='gray', ha='center')
ax.set_xlabel('Number of genes mentioned', fontsize=12, labelpad=10)
ax.set_ylabel('Number of articles', fontsize=12)
ax.set_xticks(range(len(mentions_distribution.index)))

ax.set_xticklabels(mentions_distribution.index, fontsize=10)
for tick in ax.get_xticklabels():
    tick.set_rotation(45)
    tick.set_horizontalalignment('right')
    tick.set_position((tick.get_position()[0] - 0.05, tick.get_position()[1]))
ax.set_ylim(0, max(mentions_distribution.values) * 1.2)
ax.grid(axis='y', linestyle='--', alpha=0.7)

for bar, value in zip(ax.patches, mentions_distribution.values):
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + (max(mentions_distribution.values) * 0.02), f"{value:,}", ha='center', va='bottom', fontsize=9, fontweight='bold', color='black')

    fig.tight_layout(rect=[0, 0.05, 1, 0.85])
os.chdir(figures_for_publication)
plt.savefig("03_Distribution_of_gene_mentions_publication.png", bbox_inches='tight', dpi=300)
plt.show()

# =====================================================

# Figure 4

In [None]:
# Most frequently mentioned cancer types

total_articles_with_specific_cancer_types = 199726
total_articles = len(NER_gene_list)
file_path = output_directory + "/cancer_category_occurrences_with_percentages.csv"
df_cancer_occurrences = pd.read_csv(file_path)
top_20_cancers = df_cancer_occurrences.nlargest(20, "count")
colors = ["royalblue" for _ in top_20_cancers["final_parent"]]
max_y_value = top_20_cancers["count"].max()
plt_ylim = max_y_value * 1.2
percentage_specific_cancer_articles = (total_articles_with_specific_cancer_types / total_articles) * 100
plt.figure(figsize=(14, 7))

bars = plt.bar(top_20_cancers["final_parent"], top_20_cancers["count"], color=colors)
plt.xlabel("Cancer type", fontsize=12, fontweight="bold")
plt.ylabel("Number of mentions", fontsize=12, fontweight="bold")
top_20_cancers["percentage"] = (top_20_cancers["count"] / total_articles_with_specific_cancer_types) * 100
plt.title(f"Most frequently mentioned cancer types\n in {total_articles_with_specific_cancer_types:,} "
          f"({percentage_specific_cancer_articles:.2f}%) out of {total_articles:,} articles", 
          fontsize=14, fontweight='bold')

plt.xticks(rotation=45, ha="right", fontsize=10)
plt.ylim(0, plt_ylim)
plt.grid(axis="y", linestyle="--", alpha=0.7)
for bar, value, percentage in zip(bars, top_20_cancers["count"], top_20_cancers["percentage"]):
    height = bar.get_height()
    plt.text(
        bar.get_x() + bar.get_width() / 2, 
        height + (max_y_value * 0.01), 
        f"{int(value):,}\n({percentage:.1f}%)", 
        ha="center", va="bottom", fontsize=10, fontweight='bold'
    )
    
plt.tight_layout()
plt.savefig(f"{figures_for_publication}/04_Most_frequent_cancer_types.png", bbox_inches='tight', dpi=300)
plt.show()

# =====================================================

# Figure 5

In [None]:
#Figure for treatment categoreis

treatment_articles=126195
percentage_treatment_articles=treatment_articles/len(NER_gene_list)*100
treatment_mapping_df = pd.read_csv(f"{output_directory}/treatment_mapping_with_matches.csv", low_memory=False)
CIVIC_ncit_df_finalparent_treatmentcategory = pd.read_csv(f"{output_directory}/CIVIC_ncit_df_finalparent_treatmentcategory.csv")
metadata_columns = ["PaperTitle", "Abstract", "Sum_treatments", "Treatment_matches"]
therapy_columns = [col for col in treatment_mapping_df.columns if col not in metadata_columns]
treatment_to_parent_mapping = CIVIC_ncit_df_finalparent_treatmentcategory.set_index("name")["parent_treatment_category"].to_dict()
parent_counts = {}
for treatment in therapy_columns:
    if treatment in treatment_to_parent_mapping:
        parent_category = treatment_to_parent_mapping[treatment]
        treatment_sum = treatment_mapping_df[treatment].sum()
        if parent_category in parent_counts:
            parent_counts[parent_category] += treatment_sum
        else:
            parent_counts[parent_category] = treatment_sum

            
            parent_counts_df = pd.DataFrame(list(parent_counts.items()), columns=["Parent Category", "Total Mentions"])
parent_counts_df = parent_counts_df.sort_values(by="Total Mentions", ascending=False)
if "Other therapy" in parent_counts_df["Parent Category"].values:
    other_therapy_row = parent_counts_df[parent_counts_df["Parent Category"] == "Other therapy"]
    parent_counts_df = parent_counts_df[parent_counts_df["Parent Category"] != "Other therapy"]
    parent_counts_df = parent_counts_df.head(6)
    parent_counts_df = pd.concat([parent_counts_df, other_therapy_row], ignore_index=True)
max_y_value = parent_counts_df["Total Mentions"].max()
plt_ylim = max_y_value * 1.2
plt.figure(figsize=(12, 6))
colors = plt.cm.Blues(np.linspace(1, 0.5, len(parent_counts_df)))
bars = plt.bar(parent_counts_df["Parent Category"], parent_counts_df["Total Mentions"], color=colors, edgecolor="black")
plt.xticks(rotation=45, ha="right", fontsize=10)
plt.yticks(fontsize=10)
plt.xlabel("Treatment category", fontsize=12)
plt.ylabel("Number of mentions", fontsize=12)
plt.title(f"Most frequent mentions of treatment categories\n"
          f"in {treatment_articles:,} ({percentage_treatment_articles:.2f}%) out of {total_articles:,} articles",
          fontsize=14)

plt.ylim(0, plt_ylim)
plt.gca().yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{int(x):,}"))
plt.grid(axis="y", linestyle="--", alpha=0.7)
for bar, value in zip(bars, parent_counts_df["Total Mentions"]):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width() / 2, height + (max_y_value * 0.01), f"{int(value):,}", ha="center", va="bottom", fontsize=10, fontweight='bold')

    plt.tight_layout()
plt.savefig(f"{figures_for_publication}/05_Treatment_categories.png", bbox_inches='tight', dpi=300)
plt.show()

# =====================================================

# Figure 6

In [None]:
# Figre for specific treatment mentions

treatment_summary = pd.read_csv(f"{output_directory}/Top_30_treatments_summary.csv")
treatment_summary = treatment_summary.sort_values(by="Count", ascending=False)

treatment_names = treatment_summary["Treatment"]
treatment_counts = treatment_summary["Count"]

colors = plt.cm.Blues(np.linspace(1, 0.5, len(treatment_summary)))
plt.figure(figsize=(12, 6))
bars = plt.bar(treatment_names, treatment_counts, color=colors, edgecolor='black')
plt.xticks(rotation=45, ha="right", fontsize=10)
plt.yticks(fontsize=10)
plt.xlabel("Treatment name", fontsize=12)
plt.ylabel("Number of mentions", fontsize=12)
plt.title(f"Top 30 most mentioned specific treatment mentions\n"
          f"in {treatment_articles:,} ({percentage_treatment_articles:.2f}%) out of {total_articles:,} articles",
          fontsize=14)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.ylim(0, max(treatment_counts) * 1.2)
plt.gca().yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{int(x):,}"))

for bar, value in zip(bars, treatment_counts):
    plt.text(bar.get_x() + bar.get_width() / 2, value + (max(treatment_counts) * 0.01),
             f"{int(value):,}", ha="center", va="bottom", fontsize=9, fontweight='normal', rotation=45)

plt.tight_layout()
plt.savefig(f"{figures_for_publication}/06_Figure_of_specific_treatments.png", bbox_inches='tight', dpi=300)
plt.show()

# =====================================================

# Figure 7

In [None]:
# Distribution of study designs
dataset_to_display = f"{classifier_directory}/final_gc_classificaton_output_199726.csv"
stud_cat = pd.read_csv(dataset_to_display)

stud_cat["Study_design_clean"] = stud_cat["Study_design"].replace({
    "Behavioral study": "Other",
    "Undefined": "Other"
})

# Count and calculate percentages
study_counts = stud_cat["Study_design_clean"].value_counts()
total_count = study_counts.sum()
percentages = (study_counts / total_count * 100).round(2)
total_articles_with_specific_cancer_types = total_count

study_counts_df = pd.DataFrame({
    "Study Design": study_counts.index,
    "Count": study_counts.values,
    "Percentage": percentages.values
})


# Update color mapping with "Other" instead of individual undefined/behavioral
color_mapping = {
    "In vitro study": "#636EFA",
    "Clinical study": "#EF553B",
    "Systematic review study": "#00CC96",
    "In vivo/Animal study": "#AB63FA",
    "In silico study": "#FFA15A",
    "Case report study": "#17becf",
    "Observational/RWE study": "#FF6692",
    "Other": "#B6E880"  
}


colors = [color_mapping.get(study, "#333333") for study in study_counts_df["Study Design"]]
plt.figure(figsize=(12, 6))
bars = sns.barplot(x=study_counts_df["Study Design"], y=study_counts_df["Count"], palette=colors)
plt.xticks(rotation=45, ha='right', fontsize=12)
plt.xlabel("Study design type", fontsize=14)
plt.ylabel("Count", fontsize=14)
plt.title(f"Distribution of study designs of {total_articles_with_specific_cancer_types:,} articles", fontsize=16)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.ylim(0, study_counts_df["Count"].max() * 1.3)

for bar, count, percentage in zip(bars.patches, study_counts_df["Count"], study_counts_df["Percentage"]):
    height = bar.get_height()
    plt.text(
        bar.get_x() + bar.get_width() / 2, 
        height + (study_counts_df["Count"].max() * 0.01),  
        f"{int(count):,}\n({percentage}%)",  # Newline label
        ha="center", va="bottom", fontsize=10, fontweight='bold'
    )

plt.tight_layout()
plt.savefig(f"{figures_for_publication}/07_study_design_type_barchart.png", bbox_inches='tight', dpi=300)
plt.show()

# =====================================================

# Figure 8

In [None]:
# Top 20 most frequent variant mentions

top_20_csv_path = f"{output_directory}/Top_20_Variants.csv"
top_20_df = pd.read_csv(top_20_csv_path)

if "Percentage" not in top_20_df.columns:
    raise ValueError("The 'Percentage' column is missing from the loaded CSV file. Please check your file.")

total_mentions = top_20_df["Count"].sum()
sum_of_percentages = top_20_df["Percentage"].sum()
total_articles_estimate = int(total_mentions / (sum_of_percentages / 100))

plt.figure(figsize=(12, 8))
bars = plt.bar(top_20_df["Variant"], top_20_df["Count"], color='#1f20b4')

plt.title(f'Top 20 most frequent mentions of variants in {total_articles_estimate:,} articles', fontsize=16)
plt.xlabel('Variants', fontsize=14)
plt.ylabel('Mentions', fontsize=14)
plt.xticks(rotation=45, ha="right", fontsize=12)
plt.ylim(0, top_20_df["Count"].max() * 1.4)

plt.gca().yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{int(x):,}'))
plt.grid(axis='y', linestyle='--', alpha=0.7)

for i, (count, percentage) in enumerate(zip(top_20_df["Count"], top_20_df["Percentage"])):
    plt.text(i, count + (top_20_df["Count"].max() * 0.03), 
             f"{int(count):,} ({percentage:.2f}%)", 
             ha='left', va='bottom', fontsize=10, fontweight='normal', 
             rotation=45, rotation_mode='default')  

plt.tight_layout()
plt.savefig(f"{figures_for_publication}/08_Top_20_variants_in_identified_articles.png", bbox_inches='tight', dpi=300)
plt.show()

# =====================================================

# Figure 9

In [None]:
# Oncomine figure
top_20_csv_path = f"{output_directory}/oncomine_gene_summary_stats_forfigure.csv"
df_summary = pd.read_csv(top_20_csv_path)

# Define categories to include for each plot and ensure correct order
counts_categories = [
    "Oncomine genes extracted",  
    "Oncomine genes not extracted",  
    "Oncomine genes total"  
]

percentages_categories = [
    "Oncomine gene mentions", 
    "Other gene mentions", 
    "Total gene mentions"
]

counts_data = df_summary[df_summary['Category'].isin(counts_categories)][['Category', 'Count', 'Percentage']].copy()
counts_data['Count'] = pd.to_numeric(counts_data['Count'], errors='coerce')
counts_data['Percentage'] = pd.to_numeric(counts_data['Percentage'], errors='coerce')
counts_data = counts_data.set_index('Category').reindex(counts_categories).reset_index()

percentages_data = df_summary[df_summary['Category'].isin(percentages_categories)][['Category', 'Count', 'Percentage']].copy()
percentages_data['Count'] = pd.to_numeric(percentages_data['Count'], errors='coerce')
percentages_data['Percentage'] = pd.to_numeric(percentages_data['Percentage'], errors='coerce')
percentages_data = percentages_data.set_index('Category').reindex(percentages_categories).reset_index()

total_mentions_value = int(percentages_data[percentages_data['Category'] == 'Total gene mentions']['Count'].values[0])
formatted_total_mentions = f"{total_mentions_value:,}"

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10), constrained_layout=True)
royal_blue = '#1A237E'
lighter_blue = '#5C6BC0'
light_light_blue = '#BBDEFB'

left_axis_range = counts_data['Count'].max() * 1.26
right_axis_range = percentages_data['Count'].max() * 1.26

left_padding = left_axis_range * 0.01
right_padding = right_axis_range * 0.01

additional_padding_left = 1
additional_padding_right = 300

def plot_labels(ax, data, padding, additional_padding):
    for index, (count, percentage) in enumerate(zip(data['Count'], data['Percentage'])):
        if not pd.isna(count):
            label_text = f"{int(count):,}\n({percentage:.2f}%)"
            ax.text(index, count + padding + additional_padding, label_text, ha='center', fontsize=14, fontweight='bold', color='black')

# Plot 1: Counts (LEFT)
colors_left = [lighter_blue, light_light_blue, royal_blue] 
counts_data.plot(kind='bar', x='Category', y='Count', ax=ax1, color=colors_left, edgecolor='black', legend=False)
ax1.set_title('Extraction of specific variant-associated \nOncomine genes (total of n=161)', 
              fontsize=20, fontweight='bold', pad=5)
ax1.set_xlabel('', fontsize=16)
ax1.set_ylabel('Count', fontsize=16)
ax1.tick_params(axis='x', rotation=0, labelsize=14)
ax1.tick_params(axis='y', labelsize=14)
ax1.set_ylim(0, left_axis_range)
ax1.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{int(x):,}'))
ax1.grid(axis='y', linestyle='--', alpha=0.7)

plot_labels(ax1, counts_data, left_padding, additional_padding_left)

# Plot 2: Percentages (RIGHT)
colors_right = [lighter_blue, light_light_blue, royal_blue] 
percentages_data.plot(kind='bar', x='Category', y='Count', ax=ax2, color=colors_right, edgecolor='black', legend=False)
ax2.set_title(f'Mention of variant-associated Oncomine genes across \nall identified gene mentions (n={formatted_total_mentions} articles)', 
              fontsize=20, fontweight='bold', pad=5)
ax2.set_xlabel('', fontsize=16)
ax2.set_ylabel('Count', fontsize=16)
ax2.tick_params(axis='x', rotation=0, labelsize=14)
ax2.tick_params(axis='y', labelsize=14)
ax2.set_ylim(0, right_axis_range)
ax2.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{int(x):,}'))
ax2.grid(axis='y', linestyle='--', alpha=0.7)

plot_labels(ax2, percentages_data, right_padding, additional_padding_right)

plt.savefig(f"{figures_for_publication}/09_Oncomine_Gene_Analysis.png", bbox_inches='tight', dpi=300)
plt.show()


# =====================================================

# Figure 10

In [None]:
#Progressive merging and fitlerting of datasets

csv_path = f"{variantscape_directory}/conditional_filterting_of_final_datasets_for_coassciaitions_analysis.csv"
dataset_lengths_df = pd.read_csv(csv_path)
counts = dataset_lengths_df['Count'].tolist()
count_1 = counts[0]
percentages = [(count / count_1 * 100) if count_1 > 0 else 0 for count in counts]
labels = [
    "Extracted Oncomine genes",
    "Extracted cancer types",
    "Classified study design", 
    "Extracted treatments",   
    "Identified variants",
    "All conditions TRUE"
]
fig, ax = plt.subplots(figsize=(12, 6))
colors = plt.cm.Blues(np.linspace(0.3, 1, len(labels))) 
bars = ax.bar(labels, counts, color=colors, edgecolor='black', alpha=0.95)
for bar, count, pct in zip(bars, counts, percentages):
    label = f"{count:,} ({pct:.1f}%)"
    ax.text(bar.get_x() + bar.get_width() / 2,
            bar.get_height() + max(counts) * 0.01,
            label,
            ha='center', va='bottom', fontsize=12, fontweight='bold')
ax.set_title("Progressive merging and filterting of datasets", fontsize=16, fontweight='bold', pad=20)
ax.set_ylabel("Number of articles", fontsize=14, fontweight='bold')
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, fontsize=12, ha="right", rotation=15)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig(f"{figures_for_publication}/10_filterting_prgoressing_of_datasets.png", bbox_inches='tight', dpi=300)
plt.show()


# =====================================================