In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import KFold

from lmmnn.layers import NLL
from lmmnn.callbacks import EarlyStoppingWithSigmasConvergence
from lmmnn.menet import menet_fit, menet_predict

from tensorflow.keras.preprocessing import text, sequence
from tensorflow.keras import Model
from tensorflow.keras.layers import Embedding, LSTM, Dense, Reshape, Concatenate, Input, Layer, Dropout, Flatten
from tensorflow.keras.callbacks import EarlyStopping, Callback
import tensorflow.keras.backend as K

In [None]:
gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
# The drugs_df CSV comes from simple binding the train and test TSVs from Gräßer et al. (2018),
# available in the UCI ML repo, see our paper.
drugs = pd.read_csv('drugs_df.csv')
RE_col = 'drug_name'

In [None]:
drugs.head()

Unnamed: 0,id,drugName,condition,review,rating,date,usefulCount,drug_name
0,206461,Valsartan,Left Ventricular Dysfunction,"""It has no side effect, I take it in combinati...",9,"May 20, 2012",27,3428
1,95260,Guanfacine,ADHD,"""My son is halfway through his fourth week of ...",8,"April 27, 2010",192,1542
2,92703,Lybrel,Birth Control,"""I used to take another oral contraceptive, wh...",5,"December 14, 2009",17,1989
3,138000,Ortho Evra,Birth Control,"""This is my first time using any form of birth...",8,"November 3, 2015",10,2456
4,35696,Buprenorphine / naloxone,Opiate Dependence,"""Suboxone has completely turned my life around...",9,"November 27, 2016",37,553


In [None]:
max_features = 10000
batch_size = 20
epochs = 10
seq_len = 100
words_embed_dim = 100
Z_embed_dim = 10
lstm_kernels = 64
n_cats = drugs[RE_col].max() + 1

In [None]:
n_cats

3671

In [None]:
drugs[RE_col].min()

0

In [None]:
drugs[RE_col].max()

3670

In [None]:
tokenizer = text.Tokenizer(num_words=max_features)
tokenizer.fit_on_texts(drugs['review'])
text_sequences = tokenizer.texts_to_sequences(drugs['review'])
X = sequence.pad_sequences(text_sequences, padding='post', maxlen=seq_len)
X = pd.DataFrame(X)
x_cols = ['X' + str(i) for i in range(seq_len)]
X.columns = x_cols
X = pd.concat([X, drugs[RE_col]], axis=1)

In [None]:
X.loc[0, x_cols].values

array([   5,   38,   28,   35,  198,    1,   45,    5,   15,  832,   12,
       2948,   99,  149,    2, 3852, 1585,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0])

In [None]:
((X[x_cols] > 0).astype(int).sum(axis=1) == seq_len).mean()

0.427009759930811

In [None]:
drugs.loc[0, 'review']

'"It has no side effect, I take it in combination of Bystolic 5 Mg and Fish Oil"'

In [None]:
tokenizer.word_index['it']

5

In [None]:
def lstm_ignore():
    input_layer = Input(shape=(None, ), dtype=tf.int32)
    x = Embedding(max_features + 1, words_embed_dim)(input_layer)
    x = LSTM(lstm_kernels)(x)
    output = Dense(1)(x)
    return Model(inputs=[input_layer], outputs=output)

def lstm_lmmnn():
    input_layer = Input(shape=(seq_len, ), dtype=tf.int32)
    y_true_input = Input(shape=(1, ),)
    Z_input = Input(shape=(1, ), dtype=tf.int64)
    x = Embedding(max_features + 1, words_embed_dim)(input_layer)
    x = LSTM(lstm_kernels)(x)
    y_pred_output = Dense(1)(x)
    nll = NLL(1.0, 1.0)(y_true_input, y_pred_output, Z_input)
    return Model(inputs=[input_layer, y_true_input, Z_input], outputs=nll)

def lstm_embed():
    input_layer = Input(shape=(None, ), dtype=tf.int32)
    Z_input = Input(shape=(1,))
    embed = Embedding(n_cats, Z_embed_dim, input_length = 1)(Z_input)
    embed = Reshape(target_shape = (Z_embed_dim, ))(embed)
    x = Embedding(max_features + 1, words_embed_dim)(input_layer)
    x = LSTM(lstm_kernels)(x)
    concat = Concatenate()([x, embed])
    output = Dense(1)(concat)
    return Model(inputs=[input_layer, Z_input], outputs=output)

def lstm_ohe(p):
    input_layer = Input(shape=(None, ), dtype=tf.int32)
    ohe_input = Input(shape=(p, ))
    x = Embedding(max_features + 1, words_embed_dim)(input_layer)
    x = LSTM(lstm_kernels)(x)
    concat = Concatenate()([x, ohe_input])
    output = Dense(1)(concat)
    return Model(inputs=[input_layer, ohe_input], outputs=output)

In [None]:
def calc_b_hat(Z_train, y_train, y_pred_tr, n_cats, sig2e, sig2b):
    b_hat = []
    for i in range(n_cats):
        i_vec = Z_train == i
        n_i = i_vec.sum()
        if n_i > 0:
            y_bar_i = y_train[i_vec].mean()
            y_pred_i = y_pred_tr[i_vec].mean()
            # BP(b_i) = (n_i * sig2b / (sig2a + n_i * sig2b)) * (y_bar_i - y_pred_bar_i)
            b_i = n_i * sig2b * (y_bar_i - y_pred_i) / (sig2e + n_i * sig2b)
        else:
            b_i = 0
        b_hat.append(b_i)
    return np.array(b_hat)

In [None]:
def process_one_hot_encoding(X_train, X_test, RE_col):
    X_train_ohe = pd.concat([X_train[x_cols], pd.get_dummies(X_train[RE_col])], axis=1)
    X_test_ohe = pd.concat([X_test[x_cols], pd.get_dummies(X_test[RE_col])], axis=1)
    X_test_cols_in_train = set(X_test_ohe.columns).intersection(X_train_ohe.columns)
    X_train_cols_not_in_test = set(X_train_ohe.columns).difference(X_test_ohe.columns)
    X_test_comp = pd.DataFrame(np.zeros((X_test.shape[0], len(X_train_cols_not_in_test))),
                               columns=X_train_cols_not_in_test, dtype=np.uint8, index=X_test.index)
    X_test_ohe_comp = pd.concat([X_test_ohe[X_test_cols_in_train], X_test_comp], axis=1)
    X_test_ohe_comp = X_test_ohe_comp[X_train_ohe.columns]
    return X_train_ohe, X_test_ohe_comp

In [None]:
def reg_nn_ignore(X_train, X_test, y_train, y_test, n_cats, batch_size, epochs, patience, deep=False):
    model = lstm_ignore()
    model.compile(loss='mse', optimizer='adam')

    callbacks = [EarlyStopping(monitor='val_loss', patience=epochs if patience is None else patience)]
    history = model.fit(X_train[x_cols], y_train, batch_size=batch_size, epochs=epochs,
                        validation_split=0.1, callbacks=callbacks, verbose=1)
    y_pred = model.predict(X_test[x_cols]).reshape(X_test.shape[0])
    y_pred = np.clip(y_pred, 1, 10)
    return y_pred, (None, None)

def reg_nn_ohe(X_train, X_test, y_train, y_test, n_cats, batch_size, epochs, patience, deep=False):
    X_train, X_test = process_one_hot_encoding(X_train, X_test, RE_col)
    model = lstm_ohe(X_train.drop(x_cols, axis=1).shape[1])
    model.compile(loss='mse', optimizer='adam')

    callbacks = [EarlyStopping(monitor='val_loss', patience=epochs if patience is None else patience)]
    history = model.fit([X_train[x_cols], X_train.drop(x_cols, axis=1)], y_train, batch_size=batch_size, epochs=epochs,
                        validation_split=0.1, callbacks=callbacks, verbose=1)
    y_pred = model.predict([X_test[x_cols], X_test.drop(x_cols, axis=1)]).reshape(X_test.shape[0])
    y_pred = np.clip(y_pred, 1, 10)
    return y_pred, (None, None)

def reg_nn_lmm(X_train, X_test, y_train, y_test, n_cats, batch_size, epochs, patience, deep=False):
    model = lstm_lmmnn()
    model.compile(optimizer= 'adam')
    
    patience = epochs if patience is None else patience
    callbacks = [EarlyStoppingWithSigmasConvergence(patience=patience), PrintSigmas()]
    history = model.fit([X_train[x_cols], y_train, X_train[RE_col]], None,
                        batch_size=batch_size, epochs=epochs, validation_split=0.1,
                        callbacks=callbacks, verbose=1)
    
    sig2e_est, sig2b_est = model.layers[-1].get_vars()
    y_pred_tr = model.predict([X_train[x_cols], y_train, X_train[RE_col]]).reshape(X_train.shape[0])
    y_pred_tr = np.clip(y_pred_tr, 1, 10)
    b_hat = calc_b_hat(X_train[RE_col], y_train, y_pred_tr, n_cats, sig2e_est, sig2b_est)
    dummy_y_test = np.random.normal(size=y_test.shape)
    y_pred = model.predict([X_test[x_cols], dummy_y_test, X_test[RE_col]]).reshape(X_test.shape[0]) + b_hat[X_test[RE_col]]
    y_pred = np.clip(y_pred, 1, 10)
    return y_pred, (sig2e_est, sig2b_est)

def reg_nn_embed(X_train, X_test, y_train, y_test, n_cats, batch_size, epochs, patience, deep=False):
    model = lstm_embed()

    model.compile(loss='mse', optimizer='adam')

    callbacks = [EarlyStopping(monitor='val_loss', patience=epochs if patience is None else patience)]
    history = model.fit([X_train[x_cols], X_train[RE_col]], y_train,
                        batch_size=batch_size, epochs=epochs, validation_split=0.1,
                        callbacks=callbacks, verbose=1)
    y_pred = model.predict([X_test[x_cols], X_test[RE_col]]).reshape(X_test.shape[0])
    y_pred = np.clip(y_pred, 1, 10)
    return y_pred, (None, None)

def reg_nn_menet(X_train, X_test, y_train, y_test, n_cats, batch_size, epochs, patience, deep=False):
    q = n_cats
    clusters_train, clusters_test = X_train[RE_col].values, X_test[RE_col].values
    X_train, X_test = X_train[x_cols].values, X_test[x_cols].values
    y_train, y_test = y_train.values, y_test.values

    model = lstm_ignore()
    model.compile(loss='mse', optimizer='adam')

    model, b_hat, sig2e_est, n_epochs, _ = menet_fit(model, X_train, y_train, clusters_train, q,
        batch_size, epochs, patience, verbose=True)
    y_pred = menet_predict(model, X_test, clusters_test, q, b_hat)
    y_pred = np.clip(y_pred, 1, 10)
    return y_pred, (sig2e_est, None), n_epochs

def reg_nn(X_train, X_test, y_train, y_test, n_cats, batch=30, epochs=100, patience=5, reg_type='ohe', deep=False):    
    if reg_type == 'ohe':
        y_pred, sigmas = reg_nn_ohe(X_train, X_test, y_train, y_test, n_cats, batch, epochs, patience, deep)
    elif reg_type == 'lmm':
        y_pred, sigmas = reg_nn_lmm(X_train, X_test, y_train, y_test, n_cats, batch, epochs, patience, deep)
    elif reg_type == 'ignore':
        y_pred, sigmas = reg_nn_ignore(X_train, X_test, y_train, y_test, n_cats, batch, epochs, patience, deep)
    elif reg_type == 'embed':
        y_pred, sigmas = reg_nn_embed(X_train, X_test, y_train, y_test, n_cats, batch, epochs, patience, deep)
    elif reg_type == 'menet':
        y_pred, sigmas, _ = reg_nn_menet(X_train, X_test, y_train, y_test, n_cats, batch, epochs, patience, deep)
    else:
        raise ValueError(reg_type + ' is an unknown reg_type')
    mse = np.mean((y_pred - y_test)**2)
    return mse, sigmas

In [None]:
res = pd.DataFrame(columns=['experiment', 'exp_type', 'deep', 'mse', 'sigma_e_est', 'sigma_b_est'])
counter = 0

def iterate_reg_types(X_train, X_test, y_train, y_test, deep=True):
    global counter
    mse_ig, _ = reg_nn(X_train, X_test, y_train, y_test, n_cats, reg_type='ignore', deep=deep)
    print(' finished ignore deep=%s, mse: %.2f' % (deep, mse_ig))
    mse_lmm, sigmas = reg_nn(X_train, X_test, y_train, y_test, n_cats, reg_type='lmm', deep=deep)
    print(' finished lmm deep=%s, mse: %.2f' % (deep, mse_lmm))
    mse_lm, _ = reg_nn(X_train, X_test, y_train, y_test, n_cats, reg_type='ohe', deep=deep)
    print(' finished lm deep=%s, mse: %.2f' % (deep, mse_lm))
    mse_em, _ = reg_nn(X_train, X_test, y_train, y_test, n_cats, reg_type='embed', deep=deep)
    print(' finished embed deep=%s, mse: %.2f' % (deep, mse_em))
    mse_me, sigmas_me = reg_nn(X_train, X_test, y_train, y_test, n_cats, reg_type='menet', deep=deep)
    print(' finished menet deep=%s, mse: %.2f' % (deep, mse_me))
    mse_dec = 100 * (mse_lmm - mse_lm) / mse_lm
    res.loc[counter + 0] = [i, 'ohe', deep, mse_lm, np.nan, np.nan]
    res.loc[counter + 1] = [i, 'lmm', deep, mse_lmm, sigmas[0], sigmas[1]]
    res.loc[counter + 2] = [i, 'ignore', deep, mse_ig, np.nan, np.nan]
    res.loc[counter + 3] = [i, 'embed', deep, mse_em, np.nan, np.nan]
    res.loc[counter + 4] = [i, 'menet', deep, mse_me, sigmas_me[0], np.nan]
    counter += 5
    print('iteration %d, deep=%s, mse change from mse_lm: %.2f%%' % (i, deep, mse_dec)) 

kf = KFold(n_splits=5)
y = drugs['rating']

for i, (train_index, test_index) in enumerate(kf.split(X, y)):
    print('iteration %d' % i)
    X_train, X_test, y_train, y_test = X.loc[train_index], X.loc[test_index], y[train_index], y[test_index]
    iterate_reg_types(X_train, X_test, y_train, y_test)

iteration 0
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
 finished ignore deep=True, mse: 2.78
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
 sig2e: 1.11, sig2b: 0.06
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
 sig2e: 0.44, sig2b: 0.02
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
 sig2e: 0.23, sig2b: 0.00
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 

KeyboardInterrupt: ignored

In [None]:
res

Unnamed: 0,experiment,exp_type,deep,mse,sigma_e_est,sigma_b_est
0,0,ohe,True,2.761191,,
1,0,lmm,True,2.680337,0.092918,0.007628
2,0,ignore,True,2.779226,,
3,0,embed,True,2.735665,,
4,1,ohe,True,2.775354,,
5,1,lmm,True,2.67406,0.108524,0.024346
6,1,ignore,True,2.745295,,
7,1,embed,True,2.907608,,
8,2,ohe,True,2.753927,,
9,2,lmm,True,2.646433,0.091408,0.007817


In [None]:
res.to_csv('../results/drugs.csv')