In [46]:
import pandas as pd
import matplotlib.pyplot as plt
import os

In [47]:
combined_df = pd.read_csv("th17_il23r_spillovers.csv")
IL23R = 149233

In [48]:
mask_keep = (
    (combined_df["gwas_year"].isna()) |
    (combined_df["gene_id"] == IL23R) |
    (combined_df["patent_year"] < combined_df["gwas_year"])
)
combined_pre = combined_df.loc[mask_keep].copy()


In [49]:
# Make sure output folder exists
os.makedirs("graphs", exist_ok=True)

# ----- (1) IL23R only -----
il23r_df = combined_pre[combined_pre["gene_id"] == IL23R]
il23r_total = il23r_df.groupby("patent_year")["num_patents"].sum().reset_index()

plt.figure(figsize=(12,7))
plt.bar(il23r_total["patent_year"], il23r_total["num_patents"], color="skyblue")
plt.title("IL23R: Number of Patents per Year (all years kept)")
plt.xlabel("Patent Year")
plt.ylabel("Total Patents")
plt.grid(True, axis="y", linestyle="--", linewidth=0.5)
plt.axvline(x=2007, color="red", linestyle=":", linewidth=2, label="Key Event (2007)")
plt.legend()
plt.savefig("graphs/il23r_patents.png")
plt.close()

# ----- (2) Everything except IL23R -----
non_il23r_df = combined_pre[combined_pre["gene_id"] != IL23R]

# mean
non_il23r_mean = non_il23r_df.groupby("patent_year")["num_patents"].mean().reset_index()
plt.figure(figsize=(12,7))
plt.bar(non_il23r_mean["patent_year"], non_il23r_mean["num_patents"], color="skyblue")
plt.title("Mean Patents per Year (excluding IL23R; pre-GWAS only by gene)")
plt.xlabel("Patent Year")
plt.ylabel("Mean Number of Patents")
plt.grid(True, axis="y", linestyle="--", linewidth=0.5)
plt.axvline(x=2007, color="red", linestyle=":", linewidth=2, label="Key Event (2007)")
plt.legend()
plt.savefig("graphs/non_il23r_mean.png")
plt.close()

# total
non_il23r_total = non_il23r_df.groupby("patent_year")["num_patents"].sum().reset_index()
plt.figure(figsize=(12,7))
plt.bar(non_il23r_total["patent_year"], non_il23r_total["num_patents"], color="skyblue")
plt.title("Total Patents per Year (excluding IL23R; pre-GWAS only by gene)")
plt.xlabel("Patent Year")
plt.ylabel("Total Number of Patents")
plt.grid(True, axis="y", linestyle="--", linewidth=0.5)
plt.axvline(x=2007, color="red", linestyle=":", linewidth=2, label="Key Event (2007)")
plt.legend()
plt.savefig("graphs/non_il23r_total.png")
plt.close()

# ----- (3) All genes together -----
# mean
all_mean = combined_pre.groupby("patent_year")["num_patents"].mean().reset_index()
plt.figure(figsize=(12,7))
plt.bar(all_mean["patent_year"], all_mean["num_patents"], color="skyblue")
plt.title("Mean Patents per Year (all genes; pre-GWAS rule except IL23R)")
plt.xlabel("Patent Year")
plt.ylabel("Mean Number of Patents")
plt.grid(True, axis="y", linestyle="--", linewidth=0.5)
plt.axvline(x=2007, color="red", linestyle=":", linewidth=2, label="Key Event (2007)")
plt.legend()
plt.savefig("graphs/all_genes_mean.png")
plt.close()

# total
all_total = combined_pre.groupby("patent_year")["num_patents"].sum().reset_index()
plt.figure(figsize=(12,7))
plt.bar(all_total["patent_year"], all_total["num_patents"], color="skyblue")
plt.title("Total Patents per Year (all genes; pre-GWAS rule except IL23R)")
plt.xlabel("Patent Year")
plt.ylabel("Total Number of Patents")
plt.grid(True, axis="y", linestyle="--", linewidth=0.5)
plt.axvline(x=2007, color="red", linestyle=":", linewidth=2, label="Key Event (2007)")
plt.legend()
plt.savefig("graphs/all_genes_total.png")
plt.close()

print("All five plots saved in 'graphs/' with light blue bars.")

All five plots saved in 'graphs/' with light blue bars.
