In [1]:
import pickle
import glob
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

disease1 = "ibd"
disease2 = "ibs"

files1 = glob.glob('results/'+disease1+'/'+disease1+'_dicts/*')
files2 = glob.glob('results/'+disease2+'/'+disease2+'_dicts/*')
to_classify = ["sample", "domain", "phylum", "class", "order", "family", "genus"]
classifier = "phylum"

In [8]:
# calculate thresholds

# find average value of each
avg_d1 = {}
avg_d2 = {}
total_organisms1 = {}
total_organisms2 = {}
num_files1 = len(files1)
num_files2 = len(files2)

for file in files1:
    with open(file, 'rb') as f:
        taxa_dict = pickle.load(f)
    
    # count total organisms per patient
    organism_count = 0
    for i in taxa_dict[classifier]:
        organism_count += taxa_dict[classifier][i]
    total_organisms1[file] = organism_count

    # add normalized data to the total, and divide by number of patients to average
    for i in taxa_dict[classifier]:
        if i not in avg_d1:
            avg_d1[i] = taxa_dict[classifier][i] / (num_files1 * organism_count)
        else:
            avg_d1[i] += taxa_dict[classifier][i] / (num_files1 * organism_count)


for file in files2:
    with open(file, 'rb') as f:
        taxa_dict = pickle.load(f)
    
    # count total organisms per patient
    organism_count = 0
    for i in taxa_dict[classifier]:
        organism_count += taxa_dict[classifier][i]
    total_organisms2[file] = organism_count

    # add normalized data to the total, and divide by number of patients to average
    for i in taxa_dict[classifier]:
        if i not in avg_d2:
            avg_d2[i] = taxa_dict[classifier][i] / (num_files2 * organism_count)
        else:
            avg_d2[i] += taxa_dict[classifier][i] / (num_files2 * organism_count)

In [9]:
# count number of patients both are above threshold in

total_files_per_sample = {}

for file in files1:
    with open(file, 'rb') as f:
        taxa_dict = pickle.load(f)
        
    for i in taxa_dict[classifier]:
        if i not in total_files_per_sample:
                total_files_per_sample[i] = {}

        if taxa_dict[classifier][i] / total_organisms1[file] < avg_d1[i]:
            continue

        else:
            for file2 in files2:
                with open(file2, 'rb') as f2:
                    taxa_dict2 = pickle.load(f2)
                for j in taxa_dict2[classifier]:
                    if taxa_dict2[classifier][j] / total_organisms2[file2] >= avg_d2[j]:
                        if j not in total_files_per_sample[i]:
                            total_files_per_sample[i][j] = 1
                        else:
                            total_files_per_sample[i][j] += 1


In [None]:
df = pd.DataFrame.from_dict(total_files_per_sample, orient='index')

plt.figure(figsize=(15, 15), dpi=600)
coloring = sns.color_palette("ch:s=-.2,r=.6", as_cmap=True)
coloring.set_bad("#E3D8C1")
heatmap = sns.heatmap(df, annot = False, cmap = coloring, cbar_kws={'label': 'patients that both are upregulated in'})
heatmap.set_title("Correlation between taxa in " + disease1.upper() + " and " + disease2.upper() + " ("+classifier+")")