In [18]:
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, serial, Sigmoid
from jax.nn import relu, sigmoid
from jax.experimental.stax import elementwise
from jax import random

from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMAvgHidden

In [55]:
import pandas as pd
from os import path
import numpy as np

In [23]:
DATA_DIR = "../../data"

In [102]:
train_df = pd.read_csv(path.join(DATA_DIR, "chen/deduplicated/chen_train_data.csv"), index_col=0)
train_df.head()

Unnamed: 0,Antibody_ID,heavy,light,Y
2073,6aod,EVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLE...,DIVMTKSPSSLSASVGDRVTITCRASQGIRNDLGWYQQKPGKAPKR...,0
1517,4yny,EVQLVESGGGLVQPGRSLKLSCAASGFTFSNYGMAWVRQTPTKGLE...,EFVLTQPNSVSTNLGSTVKLSCKRSTGNIGSNYVNWYQQHEGRSPT...,1
2025,5xcv,EVQLVESGGGLVQPGRSLKLSCAASGFTFSNYGMAWVRQTPTKGLE...,QFVLTQPNSVSTNLGSTVKLSCKRSTGNIGSNYVNWYQQHEGRSPT...,1
2070,6and,EVQLVESGGGLVQPGGSLRLSCAASGYEFSRSWMNWVRQAPGKGLE...,DIQMTQSPSSLSASVGDRVTITCRSSQSIVHSVGNTFLEWYQQKPG...,1
666,2xqy,QVQLQQPGAELVKPGASVKMSCKASGYSFTSYWMNWVKQRPGRGLE...,DIVLTQSPASLALSLGQRATISCRASKSVSTSGYSYMYWYQQKPGQ...,0


In [103]:
train_h_seqs = list(train_df["heavy"])
train_labels = list(train_df["Y"])
len(train_h_seqs)

1338

In [45]:
from jax.experimental.optimizers import adam
from jax import grad, jit

In [31]:
from jax_unirep.utils import seq_to_oh

In [17]:
init_fun, apply_fun = serial(
    AAEmbedding(10),
    mLSTM(1900),
    mLSTMAvgHidden(),
    Dense(512), 
    elementwise(relu),
    Dense(1),
    elementwise(sigmoid)
)

In [95]:
rng_key = random.PRNGKey(0)

batch_size = 8
num_classes = 2
input_shape = (173, 26, batch_size)
step_size = 0.1
num_steps = 10

In [50]:
_, init_params = init_fun(rng_key, input_shape)

In [110]:
def loss(params, batch):
    inputs, targets = batch
    print(inputs.shape)
    #print(params)
    logits = apply_fun(params, inputs)
    log_p = jax.nn.log_sigmoid(logits)
    log_not_p = jax.nn.log_sigmoid(-logits)
    
    return -targets * log_p - (1. - targets) * log_not_p

In [73]:
def pad_seq(seq, max_len=171):
    if len(seq) < max_len:
        seq = seq + "-" * (max_len - len(seq))
    return seq

In [75]:
train_h_seqs = [pad_seq(seq) for seq in train_h_seqs]

In [97]:
def get_batches():
    oh_seqs = [seq_to_oh(seq) for seq in train_h_seqs]
    labels = train_labels
    num_batches = len(labels) // batch_size
    for i in range(num_batches):
        x = np.swapaxes(np.swapaxes(np.asarray(oh_seqs[i*batch_size : (i+1)*batch_size]), 0, 1), 1, 2)
        
        yield x, np.asarray(labels[i*batch_size : (i+1)*batch_size])
    

In [98]:
opt_init, opt_update, get_params = adam(step_size)
batches = get_batches()

In [99]:
@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

In [100]:
opt_state = opt_init(init_params)

In [111]:
for i in range(num_steps):
    opt_state = update(i, opt_state, next(batches))
trained_params = get_params(opt_state)

(173, 26, 8)


TypeError: dot_general requires contracting dimensions to have the same shape, got [8] and [26].