In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
from unirep_reimplementation import aa_seq_to_int, aa_to_int, one_hots
import jax.numpy as np
from fundl.layers.rnn import mlstm1900
from fundl.weights import add_dense_params
from fundl.layers import dense
from fundl.activations import sigmoid

In [None]:
sequences = [
    "MRKGEELFTGVVPILVELDGDVNGHKFSVRGEGEGDATNGKLTLKFICTTGKLPVPWPTLVTTLTYGVQCFARYPDHMKQHDFFKSAMPEGYVQERTISFKDDGTYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNFNSHNVYITADKQKNGIKANFKIRHNVEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSVLSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
    "MRKGEELFTGVVPILVELDGDVGGHKFSVRGEGEGDATNGKLTLKFICTTGKLPVPWPTLVTTLTYGVQCFARYPDHMKQHDFFKSAMPEGYVQERTISFKDDGTYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNFNSHNVYITADKQKNGIKANFKIRHNVEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSVLSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
    "MRKGEELFTGVVPILVELDGDVGGHKFSVRGEGEGDATNGKLTLKFICTTGKLPVPWPTLVTTLTYGVQCFARYPDEMKQHDFFKSAMPEGYVQERTISFKDDGTYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNFNSHNVYITADKQKNGIKANFKIRHNVEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSVLSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
]
next_letters = [s[1:] for s in sequences]

sequences_int = [aa_seq_to_int(s) for s in sequences]
next_letters_int = [aa_seq_to_int(s) for s in next_letters]

embeddings = np.load("1900_weights/embed_matrix:0.npy")
x = np.stack([embeddings[i] for i in sequences_int], axis=0)[:, :-1, :]
y = np.stack([one_hots[i] for i in next_letters_int], axis=0)

# x = sliding_window(sequence, size=10)
params = dict()
params["unirep"] = dict()
params["unirep"]["gh"] = np.load("1900_weights/rnn_mlstm_mlstm_gh:0.npy")
params["unirep"]["gmh"] = np.load("1900_weights/rnn_mlstm_mlstm_gmh:0.npy")
params["unirep"]["gmx"] = np.load("1900_weights/rnn_mlstm_mlstm_gmx:0.npy")
params["unirep"]["gx"] = np.load("1900_weights/rnn_mlstm_mlstm_gx:0.npy")

params["unirep"]["wh"] = np.load("1900_weights/rnn_mlstm_mlstm_wh:0.npy")
params["unirep"]["wmh"] = np.load("1900_weights/rnn_mlstm_mlstm_wmh:0.npy")
params["unirep"]["wmx"] = np.load("1900_weights/rnn_mlstm_mlstm_wmx:0.npy")
params["unirep"]["wx"] = np.load("1900_weights/rnn_mlstm_mlstm_wx:0.npy")

params["unirep"]["b"] = np.load("1900_weights/rnn_mlstm_mlstm_b:0.npy")

params = add_dense_params(params, "dense", 1900, 26)

In [None]:
x.shape, y.shape

In [None]:
def next_sequence_model(params, x):
    """
    This model predicts next sequence.
    
    x.shape: (:, length_of_sequence, 10)

    output shape:
        (:, length_of_sequence, 20)
        The (:) dimension is the sample dimension.
        The data cube should be compared against a binary-valued tensor
        that contains the truth.
    """
    x = mlstm1900(params["unirep"], x)
    x = dense(params["dense"], x, nonlin=sigmoid)
    return x

In [None]:
from jax import vmap
out = next_sequence_model(params, x)


In [None]:
out.shape

In [None]:
from fundl.losses import neg_cross_entropy_loss
from functools import partial
from jax import grad, jit

loss = partial(neg_cross_entropy_loss, model=next_sequence_model)

loss(params, x=x, y=y)
dloss = jit(grad(loss))
g = dloss(params, x=x, y=y)

In [None]:
from jax.experimental.optimizers import adam

init, update, get_params = adam(step_size=0.005)

state = init(params)
for i in range(100):
    g = dloss(params, x=x, y=y)
    l = loss(params, x=x, y=y)

    state = update(i, g, state)
    params = get_params(state)

    print(i, l)
