In [None]:
!pip install tensorflow-text==2.5
!pip install tf-models-official==2.5

In [None]:
import os
import shutil

import pandas as pd
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from official.nlp import optimization
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback

import matplotlib.pyplot as plt
from ml4h.TensorMap import TensorMap, Interpretation
from ml4h.normalizer import ZeroMeanStd1

drug_folder = 'split_drugs'
drug_folder = 'split_small_test_all'

preprocess_model = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"

#base_model = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3"
base_model = "https://tfhub.dev/google/experts/bert/wiki_books/sst2/2"

tf.get_logger().setLevel('ERROR')

In [None]:
df = pd.read_csv(f'./all_drugs_v2022_04_26_meta_data.csv')
# df = df[df.age != 'Not Given']
df = df[df.sex_int.notna()]
df = df[df.text.notna()]


drug2class = {d:i for i,d in enumerate(df.drug.unique())}
psychoactive2class = {d:i for i,d in enumerate(df.psychoactive_class.unique())}
class2weight = {i:5*(len(df)/len(df[df.drug==d])) for i,d in enumerate(df.drug.unique())}
print(psychoactive2class)
df['drug_class'] = [drug2class[d] for d in df.drug]
df['psychoactive_class_int'] = [psychoactive2class[d] for d in df.psychoactive_class]
df[[f'tag_{i}' for i in range(52) ]] = df[[f'tag_{i}' for i in range(52) ]].fillna(0)


# train = df.sample(frac = 0.8)
# test = df.drop(train.index).sample(frac = 0.5)
# validate = df.drop(train.index).drop(test.index)




tags = {'Small_Group': 'tag_0', 'General': 'tag_1', 'First_Times': 'tag_2', 'Alone': 'tag_3', 'Difficult_Experiences': 'tag_4', 'Glowing_Experiences': 'tag_5', 'Retrospective_Summary': 'tag_6', 'Various': 'tag_7', 'Unknown_Context': 'tag_8', 'Mystical_Experiences': 'tag_9', 'Health_Problems': 'tag_10', 'Combinations': 'tag_11', 'Not_Applicable': 'tag_12', 'Bad_Trips': 'tag_13', 'Hangover_Days_After': 'tag_14', 'Entities_Beings': 'tag_15', 'Music_Discussion': 'tag_16', 'Addiction_Habituation': 'tag_17', 'Post_Trip_Problems': 'tag_18', 'Nature_Outdoors': 'tag_19', 'Relationships': 'tag_20', 'Depression': 'tag_21', 'Therapeutic_Intent_or_Outcome': 'tag_22', 'Overdose': 'tag_23', 'Medical_Use': 'tag_24', 'Sex_Discussion': 'tag_25', 'Train_Wrecks_Trip_Disasters': 'tag_26', 'Guides_Sitters': 'tag_27', 'Rave_Dance_Event': 'tag_28', 'Preparation_Recipes': 'tag_29', 'Festival_Lg_Crowd': 'tag_30', 'Health_Benefits': 'tag_31', 'Large_Group': 'tag_32', 'Multi-Day_Experience': 'tag_33', 'Club_Bar': 'tag_34', 'What_Was_in_That': 'tag_35', 'Personal_Preparation': 'tag_36', 'HPPD_Lasting_Visuals': 'tag_37', 'Families': 'tag_38', 'Second_Hand_Report': 'tag_39', 'Loss_of_Magic': 'tag_40', 'Hospital': 'tag_41', 'Public_Space': 'tag_42', 'School': 'tag_43', 'Poetry': 'tag_44', 'Performance_Enhancement': 'tag_45', 'Large_Party': 'tag_46', 'Group_Ceremony': 'tag_47', 'Workplace': 'tag_48', 'Cultivation_Synthesis': 'tag_49', 'Pregnancy_Baby': 'tag_50', 'Military': 'tag_51'}
itags = {v: k.replace('(', '').replace(')', '').replace('/', '') for k, v in tags.items()}
ctags = {v: int(k.replace('tag_', '')) for k, v in itags.items()}

receptors = ['5_ht2a', '5_ht2c', '5_ht2b', '5_ht1a', '5_ht1b', '5_ht1d', '5_ht1e', '5_ht1f', '5_ht3', '5_ht5a', '5_ht6', '5_ht7', 'dopamine_d1', 'dopamine_d2', 'dopamine_d3', 'dopamine_d4', 'dopamine_d5', 'adrenergic_alpha1a', 'adrenergic_alpha1b', 'adrenergic_alpha2a', 'adrenergic_alpha2b', 'adrenergic_beta1', 'adrenergic_beta2', 'sert', 'dat', 'net', 'imidazoline_1', 'sigma_1', 'sigma_2', 'dor', 'kor', 'mor', 'm1', 'm2', 'm3', 'm4', 'm5', 'h1', 'h2', 'h3', 'h4', 'calcium_channel', 'nmda', 'cb1', 'cb2', 'glutamate_ampa', 'gaba_a', 'gaba_b', 'dopamine_d2_long', 'dopamine_d2_short', 'sodium_channel', 'taar1', 'substance_p', 'paf_platelet_activating_factor', 'prostaglandin_e3', 'prostaglandin_e4', 'herg', 'monoamine_oxidase_a', 'monoamine_oxidase_b', 'cholecystokinin_a', 'cholecystokinin_b']
df[receptors] = df[receptors].astype(float)

In [None]:
df[['drug', 'drug_class', 'testimonial']].info()

In [None]:
df.testimonial.value_counts()

In [None]:
[c for c in df]

In [None]:
cca_df = pd.read_csv(f'./testimonial_final_v2022_04_22_max_1000_pca_1000_None_cca_12_on_all_drugs.tsv', sep='\t')
cca_df['drug_int'] = [drug2class[d] for d in cca_df.drug_idx]
cca_df.info()

In [None]:
for i in range(11):
    cca_df[f'cca_{i}'] -= cca_df[f'cca_{i}'].mean()
    cca_df[f'cca_{i}'] /= cca_df[f'cca_{i}'].std()
    cca_df[f'cca_{i}'] *= 10
cca_df.cca_2.std()

In [None]:
for r in receptors:
    df[r] -= df[r].mean()
    df[r] /= df[r].std()
    df[r] *= 10

receptors2learn = []
for r in receptors:
    if len(df[r].value_counts()) > 10:
        receptors2learn.append(r)
        df[r].plot.hist()

In [None]:
print(len(receptors2learn))

In [None]:
df = pd.merge(df, cca_df, 
              left_on=['drug_class', 'testimonial'], 
              right_on=['drug_int', 'testimonial_idx'], 
              how='inner')
df[[f'cca_{i}' for i in range(11)]] = df[[f'cca_{i}' for i in range(11)]].astype(float)
df.info()

In [None]:
df.cca_2.plot.hist(bins=150)

In [None]:
df[df.set=='train'].to_csv('train.csv', index=False)
df[df.set=='valid'].to_csv('valid.csv', index=False)
df[df.set=='test'].to_csv('test.csv', index=False)

In [None]:
df['no_tags'] = df[[f'tag_{i}' for i in range(52) ]].sum(axis=1)
print(df['no_tags'].value_counts())

In [None]:
df.tag_0.value_counts()

In [None]:
df.tag_8.value_counts()

In [None]:
output_cols = ['drug_class', 'psychoactive_class_int', 'sex_int']
output_cols += [f'tag_{i}' for i in range(52)]
output_cols += [f'cca_{i}' for i in range(11)]
output_cols += receptors2learn
tensor_maps_out = []
for oc in output_cols:
    if 'drug_class' == oc:
        tensor_maps_out.append(TensorMap(f'{oc}', Interpretation.CATEGORICAL, shape=(1,),
                                         loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                                         metrics=[tf.metrics.SparseCategoricalAccuracy()],
                                         channel_map={f'drug_{d}': v for d,v in drug2class.items()}))
    elif 'psychoactive_class_int' == oc:
        tensor_maps_out.append(TensorMap(f'{oc}', Interpretation.CATEGORICAL, shape=(1,),
                                         loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                                         metrics=[tf.metrics.SparseCategoricalAccuracy()],
                                         channel_map={f'{d}': v for d,v in psychoactive2class.items()}))
    elif 'tag_' in oc:
        tensor_maps_out.append(TensorMap(f'{oc}', Interpretation.CATEGORICAL, shape=(1,), 
                                         loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                                         metrics=[tf.metrics.SparseCategoricalAccuracy()],
                                         channel_map={f'no_{itags[oc]}': 0, f'{itags[oc]}': 1}))
    elif 'age' == oc:
        tensor_maps_out.append(TensorMap(f'{oc}', Interpretation.CONTINUOUS, shape=(1,),
                                         loss=tf.keras.losses.MeanSquaredError(),
                                         metrics=[tf.metrics.MeanAbsoluteError()],))
    elif 'cca_' in oc:
        tensor_maps_out.append(TensorMap(f'{oc}', Interpretation.CONTINUOUS, shape=(1,),
                                         loss=tf.keras.losses.MeanSquaredError(),
                                         metrics=[tf.metrics.MeanAbsoluteError()],))    
    elif oc in receptors:
        tensor_maps_out.append(TensorMap(f'{oc}', Interpretation.CONTINUOUS, shape=(1,),
                                         loss=tf.keras.losses.MeanSquaredError(),
                                         metrics=[tf.metrics.MeanAbsoluteError()],))
    else:
        tensor_maps_out.append(TensorMap(f'{oc}', Interpretation.CATEGORICAL, shape=(1,), 
                                         loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                                         metrics=[tf.metrics.SparseCategoricalAccuracy()],
                                         channel_map={f'no_{oc}':0, f'{oc}':1}))
        
def make_dataset(csv, in_cols, out_cols, batch_size=32):
    i = tf.data.experimental.make_csv_dataset(csv, select_columns=in_cols, 
                                              batch_size=1, shuffle=False)
    o = tf.data.experimental.make_csv_dataset(csv, select_columns=out_cols, 
                                              column_defaults=[np.float32]*len(out_cols),
                                              batch_size=1, shuffle=False)
    ds = tf.data.Dataset.zip((i,o))
    ds = ds.shuffle(10000)
    ds = ds.unbatch().batch(batch_size)
    
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    return ds

train_ds = make_dataset('train.csv', ['text'], output_cols)
valid_ds = make_dataset('valid.csv', ['text'], output_cols)
test_ds = make_dataset('test.csv', ['text'], output_cols)

In [None]:
from ml4h.TensorMap import TensorMap, Interpretation

In [None]:
for feature_batch, label in test_ds.take(1):
    print(f"label {label}")
    for key, value in feature_batch.items():
        print(f"\n\n\n Key is  {key:20s}: {value[0]}")

In [None]:
n_drugs = len(drug2class)
dropout_rate = 0.2

text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)
bert_preprocess_model = hub.KerasLayer(
    preprocess_model)

bert_model = hub.KerasLayer(
    base_model,
    trainable=True)

text_test = ['this is such an amazing movie!']
text_preprocessed = bert_preprocess_model(text_test)

print(f'Keys       : {list(text_preprocessed.keys())}')
print(f'Shape      : {text_preprocessed["input_word_ids"].shape}')
print(f'Word Ids   : {text_preprocessed["input_word_ids"][0, :12]}')
print(f'Input Mask : {text_preprocessed["input_mask"][0, :12]}')
print(f'Type Ids   : {text_preprocessed["input_type_ids"][0, :12]}')

bert_results = bert_model(text_preprocessed)

print(f'Pooled Outputs Shape:{bert_results["pooled_output"].shape}')
print(f'Pooled Outputs Values:{bert_results["pooled_output"][0, :12]}')
print(f'Sequence Outputs Shape:{bert_results["sequence_output"].shape}')
print(f'Sequence Outputs Values:{bert_results["sequence_output"][0, :12]}')

def weighted_scce(weights):
    def my_loss(y_true, y_pred):
        sample_weights = [weights[int(y_true[i].numpy())] for i in 33]
        scce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        return scce(y_true, y_pred, sample_weights=sample_weights)
    return my_loss

def build_classifier_model(tfhub_handle_preprocess, tfhub_handle_encoder, tensor_maps_out):
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
    preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
    encoder_inputs = preprocessing_layer(text_input)
    encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
    outputs = encoder(encoder_inputs)
    net = outputs['pooled_output']
    net = tf.keras.layers.Dropout(dropout_rate)(net)
    #net = tf.keras.layers.Dense(256, activation='swish')(net)
    #net = tf.keras.layers.Dropout(dropout_rate)(net)
    outputs = []
    metrics = []
    losses = []    
    for otm in tensor_maps_out:
        if otm.is_categorical():
            outputs.append(tf.keras.layers.Dense(len(otm.channel_map), activation=None, name=otm.name)(net))
            losses.append(tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
            #losses.append(weighted_scce(class2weight))
            metrics.append(tf.metrics.SparseCategoricalAccuracy(name=f'{otm.name}_SparseCategoricalAccuracy_met'))
        elif otm.is_continuous():
            netl = tf.keras.layers.Dense(otm.annotation_units, activation='swish')(net)
            netl = tf.keras.layers.Dropout(dropout_rate)(netl)
            outputs.append(tf.keras.layers.Dense(1, activation=None, name=otm.name)(netl))
            losses.append(tf.keras.losses.MeanSquaredError()) 
            metrics.append(tf.metrics.MeanAbsoluteError(name=f'{otm.name}_mae'))
    return tf.keras.Model(text_input, outputs), losses, metrics

classifier_model, losses, metrics = build_classifier_model(preprocess_model, base_model, tensor_maps_out)
bert_raw_result = classifier_model(tf.constant(text_test))
print(tf.sigmoid(bert_raw_result[-1]))
tf.keras.utils.plot_model(classifier_model)
metrics = {tm.name: tm.metrics for tm in tensor_maps_out}
losses = [tm.loss for tm in tensor_maps_out]
#loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
#metrics = tf.metrics.SparseCategoricalAccuracy()

In [None]:
epochs = 18
batch_size = 32
steps_per_epoch = len(df[df.set=='train'])//batch_size
num_train_steps = steps_per_epoch * epochs
num_warmup_steps = int(0.1*num_train_steps)
print(f'warm up {steps_per_epoch}  and {num_train_steps}')
init_lr = 1e-5
optimizer = optimization.create_optimizer(init_lr=init_lr,
                                          num_train_steps=num_train_steps,
                                          num_warmup_steps=num_warmup_steps,
                                          optimizer_type='adamw')


classifier_model.compile(optimizer=optimizer, loss=losses, metrics=metrics)

In [None]:
filepath='./models/bertowid'
print(f'Training model with bert, will save to: {filepath}')
patience = 8
callbacks= [
    ModelCheckpoint(filepath=filepath, verbose=1, save_best_only=True),
    EarlyStopping(monitor='val_loss', patience=patience, verbose=1), 
]


history = classifier_model.fit(x=train_ds, steps_per_epoch=steps_per_epoch, 
                               validation_data=valid_ds, 
                               validation_steps=len(df[df.set=='valid'])//batch_size,
                               epochs=epochs, shuffle=True, callbacks=callbacks,
                               
                              )


In [None]:
loss = classifier_model.evaluate(test_ds, steps=len(df[df.set=='test'])//batch_size)

print(f'Loss: {loss}')

In [None]:
from ml4h.plots import plot_metric_history
plot_metric_history(history, steps_per_epoch, 'BERTowid Learning Curves')

In [None]:
classifier_model.summary()

In [None]:
#classifier_model.save('bert_48_drug_sex_age_classifier')
#classifier_model.save('bert_cca11_mse_regressor')

In [None]:
classifier_model = tf.keras.models.load_model(f'./models/bert_cca11_regressor', 
                                              custom_objects={'AdamWeightDecay':optimizer})
# classifier_model = tf.keras.models.load_model(f'bert_48_drug_sex_52_tag_classifier', 
#                                               custom_objects={'AdamWeightDecay':optimizer})

In [None]:
len(df[df.set=='test'])

In [None]:
from collections import defaultdict
predictions = defaultdict(list)
truths = defaultdict(list)
for text, labels in test_ds.as_numpy_iterator():
    for l in labels:
        truths[l].extend(map(int, list(labels[l])))
    p = classifier_model.predict(text)
    if len(classifier_model.output_names) == 1:
        p = [p]
    for i,ot in enumerate(classifier_model.output_names):
        predictions[ot].extend(list(p[i]))
        
    if len(truths[l]) >= len(df[df.set=='test']):
        break

In [None]:
from ml4h.plots import plot_roc, subplot_rocs, plot_scatter
def make_one_hot(y, num_labels):
    ohy = np.zeros((y.shape[-1], num_labels))
    for i in range(0, y.shape[-1]):
        ohy[i, int(y[i])] = 1.0
    return ohy

rocs = []
for otm in tensor_maps_out:
    if otm.is_categorical():
        print(f' otm {otm} {np.array(predictions[otm.name]).shape}')
        plot_roc(np.array(predictions[otm.name]), 
                 make_one_hot(np.array(truths[otm.name]), len(otm.channel_map)), 
                 otm.channel_map, otm.name)
        rocs.append((np.array(predictions[otm.name]), 
                     make_one_hot(np.array(truths[otm.name]), len(otm.channel_map)), 
                     otm.channel_map))
    elif otm.is_continuous():
        plot_scatter(np.array(predictions[otm.name]), np.array(truths[otm.name]), otm.name)
subplot_rocs(rocs)

In [None]:
from ml4h.plots import plot_roc, subplot_rocs, plot_scatter
def make_one_hot(y, num_labels):
    ohy = np.zeros((y.shape[-1], num_labels))
    for i in range(0, y.shape[-1]):
        ohy[i, int(y[i])] = 1.0
    return ohy

rocs = []
for otm in tensor_maps_out:
    if otm.is_categorical():
        print(f' otm {otm} {np.array(predictions[otm.name]).shape}')
        plot_roc(np.array(predictions[otm.name]), 
                 make_one_hot(np.array(truths[otm.name]), len(otm.channel_map)), 
                 otm.channel_map, otm.name)
        rocs.append((np.array(predictions[otm.name]), 
                     make_one_hot(np.array(truths[otm.name]), len(otm.channel_map)), 
                     otm.channel_map))
    elif otm.is_continuous():
        plot_scatter(np.array(predictions[otm.name]), np.array(truths[otm.name]), otm.name)
subplot_rocs(rocs)