In [None]:
!pip install tensorflow-text==2.5
!pip install tf-models-official==2.5

In [None]:
import os
import shutil

import pandas as pd
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from official.nlp import optimization
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback

import matplotlib.pyplot as plt
from ml4h.TensorMap import TensorMap, Interpretation

drug_folder = 'split_drugs'
drug_folder = 'split_small_test_all'

preprocess_model = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"

#base_model = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3"
base_model = "https://tfhub.dev/google/experts/bert/wiki_books/sst2/2"

tf.get_logger().setLevel('ERROR')


receptors = ['5_ht2a', '5_ht2c', '5_ht2b', '5_ht1a', '5_ht1b', '5_ht1d', '5_ht1e', '5_ht1f', '5_ht3', '5_ht5a', 
             '5_ht6', '5_ht7', 'dopamine_d1', 'dopamine_d2', 'dopamine_d3', 'dopamine_d4', 'dopamine_d5', 
             'adrenergic_alpha1a', 'adrenergic_alpha1b', 'adrenergic_alpha2a', 'adrenergic_alpha2b', 
             'adrenergic_beta1', 'adrenergic_beta2', 'sert', 'dat', 'net', 'imidazoline_1', 'sigma_1', 'sigma_2', 
             'dor', 'kor', 'mor', 'm1', 'm2', 'm3', 'm4', 'm5', 'h1', 'h2', 'h3', 'h4', 'calcium_channel', 'nmda',
             'cb1', 'cb2', 'glutamate_ampa', 'gaba_a', 'gaba_b', 'dopamine_d2_long', 'dopamine_d2_short',
             'sodium_channel', 'taar1', 'substance_p', 'paf_platelet_activating_factor', 'prostaglandin_e3', 
             'prostaglandin_e4', 'herg', 'monoamine_oxidase_a', 'monoamine_oxidase_b', 'cholecystokinin_a', 
             'cholecystokinin_b']

tags = {'Small_Group': 'tag_0', 'General': 'tag_1', 'First_Times': 'tag_2', 'Alone': 'tag_3', 
        'Difficult_Experiences': 'tag_4', 'Glowing_Experiences': 'tag_5', 'Retrospective_Summary': 'tag_6', 
        'Various': 'tag_7', 'Unknown_Context': 'tag_8', 'Mystical_Experiences': 'tag_9', 'Health_Problems': 
        'tag_10', 'Combinations': 'tag_11', 'Not_Applicable': 'tag_12', 'Bad_Trips': 'tag_13', 
        'Hangover_Days_After': 'tag_14', 'Entities_Beings': 'tag_15', 'Music_Discussion': 'tag_16', 
        'Addiction_Habituation': 'tag_17', 'Post_Trip_Problems': 'tag_18', 'Nature_Outdoors': 'tag_19', 
        'Relationships': 'tag_20', 'Depression': 'tag_21', 'Therapeutic_Intent_or_Outcome': 'tag_22', 
        'Overdose': 'tag_23', 'Medical_Use': 'tag_24', 'Sex_Discussion': 'tag_25', 
        'Train_Wrecks_Trip_Disasters': 'tag_26', 'Guides_Sitters': 'tag_27', 'Rave_Dance_Event': 'tag_28', 
        'Preparation_Recipes': 'tag_29', 'Festival_Lg_Crowd': 'tag_30', 'Health_Benefits': 'tag_31', 
        'Large_Group': 'tag_32', 'Multi-Day_Experience': 'tag_33', 'Club_Bar': 'tag_34', 
        'What_Was_in_That': 'tag_35', 'Personal_Preparation': 'tag_36', 'HPPD_Lasting_Visuals': 'tag_37', 
        'Families': 'tag_38', 'Second_Hand_Report': 'tag_39', 'Loss_of_Magic': 'tag_40', 'Hospital': 'tag_41', 
        'Public_Space': 'tag_42', 'School': 'tag_43', 'Poetry': 'tag_44', 'Performance_Enhancement': 'tag_45', 
        'Large_Party': 'tag_46', 'Group_Ceremony': 'tag_47', 'Workplace': 'tag_48', 
        'Cultivation_Synthesis': 'tag_49', 'Pregnancy_Baby': 'tag_50', 'Military': 'tag_51'}

itags = {v: k.replace('(', '').replace(')', '').replace('/', '') for k, v in tags.items()}
ctags = {v: int(k.replace('tag_', '')) for k, v in itags.items()}

In [None]:
def process_df(text_csv, require_sex=False, require_age=False, 
               set_psychoactive=None, set_ligand=None, set_drug=None):
    df = pd.read_csv(text_csv)
    if require_age:
        df = df[df.age != 'Not Given']
    if require_sex:
        df = df[df.sex_int.notna()]
    df = df[df.text.notna()]
    
    if set_psychoactive:
        df['psychoactive_class_int'] = set_psychoactive
    else:
        df.psychoactive_class = df.psychoactive_class.apply(lambda x: x.strip())
        psychoactive2class = {d:i for i,d in enumerate(df.psychoactive_class.unique())}
        df['psychoactive_class_int'] = [psychoactive2class[d] for d in df.psychoactive_class]
    
    if set_ligand:
        df['ligand_chemical_int'] = set_ligand
    else:
        ligand_chemical2class = {d:i for i,d in enumerate(df.ligand_chemical_class.unique())}
        df['ligand_chemical_int'] = [ligand_chemical2class[d] for d in df.ligand_chemical_class]        
    
    if set_drug:
        df['drug_class'] = set_drug
    else:
        drug2class = {d:i for i,d in enumerate(df.drug.unique())}
        df['drug_class'] = [drug2class[d] for d in df.drug]
        
    df[[f'tag_{i}' for i in range(52) ]] = df[[f'tag_{i}' for i in range(52) ]].fillna(0)
    df[receptors] = df[receptors].astype(float)  
    return df

heroin_df = process_df('test_heroin_meta_data.csv', set_psychoactive=9, set_ligand=3, set_drug=43)
heroin_df.to_csv('heroin_to_test.csv', index=False)
heroin_ds = make_dataset('heroin_to_test.csv', ['text'], output_cols)

In [None]:
df = process_df(f'./all_drugs_v2022_04_26_meta_data.csv')

In [None]:
heroin_df = process_df('test_heroin_meta_data.csv', set_psychoactive=9, set_ligand=3, set_drug=43)
heroin_df.to_csv('heroin_to_test.csv', index=False)
heroin_ds = make_dataset('heroin_to_test.csv', ['text'], output_cols)

In [None]:
df = pd.read_csv(f'./all_drugs_v2022_04_26_meta_data.csv')

# df = df[df.age != 'Not Given']
# df = df[df.sex_int.notna()]
df = df[df.text.notna()]

df.psychoactive_class = df.psychoactive_class.apply(lambda x: x.strip())

drug2class = {d:i for i,d in enumerate(df.drug.unique())}
psychoactive2class = {d:i for i,d in enumerate(df.psychoactive_class.unique())}
ligand_chemical2class = {d:i for i,d in enumerate(df.ligand_chemical_class.unique())}
class2weight = {i:5*(len(df)/len(df[df.drug==d])) for i,d in enumerate(df.drug.unique())}
print(psychoactive2class)
df['drug_class'] = [drug2class[d] for d in df.drug]
df['psychoactive_class_int'] = [psychoactive2class[d] for d in df.psychoactive_class]
df['ligand_chemical_int'] = [ligand_chemical2class[d] for d in df.ligand_chemical_class]
df[[f'tag_{i}' for i in range(52) ]] = df[[f'tag_{i}' for i in range(52) ]].fillna(0)


# train = df.sample(frac = 0.8)
# test = df.drop(train.index).sample(frac = 0.5)
# validate = df.drop(train.index).drop(test.index)

df[df.set=='train'].to_csv('train.csv', index=False)
df[df.set=='valid'].to_csv('valid.csv', index=False)
df[df.set=='test'].to_csv('test.csv', index=False)

tags = {'Small_Group': 'tag_0', 'General': 'tag_1', 'First_Times': 'tag_2', 'Alone': 'tag_3', 'Difficult_Experiences': 'tag_4', 'Glowing_Experiences': 'tag_5', 'Retrospective_Summary': 'tag_6', 'Various': 'tag_7', 'Unknown_Context': 'tag_8', 'Mystical_Experiences': 'tag_9', 'Health_Problems': 'tag_10', 'Combinations': 'tag_11', 'Not_Applicable': 'tag_12', 'Bad_Trips': 'tag_13', 'Hangover_Days_After': 'tag_14', 'Entities_Beings': 'tag_15', 'Music_Discussion': 'tag_16', 'Addiction_Habituation': 'tag_17', 'Post_Trip_Problems': 'tag_18', 'Nature_Outdoors': 'tag_19', 'Relationships': 'tag_20', 'Depression': 'tag_21', 'Therapeutic_Intent_or_Outcome': 'tag_22', 'Overdose': 'tag_23', 'Medical_Use': 'tag_24', 'Sex_Discussion': 'tag_25', 'Train_Wrecks_Trip_Disasters': 'tag_26', 'Guides_Sitters': 'tag_27', 'Rave_Dance_Event': 'tag_28', 'Preparation_Recipes': 'tag_29', 'Festival_Lg_Crowd': 'tag_30', 'Health_Benefits': 'tag_31', 'Large_Group': 'tag_32', 'Multi-Day_Experience': 'tag_33', 'Club_Bar': 'tag_34', 'What_Was_in_That': 'tag_35', 'Personal_Preparation': 'tag_36', 'HPPD_Lasting_Visuals': 'tag_37', 'Families': 'tag_38', 'Second_Hand_Report': 'tag_39', 'Loss_of_Magic': 'tag_40', 'Hospital': 'tag_41', 'Public_Space': 'tag_42', 'School': 'tag_43', 'Poetry': 'tag_44', 'Performance_Enhancement': 'tag_45', 'Large_Party': 'tag_46', 'Group_Ceremony': 'tag_47', 'Workplace': 'tag_48', 'Cultivation_Synthesis': 'tag_49', 'Pregnancy_Baby': 'tag_50', 'Military': 'tag_51'}
itags = {v: k.replace('(', '').replace(')', '').replace('/', '') for k, v in tags.items()}
ctags = {v: int(k.replace('tag_', '')) for k, v in itags.items()}

receptors = ['5_ht2a', '5_ht2c', '5_ht2b', '5_ht1a', '5_ht1b', '5_ht1d', '5_ht1e', '5_ht1f', '5_ht3', '5_ht5a', '5_ht6', '5_ht7', 'dopamine_d1', 'dopamine_d2', 'dopamine_d3', 'dopamine_d4', 'dopamine_d5', 'adrenergic_alpha1a', 'adrenergic_alpha1b', 'adrenergic_alpha2a', 'adrenergic_alpha2b', 'adrenergic_beta1', 'adrenergic_beta2', 'sert', 'dat', 'net', 'imidazoline_1', 'sigma_1', 'sigma_2', 'dor', 'kor', 'mor', 'm1', 'm2', 'm3', 'm4', 'm5', 'h1', 'h2', 'h3', 'h4', 'calcium_channel', 'nmda', 'cb1', 'cb2', 'glutamate_ampa', 'gaba_a', 'gaba_b', 'dopamine_d2_long', 'dopamine_d2_short', 'sodium_channel', 'taar1', 'substance_p', 'paf_platelet_activating_factor', 'prostaglandin_e3', 'prostaglandin_e4', 'herg', 'monoamine_oxidase_a', 'monoamine_oxidase_b', 'cholecystokinin_a', 'cholecystokinin_b']
df[receptors] = df[receptors].astype(float)

In [None]:
len(df.ligand_chemical_class.value_counts())

In [None]:
drug2class

In [None]:
df.psychoactive_class.value_counts()

In [None]:
df['no_tags'] = df[[f'tag_{i}' for i in range(52) ]].sum(axis=1)
print(df['no_tags'].value_counts())

In [None]:
df.tag_0.value_counts()

In [None]:
df.psychoactive_class_int.value_counts()

In [None]:
output_cols = [ 'psychoactive_class_int', 'ligand_chemical_int', 'drug_class', 'sex_int']
output_cols += [f'tag_{i}' for i in range(52)]

tensor_maps_out = []
for oc in output_cols:
    if 'drug_class' == oc:
        tensor_maps_out.append(TensorMap(f'{oc}', Interpretation.CATEGORICAL, shape=(1,),
                                         loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                                         metrics=[tf.metrics.SparseCategoricalAccuracy()],
                                         channel_map={f'drug_{d}': v for d,v in drug2class.items()}))
    elif 'psychoactive_class_int' == oc:
        tensor_maps_out.append(TensorMap(f'{oc}', Interpretation.CATEGORICAL, shape=(1,),
                                         loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                                         metrics=[tf.metrics.SparseCategoricalAccuracy()],
                                         channel_map={f'{d}': v for d,v in psychoactive2class.items()}))
    elif 'ligand_chemical_int' == oc:
        tensor_maps_out.append(TensorMap(f'{oc}', Interpretation.CATEGORICAL, shape=(1,),
                                         loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                                         metrics=[tf.metrics.SparseCategoricalAccuracy()],
                                         channel_map={f'{d}': v for d,v in ligand_chemical2class.items()}))        
    elif 'tag_' in oc:
        tensor_maps_out.append(TensorMap(f'{oc}', Interpretation.CATEGORICAL, shape=(1,), 
                                         loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                                         metrics=[tf.metrics.SparseCategoricalAccuracy()],
                                         channel_map={f'no_{itags[oc]}': 0, f'{itags[oc]}': 1}))
    elif 'age' == oc:
        tensor_maps_out.append(TensorMap(f'{oc}', Interpretation.CONTINUOUS, shape=(1,),
                                         loss=tf.keras.losses.MeanSquaredError(),
                                         metrics=[tf.metrics.MeanAbsoluteError()],))
    elif oc in receptors:
        tensor_maps_out.append(TensorMap(f'{oc}', Interpretation.CONTINUOUS, shape=(1,),
                                         loss=tf.keras.losses.LogCosh(),
                                         metrics=[tf.metrics.MeanAbsoluteError()],))
    else:
        tensor_maps_out.append(TensorMap(f'{oc}', Interpretation.CATEGORICAL, shape=(1,), 
                                         loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                                         metrics=[tf.metrics.SparseCategoricalAccuracy()],
                                         channel_map={f'no_{oc}':0, f'{oc}':1}))
        
def make_dataset(csv, in_cols, out_cols, batch_size=32):
    i = tf.data.experimental.make_csv_dataset(csv, select_columns=in_cols, 
                                              batch_size=1, shuffle=False)
    o = tf.data.experimental.make_csv_dataset(csv, select_columns=out_cols, 
                                              batch_size=1, shuffle=False)
    ds = tf.data.Dataset.zip((i,o))
    ds = ds.shuffle(10000)
    ds = ds.unbatch().batch(batch_size)
    
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    return ds

train_ds = make_dataset('train.csv', ['text'], output_cols)
valid_ds = make_dataset('valid.csv', ['text'], output_cols)
test_ds = make_dataset('test.csv', ['text'], output_cols)

In [None]:
from ml4h.TensorMap import TensorMap, Interpretation

In [None]:
for feature_batch, label in test_ds.take(1):
    print(f"label {label}")
    for key, value in feature_batch.items():
        print(f"\n\n\n Key is  {key:20s}: {value[0]}")

In [None]:
n_drugs = len(drug2class)
dropout_rate = 0.2

text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)
bert_preprocess_model = hub.KerasLayer(
    preprocess_model)

bert_model = hub.KerasLayer(
    base_model,
    trainable=True)

text_test = ['this is such an amazing movie!']
text_preprocessed = bert_preprocess_model(text_test)

print(f'Keys       : {list(text_preprocessed.keys())}')
print(f'Shape      : {text_preprocessed["input_word_ids"].shape}')
print(f'Word Ids   : {text_preprocessed["input_word_ids"][0, :12]}')
print(f'Input Mask : {text_preprocessed["input_mask"][0, :12]}')
print(f'Type Ids   : {text_preprocessed["input_type_ids"][0, :12]}')

bert_results = bert_model(text_preprocessed)

print(f'Pooled Outputs Shape:{bert_results["pooled_output"].shape}')
print(f'Pooled Outputs Values:{bert_results["pooled_output"][0, :12]}')
print(f'Sequence Outputs Shape:{bert_results["sequence_output"].shape}')
print(f'Sequence Outputs Values:{bert_results["sequence_output"][0, :12]}')

# def weighted_scce(weights):
#     def my_loss(y_true, y_pred):
#         sample_weights = [weights[int(y_true[i].numpy())] for i in 33]
#         scce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
#         return scce(y_true, y_pred, sample_weights=sample_weights)
#     return my_loss

def build_classifier_model(tfhub_handle_preprocess, tfhub_handle_encoder, tensor_maps_out):
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
    preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
    encoder_inputs = preprocessing_layer(text_input)
    encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
    outputs = encoder(encoder_inputs)
    net = outputs['pooled_output']
    net = tf.keras.layers.Dropout(dropout_rate)(net)
    #net = tf.keras.layers.Dense(256, activation='swish')(net)
    #net = tf.keras.layers.Dropout(dropout_rate)(net)
    outputs = []
    metrics = []
    losses = []    
    for otm in tensor_maps_out:
        if otm.is_categorical():
            outputs.append(tf.keras.layers.Dense(len(otm.channel_map), activation=None, name=otm.name)(net))
            losses.append(tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
            metrics.append(tf.metrics.SparseCategoricalAccuracy(name=f'{otm.name}_SparseCategoricalAccuracy_met'))
        elif otm.is_continuous():
            outputs.append(tf.keras.layers.Dense(1, activation='linear', name=otm.name)(net))
            losses.append(tf.keras.losses.MeanSquaredError()) 
            metrics.append(tf.metrics.MeanAbsoluteError(name=f'{otm.name}_mae'))
    return tf.keras.Model(text_input, outputs), losses, metrics

classifier_model, losses, metrics = build_classifier_model(preprocess_model, base_model, tensor_maps_out)
bert_raw_result = classifier_model(tf.constant(text_test))
print(tf.sigmoid(bert_raw_result[-1]))
tf.keras.utils.plot_model(classifier_model)
metrics = {tm.name: tm.metrics for tm in tensor_maps_out}
losses = [tm.loss for tm in tensor_maps_out]
#loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
#metrics = tf.metrics.SparseCategoricalAccuracy()

In [None]:
epochs = 14
batch_size = 32
steps_per_epoch = len(df[df.set=='train'])//batch_size
num_train_steps = steps_per_epoch * epochs
num_warmup_steps = int(0.1*num_train_steps)
print(f'warm up {steps_per_epoch}  and {num_train_steps}')
init_lr = 1e-5
optimizer = optimization.create_optimizer(init_lr=init_lr,
                                          num_train_steps=num_train_steps,
                                          num_warmup_steps=num_warmup_steps,
                                          optimizer_type='adamw')


classifier_model.compile(optimizer=optimizer, loss=losses, metrics=metrics)

In [None]:
classifier_model = tf.keras.models.load_model(f'./models/bert_psychoactive_ligand_drug_sex_tag_classifier', 
                                              custom_objects={'AdamWeightDecay':optimizer})
classifier_model.summary()

In [None]:
len(df[df.set=='test'])

In [None]:
from collections import defaultdict
predictions = defaultdict(list)
truths = defaultdict(list)
for text, labels in test_ds.as_numpy_iterator():
    for l, otm in zip(labels, tensor_maps_out):
        if otm.is_categorical():
            truths[l].extend(map(int, list(labels[l])))
        else:
            truths[l].extend(list(labels[l]))
    p = classifier_model.predict(text)
    if len(classifier_model.output_names) == 1:
        p = [p]
    for i,ot in enumerate(classifier_model.output_names):
        predictions[ot].extend(list(p[i]))
        
    if len(truths[l]) >= len(df[df.set=='test']):
        break

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from ml4h.plots import plot_roc, subplot_rocs, plot_scatter, plot_precision_recall_per_class
def make_one_hot(y, num_labels):
    ohy = np.zeros((y.shape[-1], num_labels))
    for i in range(0, y.shape[-1]):
        ohy[i, int(y[i])] = 1.0
    return ohy

rocs = []
perfs = {}
for otm in tensor_maps_out:
    if otm.is_categorical():
        print(f' otm {otm} {np.array(predictions[otm.name]).shape}')
        perfs[otm] = plot_precision_recall_per_class(np.array(predictions[otm.name]), 
                 make_one_hot(np.array(truths[otm.name]), len(otm.channel_map)), 
                 otm.channel_map, otm.name)
        rocs.append((np.array(predictions[otm.name]), 
                     make_one_hot(np.array(truths[otm.name]), len(otm.channel_map)), 
                     otm.channel_map))
    elif otm.is_continuous():
        perfs[otm] = plot_scatter(np.array(predictions[otm.name]), np.array(truths[otm.name]), otm.name)
subplot_rocs(rocs)

In [None]:
from collections import Counter
c2c = {v: k for k, v in psychoactive2class.items()}
hl = Counter(np.argmax(predictions['psychoactive_class_int'], axis = -1))
for k in hl:
    print(f'Drug {c2c[k]} has {hl[k]}')

In [None]:
c2c = {v: k for k, v in drug2class.items()}
hl = Counter(np.argmax(predictions['drug_class'], axis = -1))
for k in hl:
    print(f'Drug {c2c[k]} has {hl[k]}')

In [None]:
c2c = {v: k for k, v in ligand_chemical2class.items()}
hl = Counter(np.argmax(predictions['ligand_chemical_int'], axis = -1))
for k in hl:
    print(f'Ligand {c2c[k]} has {hl[k]}')

In [None]:
from collections import defaultdict
predictions = defaultdict(list)
truths = defaultdict(list)
for text, labels in heroin_ds.as_numpy_iterator():
    for l in labels:
        truths[l].extend(map(int, list(labels[l])))
    p = classifier_model.predict(text)
    if len(classifier_model.output_names) == 1:
        p = [p]
    for i,ot in enumerate(classifier_model.output_names):
        predictions[ot].extend(list(p[i]))
        
    if len(truths[l]) >= len(heroin_df):
        break

In [None]:
def confusion_heatmap(confusion, labels, cutoff=0.001, fmt='2d', figsize=(24, 15), 
                      title='Confusion Matrix'):
    fig, ax = plt.subplots(figsize=figsize, dpi=300)
    ax = sb.heatmap(confusion, cmap='Blues', ax=ax, cbar=False)
    ax = sb.heatmap(confusion, mask=confusion < cutoff, cmap='Blues', 
                    annot=True, fmt=fmt, cbar_kws={"shrink": .8, 'label': 'Counts'}, ax=ax)

    ax.set_title(title)
    ax.set_xticks(np.arange(confusion.shape[0]) + 0.5)
    ax.set_yticks(np.arange(confusion.shape[0]) + 0.5)

    # update the desired text annotations
#     for text in ax.texts:
#         if text.get_text() == '0':
#             text.set_text('.')
    ax.set_xticklabels(labels=labels, ha='right', rotation=30)
    ax.set_yticklabels(labels=labels, rotation=0)
    plt.tight_layout()
    figure_path='./confusion_heatmap.png'
    if not os.path.exists(os.path.dirname(figure_path)):
        os.makedirs(os.path.dirname(figure_path))
    plt.savefig(figure_path)
    plt.show()
    

In [None]:
tm = tensor_maps_out[0]
pmax = np.argmax(np.array(predictions[tm.name]), axis=-1)

pmax.shape

confusion = np.zeros((len(tm.channel_map), len(tm.channel_map)), dtype=np.int32)
for i in range(len(tm.channel_map)):
    for j in range(len(tm.channel_map)):
        confusion[i,j] += sum((pmax==j) * (np.array(truths[tm.name]) == i))
        

# pmax == 24

# truths[tm.name] == 24

# confusion

# np.array(truths[tm.name]) == 24

import seaborn as sb
%matplotlib inline
import matplotlib.pyplot as plt
from io import StringIO

labels = list(tm.channel_map.keys())
confusion_heatmap(confusion, labels, figsize=(8, 5), title='Drug Class Confusion Matrix')

confusion_r = np.zeros((len(tm.channel_map), len(tm.channel_map)))
for i in range(len(tm.channel_map)):
    for j in range(len(tm.channel_map)):
        confusion_r[i,j] += sum((pmax==j) * (np.array(truths[tm.name]) == i)) / sum(np.array(truths[tm.name]) == i)

confusion_heatmap(confusion_r, labels, fmt=".2f", figsize=(8, 5), title='Drug Class Confusion Matrix')  

In [None]:
tm = tensor_maps_out[1]
pmax = np.argmax(np.array(predictions[tm.name]), axis=-1)

pmax.shape

confusion = np.zeros((len(tm.channel_map), len(tm.channel_map)), dtype=np.int32)
for i in range(len(tm.channel_map)):
    for j in range(len(tm.channel_map)):
        confusion[i,j] += sum((pmax==j) * (np.array(truths[tm.name]) == i))


labels = list(tm.channel_map.keys())
confusion_heatmap(confusion, labels, figsize=(10, 6))

confusion_r = np.zeros((len(tm.channel_map), len(tm.channel_map)))
for i in range(len(tm.channel_map)):
    for j in range(len(tm.channel_map)):
        confusion_r[i,j] += sum((pmax==j) * (np.array(truths[tm.name]) == i)) / sum(np.array(truths[tm.name]) == i)

confusion_heatmap(confusion_r, labels, fmt=".2f", figsize=(10, 6))  

In [None]:
tm = tensor_maps_out[2]
pmax = np.argmax(np.array(predictions[tm.name]), axis=-1)

pmax.shape

confusion = np.zeros((len(tm.channel_map), len(tm.channel_map)), dtype=np.int32)
for i in range(len(tm.channel_map)):
    for j in range(len(tm.channel_map)):
        confusion[i,j] += sum((pmax==j) * (np.array(truths[tm.name]) == i))
labels = list(tm.channel_map.keys())
confusion_heatmap(confusion, labels, title='Drug Confusion Matrix') 

In [None]:
tm = tensor_maps_out[2]
pmax = np.argmax(np.array(predictions[tm.name]), axis=-1)

pmax.shape

confusion = np.zeros((len(tm.channel_map), len(tm.channel_map)), dtype=np.int32)
for i in range(len(tm.channel_map)):
    for j in range(len(tm.channel_map)):
        confusion[i,j] += sum((pmax==j) * (np.array(truths[tm.name]) == i))
confusion_heatmap(confusion) 

In [None]:
tm = tensor_maps_out[2]
pmax = np.argmax(np.array(predictions[tm.name]), axis=-1)

pmax.shape

confusion = np.zeros((len(tm.channel_map), len(tm.channel_map)), dtype=np.int32)
for i in range(len(tm.channel_map)):
    for j in range(len(tm.channel_map)):
        confusion[i,j] += sum((pmax==j) * (np.array(truths[tm.name]) == i))
confusion_heatmap(confusion)      

In [None]:
confusion_r = np.zeros((len(tm.channel_map), len(tm.channel_map)))
for i in range(len(tm.channel_map)):
    for j in range(len(tm.channel_map)):
        confusion_r[i,j] += sum((pmax==j) * (np.array(truths[tm.name]) == i)) / sum(np.array(truths[tm.name]) == i)

confusion_heatmap(confusion_r, fmt=".2f")         

In [None]:
len(tensor_maps_out[0].channel_map)

In [None]:
perfs

In [None]:
perfs

In [None]:
import math
tag_auc = []
for tm in perfs:
    if 'tag' in tm.name:
        #print(f'{itags[tm.name]}, {perfs[tm][itags[tm.name]]} ')
        p = perfs[tm][itags[tm.name]]
        t = itags[tm.name]
        if math.isnan(p):
            continue
        tag_auc.append((t, p))

In [None]:
tag_auc = sorted(tag_auc, key=lambda x: x[1])

In [None]:
tag_auc = [('Medical_Use', 0.952544387242243),
 ('Rave_Dance_Event', 0.9404877496959814),
 ('Poetry', 0.919463505926388),
 ('Addiction_Habituation', 0.8818512946004521),
 ('Health_Benefits', 0.8463797692990239),
 ('Loss_of_Magic', 0.8341943957968476),
 ('Not_Applicable', 0.814232096584901),
 ('Alone', 0.801082111968865),
 ('Preparation_Recipes', 0.7911125886524822),
 ('Relationships', 0.786775689912729),
 ('Retrospective_Summary', 0.7866495237237311),
 ('Club_Bar', 0.7858709960508994),
 ('Mystical_Experiences', 0.76840692629412),
 ('Unknown_Context', 0.7664176932777185),
 ('Small_Group', 0.7621865858796946),
 ('School', 0.7555041534409401),
 ('Festival_Lg_Crowd', 0.7520782461337283),
 ('Depression', 0.7436038388419342),
 ('First_Times', 0.741708831433602),
 ('Entities_Beings', 0.7378392880270876),
 ('Nature_Outdoors', 0.7303554866907693),
 ('Hospital', 0.7202772739051568),
 ('Performance_Enhancement', 0.7132642300470691),
 ('Public_Space', 0.7060791157649797),
 ('Therapeutic_Intent_or_Outcome', 0.6987281649086753),
 ('Combinations', 0.6972770583104073),
 ('What_Was_in_That', 0.6771832080849962),
 ('Train_Wrecks_Trip_Disasters', 0.6718617121799112),
 ('Health_Problems', 0.6515188989624958),
 ('Glowing_Experiences', 0.6414117808313058),
 ('Difficult_Experiences', 0.6376775952272582),
 ('Multi-Day_Experience', 0.6330706659888155),
 ('Various', 0.6315747002525838),
 ('Personal_Preparation', 0.6128389404971127),
 ('General', 0.6093155886094529),
 ('Bad_Trips', 0.6037562986715529),
 ('Large_Group', 0.5999464391252954),
 ('Music_Discussion', 0.5823361233115107),
 ('Guides_Sitters', 0.5775967327691466),
 ('Hangover_Days_After', 0.5742190488333545),
 ('Post_Trip_Problems', 0.569676289121946),
 ('Overdose', 0.5639120681920143),
 ('HPPD_Lasting_Visuals', 0.5518192744755245),
 ('Second_Hand_Report', 0.5382078798391248),
 ('Sex_Discussion', 0.5337256481827861),
 ('Group_Ceremony', 0.5138755980861244),
 ('Families', 0.5106063450347287),
 ('Workplace', 0.4897781644193128),
 ('Large_Party', 0.35978260869565215)]

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
_ = plt.figure(figsize=(5, 11), dpi=300)
plt.barh(range(len(tag_auc)), [t[1] for t in tag_auc])
plt.axvline(0.5, linestyle='dashed', c='orange')
plt.yticks(np.arange(len(tag_auc)), [t[0].replace('_', ' ') for t in tag_auc], ha='right')
plt.ylabel('Erowid Meta Data Tags')
plt.xlabel('Test Set ROC AUC')
plt.box(False)

# figure_path = f'results/tag_histogram_{title}.png'
# if not os.path.exists(os.path.dirname(figure_path)):
#     os.makedirs(os.path.dirname(figure_path))
# plt.savefig(figure_path)

In [None]:
_ = plt.figure(figsize=(5, 11), dpi=300)
plt.barh(range(len(tag_auc)), [t[1] for t in tag_auc])
plt.axvline(0.5, linestyle='dashed', c='orange')
plt.yticks(np.arange(len(tag_auc)), [t[0].replace('_', ' ') for t in tag_auc], ha='right')
plt.ylabel('Erowid Meta Data Tags')
plt.xlabel('Test Set ROC AUC')
plt.box(False)

In [None]:
from scipy.stats import pearsonr
import seaborn as sb
from sklearn.cluster import AgglomerativeClustering

In [None]:
clusters

In [None]:
confusion_r

In [None]:
sorted_confuse = np.zeros(confusion_r.shape)
cur_i = 0
sorted_args = []
previous_args = []
for j in range(num_clusters):
    for i,cur_c in enumerate(clusters):
        if j == cur_c:
            sorted_confuse[cur_i, :] = confusion_r[i, :]
            sorted_args.append(i)
            cur_i += 1

In [None]:
sorted_confuse

In [None]:
tm = tensor_maps_out[2]


In [None]:
confusion_heatmap(confusion_r, fmt=".2f", figsize=(10, 6), title='Chemical Class Confusion Matrix') 

In [None]:
tm = tensor_maps_out[2]
pmax = np.argmax(np.array(predictions[tm.name]), axis=-1)
confusion_r = np.zeros((len(tm.channel_map), len(tm.channel_map)))
for i in range(len(tm.channel_map)):
    for j in range(len(tm.channel_map)):
        confusion_r[i,j] += sum((pmax==j) * 
                                (np.array(truths[tm.name]) == i)) / sum(np.array(truths[tm.name]) == i)

num_clusters = 6
cluster = AgglomerativeClustering(n_clusters=num_clusters, affinity='euclidean', linkage='ward')
clusters = cluster.fit_predict(confusion_r)
print(clusters)


In [None]:
np.argsort(clusters)

In [None]:
sort_confusion = np.zeros(confusion_r.shape)
cur_i = 0
sort_labels = []
idx2label = {v:k for k,v in tm.channel_map.items()} 
for i,j in enumerate(np.argsort(clusters)):
    for ii,jj in enumerate(np.argsort(clusters)):
        sort_confusion[i, ii] = confusion_r[j, jj]
    sort_labels.append(f'Cluster:{clusters[j]}, {idx2label[j].replace("drug_", "")}')
print(sort_labels)

In [None]:
confusion_heatmap(sort_confusion, sort_labels, fmt=".2f", title='Drug Confusion Matrix') 

In [None]:
lb = pd.read_csv('split_32_all_drugs_v2022_04_22_meta_data.csv')

In [None]:
lb = lb[lb.set=='test'].sample(frac = 0.03)
lb.to_csv('galen_to_label.csv', index=False)