In [2]:
# Web
import requests as rq

# Data analysis
import numpy as np
import pandas as pd
import json
%matplotlib notebook
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import random as rand

# Convenience
from tqdm import tqdm
from collections import Counter, defaultdict
from functools import partial
import gc
import pickle
sns.set(style='darkgrid', palette='muted')

## Load table

In [4]:
# Load tables

#big_table = pd.read_hdf("./big_table.hd5")
# table = pd.read_hdf("./table.hd5")

In [None]:
big_table = pd.read_hdf("./big_table.hd5")

In [None]:
table.head(25)

In [None]:
big_table[big_table[1].str.contains("atorvastatin")]

In [None]:
table[table["Drug"] == "Vitamin C"]["drug_ids"]

In [None]:
for ind, row in table.iterrows():
    if len(row["drug_ids"]) == 2 and len(row["disease_ids"]) > 3:
        print(row["Disease"], row["Drug"], row["drug_ids"], row["disease_ids"])

In [None]:
#table = pd.read_pickle("./table_all_indirect.pkl")

In [None]:
def get_first_without_direct(table):
    for ind, row in table.iterrows():
        if row["relation_info"] == []:
            return row

In [None]:
# row = get_first_without_direct(table)
# print(row["disease_ids"], row["drug_ids"], row["Drug"], row["Disease"])

## API calls variables

In [None]:
base_uri = "https://euretos-brain.com/spine-ws/"

auth = "/login/authenticate"
search = "/external/concepts/search"
direct_semantic = "/external/concept-to-semantic/direct"
indirect = "/external/concept-to-concept/indirect"
size = "?size=10000"

## Authenticate

In [None]:
login_info = json.load(open('../leiden_login_info/auth_info_main.json'))

login_req = rq.post(base_uri + auth, json=login_info)

headers = {'x-token': login_req.json()['token']}

## Get indirect relations

In [None]:
def get_indirect(row, table_number, row_number, filt=None):
    drug_ids = list(map(str, row['drug_ids']))
    disease_ids = list(map(str, row['disease_ids']))
    json_drug = {
      "additionalFields": ['directionalTriples', 'semanticCategory'],
      "positiveFilters": [filt] if filt else [],
      "leftInputs": drug_ids,
      "relationshipWeightAlgorithm": "PWS",
      "rightInputs": disease_ids,
      "sort": "ASC"
    }
    print('*', end="")
    resp = rq.post(base_uri + indirect + size, json=json_drug, headers=headers)
    filename = "./indirect_jsons/{}_{}.json".format(table_number, row_number)
    with open(filename, "w") as file_for_json:
        json.dump(resp.json(), file_for_json)
    return

In [None]:
table = pd.read_hdf("table.hd5", key="w")

In [None]:
mapped_drugs_set = set()
drugs_set = set()
for ind, row in table.iterrows():
    mapped_drugs_set.add(tuple(row["drug_ids"]))
    drugs_set.add(row["Drug"])
print(len(drugs_set), len(mapped_drugs_set))

In [None]:
# genes_genomes = "st:T028"

In [None]:
# table_number = "real"
# for row_number, row in table.iterrows():
#     get_indirect(row, table_number, row_number)

## Convenience stuff

In [2]:
class SpecialCounter(Counter):
    ''' Counter, which can be divided by numeric. '''
    def __truediv__(self, other):
        new = SpecialCounter(self)
        for key in new.keys():
            new[key] /= other
        return new

In [3]:
def partial_second(f, second):
    ''' Helper function for reducing second 
    argument. '''
    def f_part(first):
        return f(first, second)
    return f_part

In [4]:
def check_derangement(start, perm):
    '''Checks if perm is derangement of start.'''
    for o, p in zip(start, perm):
        print(o, p)
        if o == p:
            return False
    return True

In [5]:
def gimme_derangement(start, random_state):
    '''Returns a derangemnet of start, which is pandas Series. 
    Need to provide random_state for sampling.'''
    while True:
        print("One more time.")
        random_state += 1
        perm = start.sample(frac=1, random_state=random_state).reset_index(drop=True)
        if check_derangement(start, perm):
            break
    return perm, random_state

In [6]:
def average(lst):
    return sum(lst) / len(lst)

In [7]:
sc = SpecialCounter("shazam")
print(sc)
print(sc / 2)

SpecialCounter({'a': 2, 's': 1, 'z': 1, 'm': 1, 'h': 1})
SpecialCounter({'a': 1.0, 's': 0.5, 'z': 0.5, 'm': 0.5, 'h': 0.5})


## Get semantic type counters

In [8]:
def sem_cat_gener(table, table_number, protocol="semanticCategory"):
    ''' Generator of semantic categories lists for 
    each drug-disease pair. protocol parameter can be "semanticCategory"
    (by default), "diversity" or "semanticType"'''
#     rows = table["all_in_between"]
    for ind, row in table.iterrows():
        sem_cat_list = list()
        drug_count = len(row["drug_ids"])
        disease_count = len(row["disease_ids"])
        json_filename = "./indirect_jsons/{}_{}.json".format(table_number, ind)
        
        with open(json_filename, "r") as json_file:
            all_in_between = json.load(json_file)
            content = all_in_between["content"]
            
        for entry in content:
            concept = entry["concepts"][1]
            if protocol == "semanticType":
                sem_types = [SEM_TYPES[sem_type] if sem_type in SEM_TYPES else sem_type 
                            for sem_type in concept["semanticTypes"]]
                sem_cat_list = sem_cat_list + sem_types
            elif protocol == "semanticCategory" or protocol == "diversity":
                sem_cat_list.append(concept["semanticCategory"])
        yield sem_cat_list, drug_count, disease_count
    #         if  length > 3:
    #             print(length)
    return

def get_sem_cat_counter_list(table, table_number, protocol="semanticCategory"):
    '''protocol parameter can be "semanticCategory" (by default), "diversity" or "semanticType" '''
    assert protocol in ["semanticType", "diversity", "semanticCategory"]

    if protocol == "diversity":
        sem_cat_counter_list = [(SpecialCounter(sem_cat_list), drug_count * disease_count) 
                                for sem_cat_list, drug_count, disease_count 
                                in sem_cat_gener(table, table_number, protocol) 
                                if drug_count * disease_count != 0]
    elif protocol == "semanticCategory":
        sem_cat_counter_list = [SpecialCounter(sem_cat_list) / (drug_count * disease_count) 
                                for sem_cat_list, drug_count, disease_count 
                                in sem_cat_gener(table, table_number, protocol)]
    elif protocol == "semanticType":
        sem_cat_counter_list = [SpecialCounter(sem_cat_list) / (drug_count * disease_count) 
                                for sem_cat_list, drug_count, disease_count 
                                in sem_cat_gener(table, table_number, protocol)]
    return sem_cat_counter_list

def get_hist(counter_list, cat=None, bins=None):
    if cat is not None:
        get_track_key = partial_second(Counter.__getitem__, cat)
        track = list(map(get_track_key, counter_list))
    else:
        track = counter_list        
    if bins is not None:
        hist = np.histogram(track, bins=bins)
    else:
        hist = np.histogram(track, bins=100)
    return hist

def plot_for_cat(cat, sem_cat_counter_list, color, bins, stds=None, max_lim=None):
    ''' Function for plotting distribution by category. '''
    
    if stds:
        bin_width = bins[1] - bins[0]
        bins = bins[:-1] # getting left edges for bar plot
        plt.bar(bins, sem_cat_counter_list, color=color, width=bin_width, yerr=stds, alpha=0.5)
        
    else:
        get_track_key = partial_second(Counter.__getitem__, cat)
        track = list(map(get_track_key, sem_cat_counter_list))
        min_lim = 0
        if cat == "Genes & Molecular Sequences":
            max_lim = np.percentile(track, 90)
        else:
            max_lim = np.percentile(track, 80)
        for i in range(len(bins) - 1):
            if bins[i] < max_lim and bins[i + 1] > max_lim:
                prev_max_bin = bins[i]
                max_bin = bins[i + 1] 
                break
        plt.xlim(min_lim, max_bin)
        bins = (bins[:-1] + bins[1:]) / 2
        plt.hist(track, color=color, bins=bins, alpha=0.5, range=(min_lim, max_lim))
        return max_lim
        
    plt.xlabel("Normalized number of semantic category occurence between drug and disease")
    plt.ylabel("Drug-disease count")
    
def plot_both(cat, real_counter, randomized_counter_list):
    ''' Plots both random and nonrandom distributions 
    of concept counts within semantic category. ''' 
    bins = get_binsizes(cat, real_counter, randomized_counter_list)
    plt.figure()

    stds, avg_hist = get_stds_avghist(randomized_counter_list, bins, cat)
    
    max_lim = plot_for_cat(cat, real_counter, "red", bins)
    red_patch = mpatches.Patch(color="red", alpha=0.5, label="positive")
    plot_for_cat(cat, avg_hist, "black", bins, stds, max_lim)
    black_patch = mpatches.Patch(color="black", alpha=0.5, label="negative")
    plt.legend(handles=[red_patch, black_patch])
    plt.title(cat)
    plt.savefig("/home/explover/%s.png" % cat, format="png")

def get_stds_avghist(counter_list, bins, cat=None):
    ''' Gets standard deviations and average bar heights
    for list of distributions. '''
    
    hist_list = [get_hist(counter, cat, bins)[0] for counter in counter_list]
    stds = list(map(np.std, zip(*hist_list)))
    avghist = list(map(average, zip(*hist_list)))
    return stds, avghist

def get_binsizes(cat, real_counter, randomized_counter_list):
    summed_randomized_counter = list()
    scaled_real_counter = list()
    for counter in randomized_counter_list:
        scaled_real_counter += real_counter
        summed_randomized_counter += counter
    sum_hist = get_hist(scaled_real_counter + summed_randomized_counter, cat)
    bins = sum_hist[1]
    return bins

## Check intermediate concepts

In [None]:
def check_length(table):
    rows = table["all_in_between"]
    for i, row in enumerate(rows):
        print(i, row['totalElements'])

## Randomize table and fetch indirect relations

In [11]:
number = 10

In [12]:
def randomize_drugs_diseases(number, table):
    ''' Randomizes table and fetches indirect 
    relations certain number of times.'''
    random_table_counters = list()
    
    existing_set = set(zip(map(tuple, table["drug_ids"]), map(tuple, table["disease_ids"])))
    drug_ids_set = set(map(tuple, table["drug_ids"]))
    disease_ids_set = set(map(tuple, table["disease_ids"]))
    
    for num in range(number):
        rand.seed(0)
        random_table = pd.DataFrame(columns=['drug_ids', 'disease_ids'])
        sample_size = 0
        while sample_size < len(table):
            while True:
                drug_ids = rand.choice(list(drug_ids_set))
                disease_ids = rand.choice(list(disease_ids_set))
                drug_disease = (drug_ids, disease_ids)
                if not drug_disease in existing_set:
                    existing_set.add(drug_disease)
                    random_table.loc[sample_size] = [drug_ids, disease_ids]
                    sample_size += 1
                    break
    
        print("Fetching indirect...")
        for row_number, row in random_table.iterrows():
            get_indirect(row, num, row_number)
        counter = get_sem_cat_counter_list(random_table, num)        
        random_table_counters.append(counter)
        random_table.to_hdf('./random/little_random_table_%i.hd5' % num, 'w') # for safety

    return random_table_counters

def get_counters(number, protocol="semanticCategory"):
    '''Gets counters from existing tables and corresponding jsons.'''
    random_table_counters = list()
    for num in tqdm(range(number)):
        random_table = pd.read_hdf("./id_table_neg_%i.hdf" % num)
        counter_list = get_sem_cat_counter_list(random_table, num, protocol)
        random_table_counters.append(counter_list)
        
    return random_table_counters

In [14]:
# random_table_counters = randomize_drugs_diseases(number, table)
random_table_counters = get_counters(10)


  0%|          | 0/10 [00:00<?, ?it/s][A
100%|██████████| 10/10 [02:09<00:00, 14.10s/it]


In [15]:
nonrandom_counter_list = get_sem_cat_counter_list(table, "real")

NameError: name 'table' is not defined

In [None]:
table[["Disease", "Drug", "Drugbank ID", "Palliative", "In FDA label", "Relative Efficacy (RE)", "Indication (from DailyMed)", "Note"]].head(20)

## Plot count within semantic categories

In [None]:
# Get semantic categories
cats = set()
for counter in nonrandom_counter_list:
    cats = cats | set(counter.keys())

In [None]:
for cat in cats:
    plot_both(cat, nonrandom_counter_list, random_table_counters)

## Plot diversity

In [None]:
def plot_diversity(nonrandom_counter, random_counter_list):
    diversity_random_list_list = [[len(counter.keys()) / norm for counter, norm in random_counter] 
                           for random_counter in random_counter_list]
    diversity_nonrandom_list = [len(counter.keys()) / norm for counter, norm in nonrandom_counter]
    
    whole_diversity_random_list = list()
    for diversity_random_list in diversity_random_list_list:
        whole_diversity_random_list += diversity_random_list
    diversity_bins = 13 # questionable
    
    stds, avg_hist = get_stds_avghist(counter_list=diversity_random_list_list, bins=diversity_bins)
    
    print(diversity_bins)
    _, bins = np.histogram(diversity_nonrandom_list, bins=diversity_bins)
    bin_width = bins[1] - bins[0]
    bins = np.array([biner - bin_width / 2 for biner in bins])
    plt.hist(diversity_nonrandom_list, bins=bins, color="red", alpha=0.5)
    bins = np.array([biner + bin_width / 2 for biner in bins])
    bins = bins[:-1] # getting left edges for bar plot
    plt.bar(bins, avg_hist, color="black", width=bin_width, yerr=stds, alpha=0.5)
    
#     plt.title("Diversity")
    red_patch = mpatches.Patch(color="red", alpha=0.5, label="positive data")
    black_patch = mpatches.Patch(color="black", alpha=0.5, label="negative data")
    plt.legend(handles=[red_patch, black_patch])
    plt.xlabel("Diversity of semantic categories between drug and disease")
    plt.ylabel("Count")

In [None]:
diversity_random_counter_list = get_counters(number, diversity=True)

In [None]:
diversity_nonrandom_counter = get_sem_cat_counter_list(table, table_number="real", diversity=True)

In [None]:
plt.figure()
plot_diversity(diversity_nonrandom_counter, diversity_random_counter_list)

In [None]:
diversity_random_counter_list = get_counters(number, diversity=True)
plot_diversity(five_nonrandom_counter, diversity_random_counter_list)

In [None]:

for _, row in table.iterrows():
    if row["disease_ids"] == []:
        print(row)

1. Count concept types for random and nonrandom tables
2. Count diversity of types and for each type count concepts in between
- normalize by drug and disease count
3. Define good predicate combinations