# Load Data

In [None]:
#! pip install pandas sklearn

In [None]:
!pip install tensorflow_ranking

In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
import tensorflow_ranking as tfr

In [None]:
import pandas as pd
from io import BytesIO
df = pd.read_csv(BytesIO(list(uploaded.values())[0]))

In [None]:
df_train = df.sample(frac = 0.8)
df_test = df.drop(df_train.index)

In [None]:
df.bind_avg.hist(bins=100)

In [None]:
((df.bind_avg-df.bind_avg.mean())**2).mean()

In [None]:
import tensorflow as tf
import numpy as np

VOCAB_SIZE = 30
encoder = tf.keras.layers.TextVectorization(
    max_tokens=VOCAB_SIZE,split="character", input_shape=(1,))
encoder.adapt(df.Sequence)


In [None]:
encoder(df.Sequence[0])

In [None]:
vocab = np.array(encoder.get_vocabulary())


In [None]:
encoded_example = encoder(df.Sequence[0]).numpy()
encoded_example

# Baseline LSTM Model

In [None]:
model = tf.keras.Sequential([
    encoder,
    tf.keras.layers.Embedding(
        input_dim=len(encoder.get_vocabulary()),
        output_dim=64,
        # Use masking to handle the variable sequence lengths
        mask_zero=True),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1)
])


In [None]:
def norm(x):
  return (x-x.mean())/x.std()
def mse(x):
  return ((x-x.mean())**2).mean()

from scipy.stats import spearmanr
from typing import Tuple


def spearman_rankcor(y_true, y_pred):
    return (tf.py_function(
        spearmanr, [tf.cast(y_pred, tf.float32),
                    tf.cast(y_true, tf.float32)],
        Tout=tf.float32))

# Mock Data

In [None]:
mock_data_x = 1000*["AAAAAAAA","BBBBBB"]
mock_data_y = 1000*[-1.,1.]
model.compile(loss="mse",optimizer=tf.keras.optimizers.Adam(1e-3),metrics=["mse",spearman_rankcor])
model.fit(mock_data_x,mock_data_y, epochs=1000, batch_size=10)

In [None]:
model.predict(mock_data_x)

# LSTM with MSE loss

In [None]:
model_mse = tf.keras.Sequential([
    encoder,
    tf.keras.layers.Embedding(
        input_dim=len(encoder.get_vocabulary()),
        output_dim=64,
        # Use masking to handle the variable sequence lengths
        mask_zero=True),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1)
])
model_mse.compile(loss="mse",optimizer=tf.keras.optimizers.Adam(1e-4),metrics=["mse",spearman_rankcor])
model_mse.fit(df_train.Sequence,df_train.bind_avg, validation_data=(df_test.Sequence,df_test.bind_avg), epochs=100, batch_size=512)

# LSTM with ranking loss

In [None]:
pwl = tfr.keras.losses.PairwiseHingeLoss()
mse = tf.keras.losses.MeanSquaredError()
# TODO: add relative weighs for loss components.
def nested_pairwise_loss(x,y):
  return pwl(tf.reshape(x,(1,-1)),tf.reshape(y,(1,-1))) + mse(x,y)


In [None]:
model_ranking = tf.keras.Sequential([
    encoder,
    tf.keras.layers.Embedding(
        input_dim=len(encoder.get_vocabulary()),
        output_dim=64,
        # Use masking to handle the variable sequence lengths
        mask_zero=True),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1)
])

model_ranking.compile(loss=nested_pairwise_loss,optimizer=tf.keras.optimizers.Adam(1e-4),metrics=["mse",spearman_rankcor])
model_ranking.fit(df_train.Sequence,df_train.bind_avg, validation_data=(df_test.Sequence,df_test.bind_avg), epochs=500, batch_size=1024)

In [None]:
model.compile(loss=tfr.keras.losses.PairwiseHingeLoss,optimizer=tf.keras.optimizers.Adam(1e-5),metrics=["mse",spearman_rankcor])
model.fit(df_train.Sequence,df_train.bind_avg, validation_data=(df_test.Sequence,df_test.bind_avg), epochs=10, batch_size=128)

In [None]:
%pylab
%matplotlib inline

In [None]:
plt.hist([predictions[df.bind_avg<-3].squeeze(),predictions[df.bind_avg>-3].squeeze()],100, histtype='step', stacked=True, fill=False, density=True)

In [None]:
#predictions = model_ranking.predict(df.Sequence)
scatter(predictions,df.bind_avg, marker='.')
xlim(-1.02,-1.01)

# ESM-2 8M model

In [None]:
! pip install transformers evaluate datasets requests pandas

In [None]:
model_checkpoint = "facebook/esm2_t6_8M_UR50D"
from transformers import AutoTokenizer
from datasets import Dataset


tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
train_tokenized = tokenizer(list(df_train.Sequence.values))
test_tokenized = tokenizer(list(df_test.Sequence.values))
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset = train_dataset.add_column("labels", df_train.bind_avg)
test_dataset = test_dataset.add_column("labels", df_test.bind_avg)
train_dataset

In [None]:
from transformers import TFAutoModelForSequenceClassification

# Declare model as single class and use for regression
model_transformer = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=1,
                                                           ignore_mismatched_sizes=True)

In [None]:
tf_train_set = model.prepare_tf_dataset(
    train_dataset,
    batch_size=64,
    shuffle=True,
    tokenizer=tokenizer
)

tf_test_set = model.prepare_tf_dataset(
    test_dataset,
    batch_size=8,
    shuffle=False,
    tokenizer=tokenizer
)

In [None]:
from transformers import AdamWeightDecay

model_transformer.compile(optimizer=AdamWeightDecay(2e-5),loss="mse", metrics=["mse",get_spearman_rankcor])

In [None]:
model_transformer.fit(tf_train_set, validation_data=tf_test_set, epochs=3)