In [56]:
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, serial, Sigmoid
from jax.nn import relu, sigmoid
from jax.experimental.stax import elementwise
from jax import random
import jax
from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMAvgHidden
import jax.numpy as jnp

In [2]:
import pandas as pd
from os import path
import numpy as np

In [3]:
DATA_DIR = "../../data"

In [4]:
train_df = pd.read_csv(path.join(DATA_DIR, "chen/deduplicated/chen_train_data.csv"), index_col=0)
train_df.head()

Unnamed: 0,Antibody_ID,heavy,light,Y
2073,6aod,EVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLE...,DIVMTKSPSSLSASVGDRVTITCRASQGIRNDLGWYQQKPGKAPKR...,0
1517,4yny,EVQLVESGGGLVQPGRSLKLSCAASGFTFSNYGMAWVRQTPTKGLE...,EFVLTQPNSVSTNLGSTVKLSCKRSTGNIGSNYVNWYQQHEGRSPT...,1
2025,5xcv,EVQLVESGGGLVQPGRSLKLSCAASGFTFSNYGMAWVRQTPTKGLE...,QFVLTQPNSVSTNLGSTVKLSCKRSTGNIGSNYVNWYQQHEGRSPT...,1
2070,6and,EVQLVESGGGLVQPGGSLRLSCAASGYEFSRSWMNWVRQAPGKGLE...,DIQMTQSPSSLSASVGDRVTITCRSSQSIVHSVGNTFLEWYQQKPG...,1
666,2xqy,QVQLQQPGAELVKPGASVKMSCKASGYSFTSYWMNWVKQRPGRGLE...,DIVLTQSPASLALSLGQRATISCRASKSVSTSGYSYMYWYQQKPGQ...,0


In [71]:
test_df = pd.read_csv(path.join(DATA_DIR, "chen/deduplicated/chen_test_data.csv"), index_col=0)
valid_df = pd.read_csv(path.join(DATA_DIR, "chen/deduplicated/chen_valid_data.csv"), index_col=0)
test_df = pd.concat([test_df, valid_df])


In [5]:
train_h_seqs = list(train_df["heavy"])
train_labels = list(train_df["Y"])
len(train_h_seqs)

1338

In [72]:
test_h_seqs = list(test_df["heavy"])
test_labels = list(test_df["Y"])
len(test_h_seqs)

239

In [6]:
from jax.experimental.optimizers import adam
from jax import grad, jit

In [7]:
from jax_unirep.utils import seq_to_oh
from jax_unirep.utils import load_params

In [8]:
init_fun, apply_fun = serial(
    AAEmbedding(10),
    mLSTM(1900),
    mLSTMAvgHidden(),
    Dense(512), 
    elementwise(relu),
    Dense(1),
    elementwise(sigmoid)
)

In [9]:
rng_key = random.PRNGKey(0)

batch_size = 1
num_classes = 2
#input_shape = (173, 26, batch_size)
input_shape = (batch_size, 173, 26)
step_size = 0.1
num_steps = 10

In [10]:
_, init_params = init_fun(rng_key, input_shape)
params = load_params(paper_weights=1900)

In [105]:
def loss(params, batch):
    inputs, targets = batch
    #print(inputs.shape)
    logits = apply_fun(params, inputs)
    #print(logits.shape)
    
    log_p = jax.nn.log_sigmoid(logits)
    log_not_p = jax.nn.log_sigmoid(1-logits)
    #cross_entropy = -targets * np.log(logits) - (1 - targets)*np.log(1 - logits)
    res = -targets * log_p - (1. - targets) * log_not_p
    #print(type(res))
    #print(res)
    #return res.mean()
    return res
    
    #return -targets * log_p - (1. - targets) * log_not_p

In [106]:
def pad_seq(seq, max_len=171):
    if len(seq) < max_len:
        seq = seq + "-" * (max_len - len(seq))
    return seq

In [101]:
train_h_seqs = [pad_seq(seq) for seq in train_h_seqs]
test_h_seqs = [pad_seq(seq) for seq in test_h_seqs]

In [102]:
def get_batches():
    oh_seqs = [seq_to_oh(seq) for seq in train_h_seqs]
    labels = train_labels
    num_batches = len(labels) // batch_size
    for i in range(num_batches):
        #x = np.swapaxes(np.swapaxes(np.asarray(oh_seqs[i*batch_size : (i+1)*batch_size]), 0, 1), 1, 2)
        #x = np.asarray(oh_seqs[i*batch_size : (i+1)*batch_size])
        #x = np.swapaxes(np.asarray(oh_seqs[i*batch_size : (i+1)*batch_size]), 0,1)
        x = np.asarray(oh_seqs[i])#.transpose()
        yield x, np.asarray(labels[i*batch_size : (i+1)*batch_size])
    

In [103]:
opt_init, opt_update, get_params = adam(step_size)
batches = get_batches()

In [107]:
@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    l = loss(params, batch)
    print(l.shape)
    return opt_update(i, grad(loss)(params, batch), opt_state)

In [108]:
#opt_state = opt_init(init_params)
opt_state = opt_init(params)

In [109]:
for i in range(num_steps):
    opt_state = update(i, opt_state, next(batches))
trained_params = get_params(opt_state)

(25,)


TypeError: Gradient only defined for scalar-output functions. Output had shape: (25,).

In [95]:
test_oh_seqs = [seq_to_oh(seq) for seq in test_h_seqs]

apply_fun(trained_params, test_oh_seqs[0])

DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [96]:
apply_fun(trained_params, test_oh_seqs[0]).shape

(25,)