In [4]:
import os

from biobeaker.utils import get_angles, positional_encoding
from biobeaker import BEAKER

import tensorflow as tf

import numpy as np
import time
import pyracular

#from lib.useful import (
#    calc_kmer_numeric_tuple,
#    convert_tuple_to_string,
#    calc_distance,
#    convert_tuple_to_np,
#    cos_sim,
#    convert_string_to_nparray,
#    convert_string_to_nparray_tuple,
#)
from tensorflow.keras.layers import (
    Dense,
    Embedding,
    Flatten,
    Lambda,
    Subtract,
    Input,
    Concatenate,
    AveragePooling1D,
    LocallyConnected1D,
    Conv1D,
    GaussianNoise,
    BatchNormalization,
    Reshape,
    GlobalAveragePooling1D,
    Dropout,
)
from tensorflow.keras.models import Model, Sequential

# Hyper parameters
k = 21
window_size = 32  # up to 511
num_layers = 8
embedding_dims = 32
output_dims = 128  # Output dims are also internal dims!
intermediate_dims = 256
num_heads = 8
dropout_rate = 0.15
max_positions = 512
batch_size = 64

transformer = BEAKER(
    num_layers,
    embedding_dims,
    output_dims,
    num_heads,
    intermediate_dims,
    max_positions,
    dropout=dropout_rate,
    attention_dropout=dropout_rate,
    activation=tf.keras.activations.gelu,
)

generator = BEAKER(
    12,
    embedding_dims,
    output_dims,
    16,
    256,
    max_positions,
    dropout=0.15,
    attention_dropout=0.15,
    activation=tf.keras.activations.gelu,
)


def matched_layer():
    return tf.keras.Sequential(
        [
            tf.keras.layers.Dense(1, activation="linear"),
        ],
        name="Matched",
    )


def rc_layer(neurons, activation="relu"):
    return tf.keras.Sequential(
        [
            tf.keras.layers.Dense(neurons, activation=activation),
            tf.keras.layers.Dense(neurons, activation=activation),
            tf.keras.layers.Dense(1, activation="linear"),
        ],
        name="Rc",
    )


def discriminator_layer(neurons, activation="relu"):
    return tf.keras.Sequential(
        [
            tf.keras.layers.Dense(neurons, activation=activation),
            tf.keras.layers.Dense(1, activation="linear"),
        ],
        name="Discriminator",
    )


#def reverso_layer():
#    return tf.keras.Sequential(
#        [
#            tf.keras.layers.Dense(
#                k * 5 * 3 * embedding_dims, activation=tf.nn.swish, name="Reverso"
#            ),
#            tf.keras.layers.Dense(k * 5, name="ReversoOutput"),
#            #tf.keras.layers.Reshape((window_size, k, 5)),
#            #tf.keras.layers.Softmax(axis=-1),
#            tf.keras.layers.Reshape((window_size, k, 5)),
#        ],
#        name="Reverso",
#    )

reverso = Dense(k * 5 * 3 * embedding_dims, activation=tf.nn.swish, name="Reverso")
# reverso1 = Dense(1024, name="Reverso1")
reverso_output = Dense(k * 5, name="ReversoOutput")
reshaped = tf.keras.layers.Reshape((window_size, k, 5))
reshaped.trainable = False
reverso_output.trainable = False
reverso.trainable = False

reverso_layer = tf.keras.Sequential(name="Reverso")
reverso_layer.add(reverso)
reverso_layer.add(reverso_output)
reverso_layer.add(reshaped)

reverso.trainable = False

magic = Dense(
    embedding_dims,
    activation="linear",
    name="Magic",
    use_bias=False,
    trainable=False,
    dtype=tf.float32,
)
EPOCHS = 12

cls = np.asarray([[1] * 105])

# Define the model
# Input is 2 with mask (and CLS token)
# input is another 2 without mask (also with CLS token)
batch_input = Input(
    shape=(2, window_size + 1, k * 5), dtype="float32", name="BatchInput"
)

mask = Input(shape=(2, window_size), dtype="float32", name="Mask")
truth = Input(shape=(2, window_size + 1, k * 5), dtype="float32", name="BatchInputTrue")

contexts_a = magic(batch_input[:, 0])
contexts_b = magic(batch_input[:, 1])

contexts_a_true = magic(truth[:, 0])
contexts_b_true = magic(truth[:, 1])

BackToEmbeddings = tf.keras.layers.Dense(embedding_dims, use_bias=False)

# Generator - Train to replace mask token
generated_a, _, _ = generator(contexts_a, training=True) #, mask=mask[:, 0])
generated_b, _, _ = generator(contexts_b, training=True) #, mask=mask[:, 1])
generated_a = BackToEmbeddings(generated_a)
generated_b = BackToEmbeddings(generated_b)

# Need to softmax generated_a, b to one-hot embeddings to feed into the discriminator
# Also probably a good idea to train on the embedding weights rather than the one-hot kmers, need to think about it
# Also need to replace the [MASK] tokens with what the generator guesses, so that it's not entirely random (prevent generator and discrimiinator talking to each other)

enc_outputs_a, _, _ = transformer(generated_a, training=True)
enc_outputs_b, _, _ = transformer(generated_b, training=True)

Matched = matched_layer()

Rc = rc_layer(512)
DropRc = Dropout(dropout_rate)

Discriminator = discriminator_layer(512)
DropDiscriminator = Dropout(dropout_rate)

CosSim = tf.keras.layers.Dot(axes=-1, normalize=True)
CosSimNoise = tf.keras.layers.GaussianNoise(0.05)

# TODO: I think this is only looking at the CLS token...
# out1 = Matched(DropMatched(tf.concat([enc_outputs_b[:, 0], enc_outputs_a[:, 0]], axis=-1)))
out1 = Matched(CosSimNoise(CosSim([enc_outputs_a[:, 0], enc_outputs_b[:, 0]])))
out2 = Rc(DropRc(tf.concat([out1, enc_outputs_b[:, 0], enc_outputs_a[:, 0]], axis=-1)))

out0a = tf.squeeze(Discriminator(DropDiscriminator(enc_outputs_a[:, 1:])), name="Dis0")
out0b = tf.squeeze(Discriminator(DropDiscriminator(enc_outputs_b[:, 1:])), name="Dis1")

generator1_reversed = reverso_layer(generated_a[:, 1:])  # * mask[:, 0]
generator2_reversed = reverso_layer(generated_b[:, 1:])  # * mask[:, 1]

#gen_loss_a = tf.math.reduce_sum(tf.math.square(generated_a - contexts_a_true), axis=-1)
#gen_loss_b = tf.math.reduce_sum(tf.math.square(generated_b - contexts_b_true), axis=-1)

model = Model(
    inputs=[batch_input, mask, truth],
    outputs=[out0a, out0b, out1, out2, generator1_reversed, generator2_reversed] #, generated_a, generated_b],
)

# Load up the weights
weights = np.load(
    #"weights/weights_wide_singlelayer_k21_3Aug2020model_21_dims_32_epochs256.npy",
    "weights/wide_singlelayer_k21_23Apr2023_linear_nucleotide_model_magic_dims_32_epochs_256.npy",
    allow_pickle=True,
)
magic.set_weights([weights[0]])

reverso_layer.load_weights("weights/wide_singlelayer_k21_23Apr2023_linear_nucleotide_model_reverso_dims_32_epochs_256")

# Define the generators
cls = np.asarray([[1] * 105])


def valid_gen():
    fasta = pyracular.TripleLossKmersGenerator(
        k,
        "/mnt/data/nt/nt.sfasta",
        0.15,
        window_size,
        8192,
        2,
        42,
    )
    for i in fasta:
        if len(i.kmers1) != window_size or len(i.kmers2) != window_size:
            continue
        kmers = list()
        kmers.extend([np.concatenate([cls, i.kmers1]).tolist()])
        kmers.extend([np.concatenate([cls, i.kmers2]).tolist()])
        kmers.extend([np.concatenate([cls, i.kmers3]).tolist()])
        kmers.extend([np.concatenate([cls, i.kmers4]).tolist()])

        yield kmers, (i.truth1, i.truth2, i.matched, i.reversecomplement, 0, 0)
    print("=================Finished Training generator=================")


fakemask = np.ones((2, window_size))


def gen():
    fasta = pyracular.TripleLossKmersGenerator(
        k,
        "/mnt/data/nt/nt.sfasta",
        0.15,
        window_size,
        8192,
        4,
        42,
    )
    for i in fasta:
        # print(len(i.kmers1), len(i.kmers2), len(i.truth1), len(i.truth2))
        # Until we do masking
        if len(i.kmers1) != window_size or len(i.kmers2) != window_size:
            continue
        if len(i.truth1) != window_size or len(i.truth2) != window_size:
            continue

        kmers = list()
        kmers.extend([np.concatenate([cls, i.kmers1]).tolist()])
        kmers.extend([np.concatenate([cls, i.kmers2]).tolist()])
        
        kmers_true = list()
        kmers_true.extend([np.concatenate([cls, i.kmers3]).tolist()])
        kmers_true.extend([np.concatenate([cls, i.kmers4]).tolist()])

        yield (kmers, fakemask, kmers_true), (
            i.truth1,
            i.truth2,
            i.matched,
            i.reversecomplement,
            np.reshape(i.kmers3, (window_size, k, 5)),
            np.reshape(i.kmers4, (window_size, k, 5)),
#            np.dot(kmers_true[0], weights[0][0]),
#            np.dot(kmers_true[1], weights[0][0])
        )
    print("=================Finished Training generator=================")


output_sig = (
    (
        tf.TensorSpec(shape=(2, window_size + 1, k * 5), dtype=tf.int16),
        tf.TensorSpec(shape=(2, window_size), dtype=tf.int16),
        tf.TensorSpec(shape=(2, window_size + 1, k * 5), dtype=tf.int16),
        # tf.TensorSpec(shape=window_size, dtype=tf.int16),
    ),
    (
        tf.TensorSpec(shape=window_size, dtype=tf.int16),
        tf.TensorSpec(shape=window_size, dtype=tf.int16),
        tf.TensorSpec(shape=(), dtype=tf.int16),
        tf.TensorSpec(shape=(), dtype=tf.int16),
        tf.TensorSpec(shape=(window_size, k, 5), dtype=tf.float32),
        tf.TensorSpec(shape=(window_size, k, 5), dtype=tf.float32),
#        tf.TensorSpec(shape=(window_size+1, embedding_dims), dtype=tf.float32),
#        tf.TensorSpec(shape=(window_size+1, embedding_dims), dtype=tf.float32),
    ),
)

# ds = tf.data.Dataset.from_generator(valid_gen, output_signature=output_sig)
# validation_generator = ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

ds = tf.data.Dataset.from_generator(gen, output_signature=output_sig)
training_generator = ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

checkpoint_path = "beaker_medium_nt_triple2023_generator/model_{epoch:04d}.ckpt"

latest = tf.train.latest_checkpoint("beaker_medium_nt_triple2023_generator/")
if latest:
    print("Loading checkpoint")
    print(latest)
    model.load_weights(latest).expect_partial()
    print("Checkpoint loaded")
else:
    print("Checkpoint NOT loaded")

Loading checkpoint
beaker_medium_nt_triple2023_generator/model_0022.ckpt
Checkpoint loaded


In [5]:
kmer_generator = gen()
test_data = next(kmer_generator)

In [28]:
ctx_a = magic(np.reshape(test_data[0][0][0], (window_size+1, 105)))

In [46]:
(generator_output, generator_attention, generator_all_layer_outputs) = generator(np.reshape(ctx_a, (1, window_size+1, embedding_dims)), training=True)

In [47]:
generator_output[0][0]

<tf.Tensor: shape=(128,), dtype=float32, numpy=
array([ 2.44001322e-03,  1.44889008e-03,  1.29383616e-03, -1.21712673e-03,
        3.37229460e-03, -5.86072169e-03,  6.66994601e-05, -1.88085437e-03,
        5.69372962e-04, -2.30453094e-03, -6.03464316e-04,  1.64905377e-03,
       -8.55598832e-04, -5.82502689e-05, -3.92143120e-05,  3.05759604e-03,
        1.08458754e-03, -7.56028632e-04, -2.13995762e-03,  2.63250829e-03,
       -2.28333450e-03, -5.74389212e-02,  3.04994499e-03, -9.44274478e-04,
        8.56356397e-02, -1.53848098e-03,  4.69975639e-05,  5.65687835e-04,
       -1.15663686e-04, -2.28852988e-03, -6.87562686e-04, -5.12992847e-04,
       -4.02321981e-04,  1.16367929e-03, -4.45898128e-04,  1.96087733e-03,
        7.23734731e-04,  1.15475792e-03,  2.47768499e-03, -7.60250492e-04,
       -1.23476842e-03, -1.55980468e-01,  9.50815913e-04, -1.54257624e-03,
        3.27342749e-03,  2.10379739e-03,  2.88845808e-03,  1.32316153e-03,
       -1.41607912e-03,  1.04532810e-03, -1.20757357

In [48]:
generator_output[0][1]

<tf.Tensor: shape=(128,), dtype=float32, numpy=
array([ 2.49457802e-03,  1.39353215e-03,  1.30054983e-03, -1.14617858e-03,
        3.34637635e-03, -5.84970368e-03,  1.59585034e-05, -1.93903805e-03,
        5.06711542e-04, -2.26109102e-03, -5.56816987e-04,  1.61606190e-03,
       -8.28404271e-04, -4.48737992e-05, -5.61327761e-05,  3.03424033e-03,
        1.03552872e-03, -8.24744289e-04, -2.11091898e-03,  2.62213894e-03,
       -2.29112641e-03, -5.74309081e-02,  2.99665006e-03, -9.72570269e-04,
        8.57128277e-02, -1.48681039e-03,  1.16241863e-05,  5.69468248e-04,
       -8.81489832e-05, -2.25469610e-03, -7.02082063e-04, -5.08196303e-04,
       -4.37965791e-04,  1.10184506e-03, -4.46045655e-04,  2.01555691e-03,
        6.48476125e-04,  1.11689349e-03,  2.47819489e-03, -7.70954997e-04,
       -1.33185578e-03, -1.55852929e-01,  9.17433004e-04, -1.56795653e-03,
        3.30685778e-03,  2.04468891e-03,  2.87029659e-03,  1.31656940e-03,
       -1.46325061e-03,  1.09212927e-03, -1.35131890

In [54]:
generator_all_layer_outputs[0][0][0]

<tf.Tensor: shape=(256,), dtype=float32, numpy=
array([-1.32322896e+00,  1.54648006e+00,  3.00923109e-01,  2.52391934e-01,
       -3.61271560e-01,  9.63469207e-01, -1.79152284e-03,  4.56660688e-01,
        7.82799304e-01,  3.85050356e-01, -1.14001071e+00, -6.28199995e-01,
       -1.28034878e+00,  5.46714425e-01, -1.99703217e-01, -3.70031536e-01,
        5.02494097e-01, -7.38070428e-01, -5.92890501e-01, -9.04862210e-03,
       -9.62855071e-02, -2.94313729e-01,  3.81421417e-01,  1.53125036e+00,
        5.87318480e-01,  7.61968195e-01,  3.83370519e-01, -2.12361240e+00,
        9.20791209e-01,  2.38686487e-01, -1.20400560e+00, -7.24162817e-01,
        1.13833308e+00,  1.06043684e+00, -5.61481416e-01, -1.21233284e-01,
       -7.33564615e-01, -7.44567275e-01,  3.46341319e-02,  8.89643490e-01,
       -6.07737780e-01,  9.78815079e-01, -8.91083479e-02, -8.29121768e-01,
       -4.27011877e-01,  1.51709485e+00, -1.74635148e+00,  3.97614211e-01,
       -2.26654336e-01, -1.18676141e-01, -2.27535224

In [55]:
generator_all_layer_outputs[0][0][15]

<tf.Tensor: shape=(256,), dtype=float32, numpy=
array([-1.30507898e+00,  1.79238689e+00,  1.90181136e-01, -4.80215810e-02,
       -4.23910707e-01,  8.13567936e-01,  1.74033523e-01,  7.93106318e-01,
        1.18630683e+00, -3.28992568e-02, -5.85105777e-01, -1.84512150e+00,
       -1.04040444e+00,  2.07654357e-01, -1.20440453e-01, -4.57692385e-01,
        4.38072532e-01, -8.13308120e-01, -3.99727076e-01, -5.26226424e-02,
        3.91571671e-01, -4.86130774e-01,  6.08992398e-01,  1.34339035e+00,
        2.72873789e-01,  6.62016690e-01, -1.57131076e-01, -1.68226826e+00,
        3.66032815e+00,  6.85176849e-01, -8.76659036e-01,  2.02806666e-01,
        9.22871888e-01,  7.88081348e-01, -5.09468257e-01, -6.03425577e-02,
       -1.33620739e+00, -7.29181826e-01, -3.38607393e-02,  4.61449921e-01,
       -7.52877891e-01,  7.56183863e-01,  1.60372376e-01, -8.11421692e-01,
       -9.81155932e-01,  9.72118855e-01, -1.59920001e+00,  6.68141484e-01,
       -3.34355012e-02, -8.93137679e-02, -2.58905578

In [59]:
generator.encoder.dense1.weights

[<tf.Variable 'beaker_3/encoder_3/dense_70/kernel:0' shape=(32, 240) dtype=float32, numpy=
 array([[ 0.05771315,  0.13419619,  0.04603862, ...,  0.01010219,
         -0.12283175,  0.14422467],
        [-0.12149652, -0.04348774, -0.0843328 , ...,  0.03878966,
          0.12934925,  0.11111604],
        [-0.10178564,  0.09992684,  0.11263457, ..., -0.01479641,
          0.10504636,  0.00113237],
        ...,
        [-0.03085393,  0.11351393,  0.02083959, ..., -0.0644629 ,
         -0.00628267, -0.00065875],
        [ 0.09761   ,  0.11956098, -0.05428982, ...,  0.09440955,
          0.11596511, -0.05242577],
        [ 0.04012604, -0.13326737, -0.06339702, ...,  0.02430919,
         -0.06000339,  0.11635048]], dtype=float32)>,
 <tf.Variable 'beaker_3/encoder_3/dense_70/bias:0' shape=(240,) dtype=float32, numpy=
 array([ 4.8744013e-03,  7.1627819e-03, -6.1904499e-03,  7.4617970e-03,
         5.9656464e-03, -6.0227849e-03, -5.7997690e-03, -5.4343157e-03,
        -4.2299074e-03, -6.4181075e-

In [60]:
test_data[1]

([True,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  False,
  False,
  True,
  True,
  False,
  True,
  True,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  True],
 [True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  False,
  True,
  False,
  True,
  True,
  False,
  True,
  True,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True],
 False,
 False,
 array([[[ True, False, False, False, False],
         [ True, False, False, False, False],
         [False, False, False, False,  True],
         ...,
         [False, False, False,  True, False],
         [ True, False, False, False, False],
         [ True, False, False, False, False]],
 
        [[False, False, False,  True, False],
         [ True, False, False, False, False],
         [ True, False, False, False, False],
         ...,
         [Fals