In [None]:
import sys 
sys.path.append("./tabnet/tf_tabnet/")

import math

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.preprocessing import minmax_scale
from sklearn.decomposition import PCA

import tensorflow as tf
from tensorflow.data import Dataset
from tensorflow.keras.models import Model
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.models import load_model, save_model
from keras import backend as K

from tensorflow_addons.optimizers import AdamW

from model.arcface_loss import ArcFace
import tabnet_model

from train import make_X_y, encode_y, scale_X, quantile_X, grouped_train_test_split
from eval import recall_at_k

In [None]:
def seed_everything(seed=24):
    np.random.seed(seed)
    tf.random.set_seed(seed)

seed_everything(24)

In [None]:
df = pd.read_parquet("../../../../data/clean/clean_sample.parquet")
df.shape

In [None]:
X, y = make_X_y(df)
X_train, X_test, y_train, y_test = grouped_train_test_split(X, y, y, test_size=0.2)
X_train_scale, X_test_scale = scale_X(X_train, X_test)
#X_train_scale, X_test_scale = quantile_X(X_train, X_test)

num_classes = len(np.unique(y_train))
y_train_encode = encode_y(y_train)

In [None]:
BATCH_SIZE = 1024

train_dataset = Dataset.from_tensor_slices((dict(X_train), y_train_encode))
label_dataset = Dataset.from_tensor_slices(y_train_encode)
dataset = Dataset.zip((train_dataset, label_dataset)).shuffle(100).batch(BATCH_SIZE).prefetch(1)

In [None]:
def create_keras_input_layer(feature_names):
    """returns list of keras.engine.keras_tensor.KerasTensor"""
    
    model_inputs = list()
    
    for name in feature_names:
        dtype = tf.float32
        shape = (1,) if dtype==tf.float32 else ()
        model_inputs.append(tf.keras.Input(shape=shape, name=name, dtype=dtype))
    
    return model_inputs

In [None]:
def encode_features(keras_inputs, feature_names):
    encoded_features = list()

    for keras_input, feature_name in zip(keras_inputs, feature_names):
        # no encoding for numerical features
        encoded_features.append(keras_input)
    
    return encoded_features

In [None]:
tabnet_params = {
        "decision_dim": 16, 
        "attention_dim": 16, 
        "n_steps": 5, 
        "n_shared_glus": 2, 
        "n_dependent_glus": 2, 
        "relaxation_factor": 1.5, 
        "epsilon": 1e-15, 
        "virtual_batch_size": None, 
        "momentum": 0.98, 
        "mask_type": "entmax", 
        "lambda_sparse": 1e-4, 
}

In [None]:
feature_names = list(X_train.columns)
embedding_size = 32

# Keras model using Functional API
gene_expression = create_keras_input_layer(feature_names)

x = encode_features(gene_expression, feature_names)
x = tf.keras.layers.Concatenate()(x)
x = tabnet_model.TabNetEncoder(**tabnet_params)(x)
x = layers.Dense(embedding_size, name="embedding")(x) 
l2norm_embedding = layers.Lambda(lambda t: K.l2_normalize(t, axis=1))(x) #https://stackoverflow.com/questions/53960965/normalized-output-of-keras-layer

labels = layers.Input(shape=(1,), dtype = np.int32, name="labels") 
x = ArcFace(num_classes, BATCH_SIZE, max_m=0.2)([l2norm_embedding, labels]) 
output = layers.Activation('softmax')(x)

model = Model([gene_expression, labels], output)

In [None]:
model.compile(loss=SparseCategoricalCrossentropy(),
                optimizer=AdamW(learning_rate=5e-3, weight_decay=1e-5),
                metrics=['accuracy'])

EPOCHS = 5

model.fit(dataset,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        verbose=1)

In [None]:
embeddings_model = Model(gene_expression, model.get_layer('lambda').output)

embed_dataset = {name: tf.convert_to_tensor(value) for name, value in dict(X_test).items()}
embedded = embeddings_model.predict(embed_dataset, verbose=1)

In [None]:
_, embedded_sample, _, labs_sample = grouped_train_test_split(embedded, y_test, y_test, test_size=100)
print(embedded_sample.shape)

recall = recall_at_k(embedded_sample, embedded, y_test)
quantile = minmax_scale(np.arange(1, embedded.shape[0]), feature_range=(0, 1))

auc = np.trapz(recall, quantile)

In [None]:
auc_lab = f"AUC {auc:.2f}"

fig, ax = plt.subplots()
ax.plot(quantile, recall)
props = dict(boxstyle='round', facecolor='white', alpha=0.5)
ax.text(0.73, 0.1, auc_lab, transform=ax.transAxes, fontsize=14,
        verticalalignment='bottom', bbox=props)
plt.title("Compound Retrieval for Embedded Signatures in Test Set")
plt.xlabel("Proportion of Results Included")
plt.ylabel("Proportion of Compound Instances Identified")
plt.show()

In [None]:
#0.0005 lr, 2048 batch, 25 epochs, got loss 7.83 and 0.81 auc
#max_m 0.15