In [None]:
import os, sys
import pandas as pd
import pickle
import matplotlib.pyplot as plt

In [None]:
## Read data
data_loc = './model_evaluation/data/drug_class_experiment'
all_true_positive_pairs = pd.read_csv(os.path.join(data_loc, 'all_true_positive_pairs.txt'), sep='\t', header=0)
all_true_positive_pairs['drug_class'] = all_true_positive_pairs['drug_class'].apply(lambda value: eval(value) if '[' in value else value)

with open(os.path.join(data_loc, 'test_diseases_info.pkl'), 'rb') as f:
    drug_class_dict = pickle.load(f)

In [None]:
## Draw drug class distribution in train set
train_drug_class = {}
train_drug_class_list = all_true_positive_pairs.loc[all_true_positive_pairs['data_type'] == 'train','drug_class'].to_list()
for this_drug_class in train_drug_class_list:
    if type(this_drug_class) is str:
        if this_drug_class not in train_drug_class:
            train_drug_class[this_drug_class] = 1
        else:
            train_drug_class[this_drug_class] += 1
    elif type(this_drug_class) is list:
        for x in this_drug_class:
            if x[1] not in train_drug_class:
                train_drug_class[x[1]] = 1
            else:
                train_drug_class[x[1]] += 1
## Sort drug class by frequency, calculate the percentage, and draw pie chart 
train_drug_class = sorted(train_drug_class.items(), key=lambda item: item[1], reverse=True)
total = sum([x[1] for x in train_drug_class])

labels = []
freqs = []
for index in range(10):
    labels += [f"{train_drug_class[index][0]} ({train_drug_class[index][1]/total * 100:.1f}%)" ]
    freqs += [train_drug_class[index][1]]
other_freq = sum([x[1] for x in train_drug_class[10:]])
labels += [f"Others ({other_freq/total * 100:.1f}%)" ]
freqs += [other_freq]

#colors
colors = ['#1f77b4','#ff7f0e','#2ca02c','#d62728','#9467bd','#8c564b','#e377c2','#7f7f7f','#bcbd22','#17becf','#aec7e8']

fig1, ax1 = plt.subplots()
patches, texts = ax1.pie(freqs, colors=colors, startangle=90)
plt.legend(patches, labels, loc='center left', bbox_to_anchor=(1, 0.5))
ax1.axis('equal')  
plt.tight_layout()
# plt.show()
fig1.savefig('draw_figures/train_drug_class_distribution.svg')


In [None]:
## Draw drug class distribution in test set
test_drug_class = {}
test_drug_class_list = all_true_positive_pairs.loc[all_true_positive_pairs['data_type'] == 'test','drug_class'].to_list()
for this_drug_class in test_drug_class_list:
    if type(this_drug_class) is str:
        if this_drug_class not in test_drug_class:
            test_drug_class[this_drug_class] = 1
        else:
            test_drug_class[this_drug_class] += 1
    elif type(this_drug_class) is list:
        for x in this_drug_class:
            if x[1] not in test_drug_class:
                test_drug_class[x[1]] = 1
            else:
                test_drug_class[x[1]] += 1
## Sort drug class by frequency, calculate the percentage, and draw pie chart 
test_drug_class = sorted(test_drug_class.items(), key=lambda item: item[1], reverse=True)
total = sum([x[1] for x in test_drug_class])

labels = []
freqs = []
for index in range(10):
    labels += [f"{test_drug_class[index][0]} ({test_drug_class[index][1]/total * 100:.1f}%)" ]
    freqs += [test_drug_class[index][1]]
other_freq = sum([x[1] for x in test_drug_class[10:]])
labels += [f"Others ({other_freq/total * 100:.1f}%)" ]
freqs += [other_freq]

#colors
colors = ['#1f77b4','#ff7f0e','#2ca02c','#8c564b','#7f7f7f','#d62728','#bcbd22','#ffbb78','#e377c2','#9467bd','#aec7e8']

fig1, ax1 = plt.subplots()
patches, texts = ax1.pie(freqs, colors=colors, startangle=90)
plt.legend(patches, labels, loc='center left', bbox_to_anchor=(1, 0.5))
ax1.axis('equal')  
plt.tight_layout()
# plt.show()
fig1.savefig('draw_figures/test_drug_class_distribution.svg')

In [None]:
## Calcualte number of drugs class not in the drug-disease pairs in the training set for each disease in test set
disease_curie_stat = {}
for disease_curie in drug_class_dict:
    this_df = drug_class_dict[disease_curie]['top_500_drugs']
    new_this_df = this_df.loc[this_df['is_in_tarin'] == False,:][:100].reset_index(drop=True)
    disease_curie_stat[disease_curie] = len(set([y for x in new_this_df.loc[new_this_df['known_drug_class'] == False,'drug_class'].to_list() for y in x]))

In [None]:
## Draw histogram
# create a list of values
values = list(disease_curie_stat.values())

# create the histogram
fig, ax = plt.subplots()
n, bins, patches = ax.hist(values, bins=100, edgecolor='black')

# set x-axis limits
min_v, max_v = ax.get_xlim()
ax.set_xlim(min_v, 100.5)

# add labels and a title
ax.set_xlabel('Number of Drug Classes')
ax.set_ylabel('Frequency')
ax.set_title('Histogram of Unseen Drug Classes \n in Top 100 Predicted Non-train Drugs for Each Disease in Test Set')

# display the histogram
plt.tight_layout()
plt.show()
fig.savefig('draw_figures/histogram_unseen_drug_classes.svg')