In [1]:
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, serial, Sigmoid
from jax.nn import relu, sigmoid
from jax.experimental.stax import elementwise
from jax import random
import jax
from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMAvgHidden
import jax.numpy as jnp



In [2]:
import pandas as pd
from os import path
import numpy as np

In [3]:
DATA_DIR = "../../data"

In [4]:
train_df = pd.read_csv(path.join(DATA_DIR, "chen/deduplicated/chen_train_data.csv"), index_col=0)
train_df.head()

Unnamed: 0,Antibody_ID,heavy,light,Y
2073,6aod,EVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLE...,DIVMTKSPSSLSASVGDRVTITCRASQGIRNDLGWYQQKPGKAPKR...,0
1517,4yny,EVQLVESGGGLVQPGRSLKLSCAASGFTFSNYGMAWVRQTPTKGLE...,EFVLTQPNSVSTNLGSTVKLSCKRSTGNIGSNYVNWYQQHEGRSPT...,1
2025,5xcv,EVQLVESGGGLVQPGRSLKLSCAASGFTFSNYGMAWVRQTPTKGLE...,QFVLTQPNSVSTNLGSTVKLSCKRSTGNIGSNYVNWYQQHEGRSPT...,1
2070,6and,EVQLVESGGGLVQPGGSLRLSCAASGYEFSRSWMNWVRQAPGKGLE...,DIQMTQSPSSLSASVGDRVTITCRSSQSIVHSVGNTFLEWYQQKPG...,1
666,2xqy,QVQLQQPGAELVKPGASVKMSCKASGYSFTSYWMNWVKQRPGRGLE...,DIVLTQSPASLALSLGQRATISCRASKSVSTSGYSYMYWYQQKPGQ...,0


In [5]:
test_df = pd.read_csv(path.join(DATA_DIR, "chen/deduplicated/chen_test_data.csv"), index_col=0)
valid_df = pd.read_csv(path.join(DATA_DIR, "chen/deduplicated/chen_valid_data.csv"), index_col=0)
test_df = pd.concat([test_df, valid_df])


In [6]:
train_h_seqs = list(train_df["heavy"])
train_labels = list(train_df["Y"])
len(train_h_seqs)

1338

In [7]:
test_h_seqs = list(test_df["heavy"])
test_labels = list(test_df["Y"])
len(test_h_seqs)

239

In [8]:
from jax.experimental.optimizers import adam
from jax import grad, jit

In [9]:
from jax_unirep.utils import seq_to_oh
from jax_unirep.utils import load_params

In [44]:
init_fun, apply_fun = serial(
    AAEmbedding(10),
    mLSTM(1900),
    mLSTMAvgHidden(),
    Dense(512), 
    elementwise(relu),
    Dense(1)
)

In [45]:
rng_key = random.PRNGKey(0)

batch_size = 1
num_classes = 2
#input_shape = (173, 26, batch_size)
input_shape = (batch_size, 153, 26)
step_size = 0.1
num_steps = 10

In [46]:
_, init_params = init_fun(rng_key, input_shape)
params = load_params(paper_weights=1900)

In [47]:
def loss(params, batch):
    inputs, targets = batch
    logits = apply_fun(params, inputs)
    log_p = jax.nn.log_sigmoid(logits)
    log_not_p = jax.nn.log_sigmoid(1-logits)
    #cross_entropy = -targets * np.log(logits) - (1 - targets)*np.log(1 - logits)
    res = -targets * log_p - (1. - targets) * log_not_p
    
    return res.mean()
    # without the .mean(), res has shape (25,)
    #return res


In [14]:
def pad_seq(seq, max_len=151):
    seq = seq[:max_len]
    if len(seq) < max_len:
        seq = seq + "-" * (max_len - len(seq))
    return seq

In [15]:
train_h_seqs = [pad_seq(seq) for seq in train_h_seqs]
test_h_seqs = [pad_seq(seq) for seq in test_h_seqs]

In [48]:
def get_batches():
    oh_seqs = [seq_to_oh(seq) for seq in train_h_seqs]
    labels = train_labels
    num_batches = len(labels) // batch_size
    for i in range(num_batches):
        #x = np.swapaxes(np.swapaxes(np.asarray(oh_seqs[i*batch_size : (i+1)*batch_size]), 0, 1), 1, 2)
        #x = np.asarray(oh_seqs[i*batch_size : (i+1)*batch_size])
        #x = np.swapaxes(np.asarray(oh_seqs[i*batch_size : (i+1)*batch_size]), 0,1)
        x = np.asarray(oh_seqs[i])#.transpose()
        yield x, np.asarray(labels[i*batch_size : (i+1)*batch_size])
    

In [49]:
opt_init, opt_update, get_params = adam(step_size)
batches = get_batches()

In [56]:
#@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    l = loss(params, batch)
    print(l)
    return opt_update(i, grad(loss)(params, batch), opt_state)

In [52]:
#opt_state = opt_init(init_params)
opt_state = opt_init(params)

In [57]:
for i in range(num_steps):
    print(f"Training epoch {i}...")
    opt_state = update(i, opt_state, next(batches))
trained_params = get_params(opt_state)

Training epoch 0...
2.1457455
Training epoch 1...
0.3132617
Training epoch 2...
0.3132617
Training epoch 3...
0.3132617
Training epoch 4...
0.3132617
Training epoch 5...
0.3132617
Training epoch 6...
0.31411317
Training epoch 7...
0.3132617
Training epoch 8...
0.3132617
Training epoch 9...
0.3132617


In [None]:
test_oh_seqs = [seq_to_oh(seq) for seq in test_h_seqs]

apply_fun(trained_params, test_oh_seqs[0])

DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [24]:
apply_fun(trained_params, test_oh_seqs[0]).shape

(25,)