### Import and global config

In [1]:
%env TF_CPP_MIN_LOG_LEVEL=2

env: TF_CPP_MIN_LOG_LEVEL=2


In [2]:
from datetime import datetime
from itertools import product
import re
import json
from math import log
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.data import Dataset
from tensorflow.keras import backend

In [3]:
def check_and_set_gpu():
    gpu_list = tf.config.list_physical_devices('GPU')
    if len(gpu_list) == 0:
        print("No available GPU!")
    else:
        try:
            tf.config.experimental.set_memory_growth(gpu_list[0], True)
            print("Enable VRAM growth")
        except e:
            print(e)
check_and_set_gpu()

Enable VRAM growth


In [4]:
# for reproducibility
SEED = 42
keras.utils.set_random_seed(SEED)
tf.config.experimental.enable_op_determinism()

# for network debugging
# tf.debugging.enable_check_numerics()
# tf.debugging.experimental.enable_dump_debug_info(  #  incompatible with enable_op_determinism()
#     f"./logs/tfdbg2_logdir",
#     tensor_debug_mode="FULL_HEALTH",
#     circular_buffer_size=-1
# )

### load data

In [5]:
DATA_PATH = "./data/airrship_shm_seed42_10_000.tsv"
df_data = pd.read_csv(DATA_PATH, sep="\t")
df_data.head()

Unnamed: 0,sequence_id,sequence,productive,stop_codon,vj_in_frame,v_call,d_call,j_call,junction,junction_aa,...,d_sequence_start,d_sequence_end,j_sequence_start,j_sequence_end,shm_events,shm_count,shm_freq,unmutated_sequence,gapped_unmutated_sequence,gapped_mutated_sequence
0,0,CAGGTGCAGCTGCGGGAGTCGGGCCCAGGGCTGGTGAAGCCTTTGG...,T,T,F,IGHV4-61*08,IGHD3-3*02,IGHJ4*01,TGCGCGAGGCCGCCAGGTGTATCAGCATTTAGGAGGACACCCGCTT...,CARPPGVSAFRRTPAWDFDPW,...,307,318,338,382,"14:A>G,30:A>G,44:C>T,68:C>A,84:C>T,89:G>C,90:C...",28,0.073298,CAGGTGCAGCTGCAGGAGTCGGGCCCAGGACTGGTGAAGCCTTCGG...,CAGGTGCAGCTGCAGGAGTCGGGCCCA...GGACTGGTGAAGCCTT...,CAGGTGCAGCTGCGGGAGTCGGGCCCA...GGGCTGGTGAAGCCTT...
1,1,CAGGTCACTTTGAGGGAGTCTGGTCCTGCGCTGGTGAAACCCACAC...,T,T,F,IGHV2-70*19,IGHD2-8*02,IGHJ4*01,TGTGCACGGGGGCATGTCCACGATAGGGTCTTTCCGAGAGTTGACT...,CARGHVHDRVFPRVDFW,...,310,313,329,370,"9:C>T,72:C>T,82:T>C,85:C>A,88:A>T,100:A>G,105:...",20,0.054054,CAGGTCACCTTGAGGGAGTCTGGTCCTGCGCTGGTGAAACCCACAC...,CAGGTCACCTTGAGGGAGTCTGGTCCT...GCGCTGGTGAAACCCA...,CAGGTCACTTTGAGGGAGTCTGGTCCT...GCGCTGGTGAAACCCA...
2,2,GAGGTGCAGCTCCCGGAGTCTGGGGGCGGCCTGGTACAGCCTGGGG...,T,T,F,IGHV3-23*03,IGHD5-24*01,IGHJ4*01,TGTGCGAGAGACGGAAAAAAGAGACCCGACTGG,CARDGKKRPDW,...,305,309,314,349,"12:G>C,13:T>C,14:T>C,27:A>C,31:T>C,71:C>T,90:C...",34,0.097421,GAGGTGCAGCTGTTGGAGTCTGGGGGAGGCTTGGTACAGCCTGGGG...,GAGGTGCAGCTGTTGGAGTCTGGGGGA...GGCTTGGTACAGCCTG...,GAGGTGCAGCTCCCGGAGTCTGGGGGC...GGCCTGGTACAGCCTG...
3,3,CAGGTGCAGCTGGTGGAGTCTGGGGGAGGCGTGGACCAGCCTGGGA...,T,T,F,IGHV3-33*05,IGHD3/OR15-3a*01,IGHJ4*01,TGTGCGAGAGACAAAAATTTGGGACTGGCCGGGAACTTCTTTGACT...,CARDKNLGLAGNFFDYW,...,304,313,326,367,"35:T>A,72:G>T,92:G>C,98:G>A,132:G>A,151:A>T,15...",17,0.046322,CAGGTGCAGCTGGTGGAGTCTGGGGGAGGCGTGGTCCAGCCTGGGA...,CAGGTGCAGCTGGTGGAGTCTGGGGGA...GGCGTGGTCCAGCCTG...,CAGGTGCAGCTGGTGGAGTCTGGGGGA...GGCGTGGACCAGCCTG...
4,4,GAGGTGCAGCTGGTGGAGTCTGGGGGAGGCTTAGTTCAGCCTGGGG...,T,T,F,IGHV3-74*02,IGHD1-26*01,IGHJ4*01,TGTGCAAGACAAGTGGGGGGCAATATCGACCACCTTTCGAAATACT...,CARQVGGNIDHLSKYYW,...,298,301,329,367,"89:G>C,93:C>T,97:T>C,119:C>T,138:G>A,147:A>C,1...",14,0.038147,GAGGTGCAGCTGGTGGAGTCTGGGGGAGGCTTAGTTCAGCCTGGGG...,GAGGTGCAGCTGGTGGAGTCTGGGGGA...GGCTTAGTTCAGCCTG...,GAGGTGCAGCTGGTGGAGTCTGGGGGA...GGCTTAGTTCAGCCTG...


In [6]:
pos_names = [f"{seg}_sequence_{pos}" for seg in "vdj" for pos in ("start", "end")]
sequences = df_data["sequence"]
positions = df_data[pos_names]

display(sequences.notna().all())
display(positions.notna().all())

True

v_sequence_start    True
v_sequence_end      True
d_sequence_start    True
d_sequence_end      True
j_sequence_start    True
j_sequence_end      True
dtype: bool

In [7]:
ds_seq = Dataset.from_tensor_slices(sequences.to_numpy())
ds_pos = Dataset.from_tensor_slices(positions.to_numpy())
ds_all = Dataset.zip((ds_seq, ds_pos))

display(ds_all.take(1).get_single_element())
display(ds_all.cardinality().numpy())

(<tf.Tensor: shape=(), dtype=string, numpy=b'CAGGTGCAGCTGCGGGAGTCGGGCCCAGGGCTGGTGAAGCCTTTGGAGACCCTGTCCCTCACCTGCAATGTCTCTGGTGGCTCTGTCACTAGTGGTGGTTACTACTGGAGTTGGGTCCGGCTGACCCCAGGGAAGGGACTGGACTGGATTGGTTTTCTTTATTACAGTGGGAGTACCAATTACAACCCCTCCCTCGAGACTCGAGTCACCATATCAGTAGACACGGCCAAGAACCAGTTCTCTCTGAAGGTGAGCTCTGTGACCGCTGCGGACACGGCCGTGTATTACTGCGCGAGGCCGCCAGGTGTATCAGCATTTAGGAGGACACCCGCTTGGGACTTTGACCCCTGGGGCCATGGAACCCTGGTCACCGTCTCCTCAG'>,
 <tf.Tensor: shape=(6,), dtype=int64, numpy=array([  1, 296, 307, 318, 338, 382])>)

10000

### Functions for encoding

In [8]:
def dna_onehot_tensor(seq):
    table = tf.lookup.StaticHashTable(
        initializer=tf.lookup.KeyValueTensorInitializer(
            keys=tf.constant(["A", "C", "G", "T"], dtype=tf.string),
            values=tf.constant([0, 1, 2, 3]),
        ),
        default_value=tf.constant(-1)
    )
    chars = tf.strings.bytes_split(seq)
    ind = table.lookup(chars)
    encoded = tf.one_hot(ind, depth=4)
    return encoded

# test dna_onehot_tensor
# test_oh = ds_all.take(6).map(
#     lambda x, y: (dna_onehot_tensor(x), y)
# )
# for padded_batch in test_oh.padded_batch(3):
#     display(padded_batch)

In [9]:
def get_kmer_tensor(seq, k):
    chars = tf.strings.bytes_split(seq)
    kmers = tf.strings.ngrams(chars, k, separator="")
    sentence = tf.strings.reduce_join(kmers, separator=" ")
    return sentence

# test get_kmer_tensor()
# test_kmer = ds_all.take(3).map(
#     lambda x, y: (get_kmer_tensor(x, 3), y)
# )
# display(test_kmer.batch(3).take(1).get_single_element())

In [10]:
def get_kmer_vocab(k):
    return ["".join(x) for x in product("ACGT", repeat=k)]

# test get_kmer_vocab()
# vocab_kmer = get_kmer_vocab(3)
# " ".join(vocab_kmer)

### Prepare datasets

In [11]:
TRAIN_SIZE = 8000
VALID_SIZE = 1000
TEST_SIZE = 1000

In [12]:
ds_train = ds_all.take(TRAIN_SIZE)
ds_not_train = ds_all.skip(TRAIN_SIZE)
ds_valid = ds_not_train.take(VALID_SIZE)
ds_test = ds_not_train.skip(VALID_SIZE)

assert ds_train.cardinality().numpy() == TRAIN_SIZE
assert ds_valid.cardinality().numpy() == VALID_SIZE
assert ds_test.cardinality().numpy() == TEST_SIZE

In [13]:
def transform_ds(ds, method, scale_factor, batch_size=None, shuffle_buffer=None, shuffle_seed=None):
    assert method == "onehot" or method.endswith("mer")
    assert (shuffle_buffer is not None) ^ (shuffle_seed is None)
    if method == "onehot":
        ds = ds.map(lambda x, y: (dna_onehot_tensor(x), y / scale_factor))
    else:
        k = int(method[0])
        ds = ds.map(lambda x, y: (get_kmer_tensor(x, k), y / scale_factor))
    if not batch_size:
        return ds
    if shuffle_buffer:
        ds = ds.shuffle(shuffle_buffer, seed=shuffle_seed, reshuffle_each_iteration=True)
    batched = ds.padded_batch(batch_size) if method == "onehot" else ds.batch(batch_size)
    return batched
                    
# # test encode_ds()
# for batch_seq, batch_pos in transform_ds(ds_train.take(6), "onehot", 450, 3, 5, 1):
#     display(batch_seq.shape)
#     display(batch_pos.shape)   
# display(batch_seq)
# display(batch_pos)
    
# for batch_seq, batch_pos in transform_ds(ds_train.take(6), "3mer", 450, 3, 5, 1):
#     display(batch_seq.shape)
#     display(batch_pos.shape)
# display(batch_seq)
# display(batch_pos)

# batched = transform_ds(ds_train, "3mer", 450, 32, 128, SEED)
# display(batched.cardinality().numpy())
# prefetched = batched.prefetch(8)
# display(prefetched.cardinality().numpy())

### Functions and classes for building model

In [14]:
class GRULNCell(keras.layers.GRUCell):
    def __init__(self, units, **kwargs):
        super().__init__(units, **kwargs)
        self.LN_x_to_gate = keras.layers.LayerNormalization(name="LN_x_to_gate")
        self.LN_h_to_gate = keras.layers.LayerNormalization(name="LN_h_to_gate")
        self.LN_x_to_recurrent = keras.layers.LayerNormalization(name="LN_x_to_recurrent")
        self.LN_h_to_recurrent = keras.layers.LayerNormalization(name="LN_h_to_recurrent")
        
    def call(self, inputs, states, training=None):
        h_tm1 = states[0] if tf.nest.is_nested(states) else states  # previous memory

        dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3)
        rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
            h_tm1, training, count=3
        )

        if self.use_bias:
            if not self.reset_after:
                input_bias, recurrent_bias = self.bias, None
            else:
                input_bias, recurrent_bias = tf.unstack(self.bias)

        if self.implementation == 1:
            if 0.0 < self.dropout < 1.0:
                inputs_z = inputs * dp_mask[0]
                inputs_r = inputs * dp_mask[1]
                inputs_h = inputs * dp_mask[2]
            else:
                inputs_z = inputs
                inputs_r = inputs
                inputs_h = inputs

            x_z = backend.dot(inputs_z, self.kernel[:, : self.units])
            x_r = backend.dot(inputs_r, self.kernel[:, self.units : self.units * 2])
            x_h = backend.dot(inputs_h, self.kernel[:, self.units * 2 :])
            
            # layer normalization: x to gate, x to recurrent
            x_concat = backend.concatenate([x_z, x_r])
            x_concat = self.LN_x_to_gate(x_concat)
            x_z = x_concat[:, :self.units]
            x_r = x_concat[:, self.units:]
            x_h = self.LN_x_to_recurrent(x_h)

            if self.use_bias:
                x_z = backend.bias_add(x_z, input_bias[: self.units])
                x_r = backend.bias_add(x_r, input_bias[self.units : self.units * 2])
                x_h = backend.bias_add(x_h, input_bias[self.units * 2 :])

            if 0.0 < self.recurrent_dropout < 1.0:
                h_tm1_z = h_tm1 * rec_dp_mask[0]
                h_tm1_r = h_tm1 * rec_dp_mask[1]
                h_tm1_h = h_tm1 * rec_dp_mask[2]
            else:
                h_tm1_z = h_tm1
                h_tm1_r = h_tm1
                h_tm1_h = h_tm1

            recurrent_z = backend.dot(
                h_tm1_z, self.recurrent_kernel[:, : self.units]
            )
            recurrent_r = backend.dot(
                h_tm1_r, self.recurrent_kernel[:, self.units : self.units * 2]
            )
            
            # layer normalization: h to gate
            recurrent_concat = backend.concatenate([recurrent_z, recurrent_r])
            recurrent_concat = self.LN_h_to_gate(recurrent_concat)
            recurrent_z = recurrent_concat[:, :self.units]
            recurrent_r = recurrent_concat[:, self.units:]
            
            if self.reset_after and self.use_bias:
                recurrent_z = backend.bias_add(
                    recurrent_z, recurrent_bias[: self.units]
                )
                recurrent_r = backend.bias_add(
                    recurrent_r, recurrent_bias[self.units : self.units * 2]
                )

            z = self.recurrent_activation(x_z + recurrent_z)
            r = self.recurrent_activation(x_r + recurrent_r)

            # reset gate applied after/before matrix multiplication
            if self.reset_after:
                recurrent_h = backend.dot(
                    h_tm1_h, self.recurrent_kernel[:, self.units * 2 :]
                )
                # layer normalization: h to recurrent
                recurrent_h = self.LN_h_to_recurrent(recurrent_h)
                if self.use_bias:
                    recurrent_h = backend.bias_add(
                        recurrent_h, recurrent_bias[self.units * 2 :]
                    )
                recurrent_h = r * recurrent_h
            else:
                recurrent_h = backend.dot(
                    r * h_tm1_h, self.recurrent_kernel[:, self.units * 2 :]
                )
                # layer normalization: h to recurrent
                recurrent_h = self.LN_h_to_recurrent(recurrent_h)

            hh = self.activation(x_h + recurrent_h)
        else:
            if 0.0 < self.dropout < 1.0:
                inputs = inputs * dp_mask[0]

            # inputs projected by all gate matrices at once
            matrix_x = backend.dot(inputs, self.kernel)
            
            # layer normalization: x to gate, x to recurrent
            x_concat = matrix_x[:, : self.units * 2]
            x_h = matrix_x[:, self.units * 2: ]
            x_concat = self.LN_x_to_gate(x_concat)
            x_h = self.LN_x_to_recurrent(x_h)
            matrix_x = backend.concatenate([x_concat, x_h])
            
            if self.use_bias:
                # biases: bias_z_i, bias_r_i, bias_h_i
                matrix_x = backend.bias_add(matrix_x, input_bias)

            x_z, x_r, x_h = tf.split(matrix_x, 3, axis=-1)

            if self.reset_after:
                # hidden state projected by all gate matrices at once
                matrix_inner = backend.dot(h_tm1, self.recurrent_kernel)
                
                # layer normalization: h to gate, h to recurrent
                recurrent_concat = matrix_inner[:, : self.units * 2]
                recurrent_h = matrix_inner[:, self.units * 2: ]
                recurrent_concat = self.LN_h_to_gate(recurrent_concat)
                recurrent_h = self.LN_h_to_recurrent(recurrent_h)
                matrix_inner = backend.concatenate([recurrent_concat, recurrent_h])
                
                if self.use_bias:
                    matrix_inner = backend.bias_add(
                        matrix_inner, recurrent_bias
                    )
            else:
                # hidden state projected separately for update/reset and new
                matrix_inner = backend.dot(
                    h_tm1, self.recurrent_kernel[:, : 2 * self.units]
                )
                # layer normalization: h to gate
                matrix_inner = self.LN_h_to_gate(matrix_inner)

            recurrent_z, recurrent_r, recurrent_h = tf.split(
                matrix_inner, [self.units, self.units, -1], axis=-1
            )

            z = self.recurrent_activation(x_z + recurrent_z)
            r = self.recurrent_activation(x_r + recurrent_r)

            if self.reset_after:
                recurrent_h = r * recurrent_h
            else:
                recurrent_h = backend.dot(
                    r * h_tm1, self.recurrent_kernel[:, 2 * self.units :]
                )
                # layer normalization: h to recurrent
                recurrent_h = self.LN_h_to_recurrent(recurrent_h)
            hh = self.activation(x_h + recurrent_h)
            
        # previous and candidate state mixed by update gate
        h = z * h_tm1 + (1 - z) * hh
        new_state = [h] if tf.nest.is_nested(states) else h
        return h, new_state

In [15]:
class ModelWriteGrad(keras.Model):
    
    def set_model(
        self, logdir, write_grad, 
        scalars=[], scalar_freq=1, histogram_freq=1, update_freq="batch", batch_per_epoch=None, 
        inspect_adam=False, inspect_clip=False, replace_nan=None, replace_inf=None
    ):
        # Functions for compute scalars of gradients
        self.scalar_funs = {
            "abs_max": lambda x: tf.reduce_max(tf.abs(x)),
            "abs_min": lambda x: tf.reduce_min(tf.abs(x)),
            "norm": lambda x: tf.norm(x)
        }
        
        # check arguments
        assert set(scalars).issubset(set(self.scalar_funs.keys()))
        assert update_freq in ("batch", "epoch")
        assert not (update_freq == "epoch" and batch_per_epoch is None)
        
        # initialize
        self.step = 0
        self.writer = tf.summary.create_file_writer(logdir + "/gradients") if write_grad else None
        self.write_grad = write_grad
        self.scalars = scalars
        self.scalar_freq = scalar_freq if update_freq == "batch" else scalar_freq * batch_per_epoch
        self.histogram_freq = histogram_freq if update_freq == "batch" else histogram_freq * batch_per_epoch
        self.update_freq = update_freq
        self.batch_per_epoch = batch_per_epoch
        self.inspect_clip = inspect_clip
        self.replace_nan = tf.constant(replace_nan, dtype=tf.float32) if replace_nan else False
        self.replace_inf = tf.constant(replace_inf, dtype=tf.float32) if replace_inf else False
        
        opt_param = self.optimizer.get_config()
        self.inspect_adam = (inspect_adam and opt_param["name"] == "Adam")
        if self.inspect_adam:
            var_num = len(self.trainable_variables)
            self.m = [0 for _ in range(var_num)]
            self.v = [0 for _ in range(var_num)]
            self.lr = opt_param["learning_rate"]
            self.beta_1 = opt_param["beta_1"]
            self.beta_2 = opt_param["beta_2"]
            self.epsilon = opt_param["epsilon"]
        if self.inspect_clip:
            self.clipvalue = opt_param.get("clipvalue")
            self.clipnorm = opt_param.get("clipnorm")
            
    def write_value(self, values, value_name, trainable_vars):
        with self.writer.as_default():
            for val, var in zip(values, trainable_vars):
                if isinstance(val, tf.IndexedSlices):
                    val = tf.convert_to_tensor(val)
                var_name = var.name.replace(':', '_')
                var_name = re.sub(r"/gru_cell_\d+", "", var_name)
                if self.histogram_freq != 0 and self.step % self.histogram_freq == 0:
                    tf.summary.histogram(f"{var_name}/{value_name}", val, step=self.step)
                if self.scalars and self.scalar_freq != 0 and self.step % self.scalar_freq == 0:
                    for scalar in self.scalars:
                        sum_name = f"{var_name}/{value_name}/{scalar}"
                        sum_val = self.scalar_funs[scalar](val)
                        tf.summary.scalar(sum_name, sum_val, step=self.step)
            self.writer.flush()
            
    def train_step(self, data):
        
        # Compute gradients
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        
        # Check and process nan and Inf
        for i in range(len(gradients)):
            grad = gradients[i]
            var = trainable_vars[i]
            var_name = var.name.replace(':', '_')
            var_name = re.sub(r"/gru_cell_\d+", "", var_name)
            
            nan_bool = tf.math.is_nan(grad)
            if self.write_grad and self.inspect_clip:
                show_nan = f"Batch {self.step}: Nan in gradients of {var_name}"
                tf.cond(tf.reduce_any(nan_bool), lambda: tf.print(show_nan), tf.no_op)
            if self.replace_nan:
                if isinstance(grad, tf.IndexedSlices):
                    grad = tf.convert_to_tensor(grad)
                grad = tf.where(nan_bool, self.replace_nan, grad)
                gradients[i] = grad
            
            inf_bool = tf.math.is_inf(grad)
            if self.write_grad and self.inspect_clip:
                show_inf = f"Batch {self.step}: Inf in gradients of {var_name}"
                tf.cond(tf.reduce_any(inf_bool), lambda: tf.print(show_inf), tf.no_op)
            if self.replace_inf:
                if isinstance(grad, tf.IndexedSlices):
                    grad = tf.convert_to_tensor(grad)
                pos_bool = inf_bool & (grad > 0)
                neg_bool = inf_bool & (grad < 0)
                grad = tf.where(pos_bool, self.replace_inf, grad)
                grad = tf.where(neg_bool, -1 * self.replace_inf, grad)
                gradients[i] = grad
                
            if self.write_grad and self.inspect_clip and self.clipvalue:
                show_clip = f"Batch {self.step}: clip value for gradients of {var_name}"
                clip_bool = tf.reduce_any(tf.abs(grad) > self.clipvalue)
                tf.cond(clip_bool, lambda: tf.print(show_clip), tf.no_op)
            if self.write_grad and self.inspect_clip and self.clipnorm:
                show_clip = f"Batch {self.step}: clip norm for gradients of {var_name}"
                clip_bool = tf.reduce_any(tf.norm(grad) > self.clipnorm)
                tf.cond(clip_bool, lambda: tf.print(show_clip), tf.no_op)
        
        # Compute adam values
        if self.write_grad and self.inspect_adam:
            t = self.step + 1
            lr = self.lr * tf.sqrt(1 - tf.pow(self.beta_2, t)) / (1 - tf.pow(self.beta_1, t))
            var_num = len(self.trainable_variables)
            root_v = [0 for i in range(var_num)]
            weight_updates = [0 for i in range(var_num)]
            for i in range(var_num):
                grad = gradients[i]
                if isinstance(grad, tf.IndexedSlices):
                    grad = tf.convert_to_tensor(grad)
                self.m[i] = self.m[i] * self.beta_1 + (1 - self.beta_1) * grad
                self.v[i] = self.v[i] * self.beta_2 + (1 - self.beta_2) * tf.square(grad)
                root_v[i] = tf.sqrt(self.v[i])
                weight_updates[i] = lr * self.m[i] / (root_v[i] + self.epsilon)
        
        # Record gradients
        if self.write_grad:
            self.write_value(gradients, "grads", trainable_vars)
            if self.inspect_adam:
                self.write_value(self.m, "m", trainable_vars)
                self.write_value(root_v, "root_v", trainable_vars)
                self.write_value(weight_updates, "updates", trainable_vars)
                
        # Update step, weights and metrics
        self.step += 1
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.compiled_metrics.update_state(y, y_pred)
        
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

In [16]:
# build the encoders for kmer
def create_kmer_encoder(k, embed_dim):
    text_vec_layer = keras.layers.TextVectorization(
        standardize=None,
        split="whitespace",
        vocabulary=get_kmer_vocab(k),
        name="text_vectorize"
    )
    embedding_layer = keras.layers.Embedding(
        input_dim=text_vec_layer.vocabulary_size(), 
        output_dim=embed_dim, 
        mask_zero=True,
        name="embedding"
    )

    kmer_seq = keras.Input(shape=(1,), dtype=tf.string, name=f"{k}mer_seq")
    tokens = text_vec_layer(kmer_seq)
    embedding = embedding_layer(tokens)
    kmer_encoder = keras.models.Model(kmer_seq, embedding, name=f"{k}mer_encoder")
    return kmer_encoder
    
# test_encoder = create_kmer_encoder(5, 64)
# test_encoder.summary()

In [17]:
# build the complete models
def create_model(
    encode_name, *, 
    EMBED_DIM, GRU_NUM, GRU_UNIT, GRU_BIDIRECT, GRU_LN, 
    DENSE_NUM, DENSE_UNIT, LEAKY_ALPHA, DENSE_NORM, OUT_ACTI
):
    
    # input and encoding part
    if encode_name == "onehot":
        model_input = keras.Input(shape=(None, 4), name="input")
        x = keras.layers.Masking(mask_value=[0.0, 0.0, 0.0, 0.0], name="mask")(model_input)
    else:
        k = int(encode_name[0])
        model_input = keras.Input(shape=(1, ), dtype=tf.string, name="input")
        kmer_encoder = create_kmer_encoder(k, EMBED_DIM)
        x = kmer_encoder(model_input)
    
    # gru part
    for i in range(GRU_NUM):
        gru_unit = GRU_UNIT if isinstance(GRU_UNIT, int) else GRU_UNIT[i]
        gru_bidirect = GRU_BIDIRECT if isinstance(GRU_BIDIRECT, bool) else GRU_BIDIRECT[i]
        gru_ln = GRU_LN if isinstance(GRU_LN, bool) or isinstance(GRU_LN, str) else GRU_LN[i]
        assert gru_ln in ("inside", "outside", False)
        not_last_gru = (i+1 != GRU_NUM)
        if gru_ln == "inside":
            gru_ln_cell = GRULNCell(gru_unit)
            gru_layer = keras.layers.RNN(gru_ln_cell, return_sequences=not_last_gru, name=f"gru_ln_{i+1}")
        else:
            gru_layer = keras.layers.GRU(gru_unit, return_sequences=not_last_gru, name=f"gru_{i+1}")
        if gru_bidirect:
            gru_layer = keras.layers.Bidirectional(gru_layer, name=f"gru_{i+1}")
        x = gru_layer(x)
        if gru_ln == "outside":
            x = keras.layers.LayerNormalization(name=f"gru_ln_{i+1}")(x)
    
    # dense part
    # dense_acti = lambda x: keras.activations.relu(x, alpha=LEAKY_ALPHA)
    for i in range(DENSE_NUM):
        dense_unit = DENSE_UNIT if isinstance(DENSE_UNIT, int) else DENSE_UNIT[i]
        dense_norm = DENSE_NORM if isinstance(DENSE_NORM, bool) or isinstance(DENSE_NORM, str) else DENSE_NORM[i]
        assert dense_norm in ("BN", "LN", False)
        if dense_norm == "BN":
            x = keras.layers.Dense(dense_unit, use_bias=False, activation=None, name=f"dense_{i+1}")(x)
            x = keras.layers.BatchNormalization(name=f"dense_bn_{i+1}")(x)
        elif dense_norm == "LN":
            x = keras.layers.Dense(dense_unit, use_bias=True, activation=None, name=f"dense_{i+1}")(x)
            x = keras.layers.LayerNormalization(name=f"dense_ln_{i+1}")(x)
        else:
            x = keras.layers.Dense(dense_unit, use_bias=True, activation=None, name=f"dense_{i+1}")(x)
        x = keras.layers.LeakyReLU(alpha=LEAKY_ALPHA, name=f"dense_leaky_relu_{i+1}")(x)
    
    # build the model
    model_output = keras.layers.Dense(6, activation=OUT_ACTI, name="output")(x)
    model = ModelWriteGrad(model_input, model_output, name=f"{encode_name}_model")
    return model

### Build model, compile and train

In [18]:
def pretty_json(hp):
    json_hp = json.dumps(hp, indent=2)
    return "".join("\t" + line for line in json_hp.splitlines(True))

In [19]:
# Custom metrics
class ScaledRMSE(keras.metrics.RootMeanSquaredError):
    def __init__(self, scale_factor, name="s_rmse", dtype=None):
        super().__init__(name, dtype=dtype)
        self.scale_factor = scale_factor
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.math.rint(tf.cast(y_true, self._dtype) * self.scale_factor)
        y_pred = tf.cast(y_pred, self._dtype) * self.scale_factor
        super().update_state(y_true, y_pred, sample_weight)
        
class ScaledMAE(keras.metrics.MeanAbsoluteError):
    def __init__(self, scale_factor, name="s_mae", dtype=None):
        super().__init__(name, dtype=dtype)
        self.scale_factor = scale_factor
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.math.rint(tf.cast(y_true, self._dtype) * self.scale_factor)
        y_pred = tf.cast(y_pred, self._dtype) * self.scale_factor
        super().update_state(y_true, y_pred, sample_weight)
        
class PosAccuracy(keras.metrics.Accuracy):
    def __init__(self, scale_factor, name="p_acc", dtype=None):
        super().__init__(name, dtype=dtype)
        self.scale_factor = scale_factor
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.math.rint(tf.cast(y_true, self._dtype) * self.scale_factor)
        y_pred = tf.math.rint(tf.cast(y_pred, self._dtype) * self.scale_factor)
        super().update_state(y_true, y_pred, sample_weight)
        
class AllAccuracy(keras.metrics.Accuracy):
    def __init__(self, scale_factor, name="a_acc", dtype=None):
        super().__init__(name, dtype=dtype)
        self.scale_factor = scale_factor
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.math.rint(tf.cast(y_true, self._dtype) * self.scale_factor)
        y_pred = tf.math.rint(tf.cast(y_pred, self._dtype) * self.scale_factor)
        all_correct_bool = tf.reduce_all(y_true == y_pred, axis=-1)
        correct_mat = tf.where(all_correct_bool, 1, 0)
        y_pred = tf.reshape(correct_mat, (-1, 1))
        y_true = tf.ones(tf.shape(y_true)[0], 1)
        super().update_state(y_true, y_pred, sample_weight)
        
class ClassAccuracy(keras.metrics.Metric):
    def __init__(self, scale_factor, name="c_acc", **kwargs):
        super().__init__(name=name, **kwargs)
        self.scale_factor = scale_factor
        self.len_sum = self.add_weight(name="len_sum", initializer="zeros")
        self.correct_sum = self.add_weight(name="correct_sum", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.math.rint(tf.cast(y_true, self._dtype) * self.scale_factor)
        y_pred = tf.math.rint(tf.cast(y_pred, self._dtype) * self.scale_factor)
        incorrect_count = tf.reduce_sum(tf.abs(y_true - y_pred))
        batch_len_sum = tf.reduce_sum(tf.gather(y_true, indices=(5), axis=-1))
        self.len_sum.assign_add(batch_len_sum)
        self.correct_sum.assign_add(batch_len_sum - incorrect_count)

    def result(self):
        return self.correct_sum / self.len_sum

    def reset_state(self):
        self.len_sum.assign(0.0)
        self.correct_sum.assign(0.0)
        
class SegmentPosAccuracy(PosAccuracy):
    def __init__(self, scale_factor, segment, name="p_acc", dtype=None):
        assert segment in ("V", "D", "J")
        super().__init__(scale_factor, name=f"{segment}_{name}", dtype=dtype)
        self.segment = segment
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        seg_idx_dict = {"V": (0, 1), "D": (2, 3), "J": (4, 5)}
        seg_idx = seg_idx_dict[self.segment]
        y_true_seg = tf.gather(y_true, indices=seg_idx, axis=-1)
        y_pred_seg = tf.gather(y_pred, indices=seg_idx, axis=-1)
        super().update_state(y_true_seg, y_pred_seg, sample_weight)

In [20]:
class LinearWarmup(keras.optimizers.schedules.LearningRateSchedule):

    def __init__(
        self, 
        after_warmup_lr_sched,
        warmup_steps,
        warmup_learning_rate,
        after_warmup_offset=None,
        name=None
    ):
        super().__init__()
        self._name = name
        self._after_warmup_lr_sched = after_warmup_lr_sched
        if isinstance(after_warmup_offset, int) and after_warmup_offset <= warmup_steps:
            self._after_warmup_offset = after_warmup_offset
        elif after_warmup_offset is None:
            self._after_warmup_offset = warmup_steps
        else:
            raise TypeError("after_warmup_offset must be int or None!")
        self._warmup_steps = warmup_steps
        self._init_warmup_lr = warmup_learning_rate
        if isinstance(after_warmup_lr_sched, keras.optimizers.schedules.LearningRateSchedule):
            self._final_warmup_lr = after_warmup_lr_sched(self._warmup_steps - self._after_warmup_offset)
        else:
            self._final_warmup_lr = tf.cast(after_warmup_lr_sched, dtype=tf.float32)

    def __call__(self, step):

        global_step = tf.cast(step, dtype=tf.float32)

        linear_warmup_lr = (
            self._init_warmup_lr + global_step / self._warmup_steps *
            (self._final_warmup_lr - self._init_warmup_lr)
        )

        if isinstance(self._after_warmup_lr_sched, keras.optimizers.schedules.LearningRateSchedule):
            after_warmup_lr = self._after_warmup_lr_sched(step - self._after_warmup_offset)
        else:
            after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32)

        lr = tf.where(global_step < self._warmup_steps, linear_warmup_lr, after_warmup_lr)
        return lr

    def get_config(self):
        if isinstance(self._after_warmup_lr_sched, keras.optimizers.schedules.LearningRateSchedule):
            config = {
                "after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()
            }  
        else:
            config = {"after_warmup_lr_sched": self._after_warmup_lr_sched}

        config.update({
            "after_warmup_offset": self._after_warmup_offset,
            "warmup_steps": self._warmup_steps,
            "warmup_learning_rate": self._init_warmup_lr,
            "name": self._name
        })
        return config
    
# test_schedule = keras.optimizers.schedules.CosineDecay(
#     1e-2, decay_steps=1000 - 100, alpha=1e-6, name="cos_decay"
# )
# test_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
#     1e-2,
#     batch_per_epoch,
#     t_mul=1.0,
#     m_mul=0.95,
#     alpha=1e-6,
#     name="cos_decay_restart"
# )
# test_schedule = LinearWarmup(test_schedule, after_warmup_offset=20, warmup_steps=100, warmup_learning_rate=1e-6)
# plt.plot(test_schedule(tf.range(1000)))
# plt.show()

In [21]:
# A callback class that stops the training when there is a sudden increase in one of the metrics 
class SpikeStopping(keras.callbacks.Callback):
    def __init__(self, monitor, fold_allow=2, min_allow=5, skip_first=10):
        super().__init__()
        self.monitor = monitor
        self.fold_allow = fold_allow
        self.min_allow = min_allow
        self.skip_first = skip_first
    
    def on_epoch_end(self, epoch, logs=None):
        current = logs.get(self.monitor)
        if epoch == 0:
            self.last = current
        if epoch >= self.skip_first and current > self.last * fold_allow + min_allow:
            print(f"Stop training at epoch {epoch}")
            self.model.stop_training = True
        self.last = current

In [22]:
# A callback class that records the training information
class WriteInfo(keras.callbacks.Callback):
    def __init__(self, logdir, info):
        super().__init__()
        self.writer = tf.summary.create_file_writer(logdir + "/info")
        self.info = info
    
    def on_train_begin(self, logs=None):
        model_info = {
            "MODEL_NAME": self.model.name,
            "MODEL_PARAM_NUM": self.model.count_params(),
        }
        self.info["MODEL_INFO"] = model_info
        info_json = pretty_json(self.info)
        with self.writer.as_default():
            tf.summary.text("training_setting", info_json, step=0)

In [23]:
# A callback class that records the training time
class WriteTrainingTime(keras.callbacks.Callback):
    def __init__(self, logdir, print_time=True):
        super().__init__()
        self.writer = tf.summary.create_file_writer(logdir + "/info")
        self.print_time = print_time
        self.epoch = 0
    
    def on_train_begin(self, logs=None):
        self.start = datetime.now()
        
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch

    def on_train_end(self, logs=None):
        duration = datetime.now() - self.start + datetime.min
        time_str = duration.strftime("%H:%M:%S.%f")
        if self.print_time:
            print(f"Time usage: {time_str}")
        with self.writer.as_default():
            tf.summary.text("training_time", time_str, step=self.epoch)

In [24]:
class WriteLearningRate(keras.callbacks.Callback):
    def __init__(self, logdir):
        super().__init__()
        self.writer = tf.summary.create_file_writer(logdir + "/lr")
    
    def on_batch_begin(self, batch, logs=None):
        lr = keras.backend.eval(self.model.optimizer.learning_rate)
        step = keras.backend.eval(self.model.optimizer.iterations)
        if isinstance(lr, keras.optimizers.schedules.LearningRateSchedule):
            val = lr(step)
        else:
            val = lr
        with self.writer.as_default():
            tf.summary.scalar("learning_rate", val, step=step)

In [34]:
settings = {
    "GRU_UNIT": [1024, 512, 256, 128, 256],
    "GRU_NUM": [1, 1, 6, 6, 3], 
    "GRU_BIDIRECT": [False, True, False, True, True], 
    "DENSE_UNIT": [128, 128, 256, 256, 256], 
    "DENSE_NUM": [4, 4, 1, 1, 4]
}
df_settings = pd.DataFrame(settings, index=range(1, 6), dtype="object")
df_settings

Unnamed: 0,GRU_UNIT,GRU_NUM,GRU_BIDIRECT,DENSE_UNIT,DENSE_NUM
1,1024,1,False,128,4
2,512,1,True,128,4
3,256,6,False,256,1
4,128,6,True,256,1
5,256,3,True,256,4


In [35]:
SETTING = 5
MODEL_STRUCT = {
    "EMBED_DIM": 64,
    "GRU_UNIT": df_settings.loc[SETTING, "GRU_UNIT"],
    "GRU_NUM": df_settings.loc[SETTING, "GRU_NUM"],
    "GRU_BIDIRECT": df_settings.loc[SETTING, "GRU_BIDIRECT"],
    "GRU_LN": False,
    "DENSE_UNIT": df_settings.loc[SETTING, "DENSE_UNIT"],
    "DENSE_NUM": df_settings.loc[SETTING, "DENSE_NUM"],
    "DENSE_NORM": False,
    "LEAKY_ALPHA": 0.1,
    "OUT_ACTI": "linear"
}
show_model = create_model("5mer", **MODEL_STRUCT)
show_model.summary()

Model: "5mer_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input (InputLayer)          [(None, 1)]               0         
                                                                 
 5mer_encoder (Functional)   (None, None, 64)          65664     
                                                                 
 gru_1 (Bidirectional)       (None, None, 512)         494592    
                                                                 
 gru_2 (Bidirectional)       (None, None, 512)         1182720   
                                                                 
 gru_3 (Bidirectional)       (None, 512)               1182720   
                                                                 
 dense_1 (Dense)             (None, 256)               131328    
                                                                 
 dense_leaky_relu_1 (LeakyRe  (None, 256)              0

In [27]:
# training setting
EPOCH = 500
BATCH_SIZE = 32
SCALE_FACTOR = 450
SHUFFLE_BUFFER = 128
PREFETCH_BUFFER = 8
LOSS = keras.losses.MeanSquaredError()

REPLACE_NAN = 1e-5
REPLACE_INF = 1e-5
EARLY_STOP = {"monitor": "loss", "min_delta": 0, "patience": 20}
# EARLY_STOP = False
STOP_NAN = True
WRITE_GRAD = False

train_info = {
    "SEED": SEED,
    "DATA": {
        "DATA_PATH": DATA_PATH,
        "TRAIN_SIZE": TRAIN_SIZE,
        "VALID_SIZE": VALID_SIZE,
        "TEST_SIZE": TEST_SIZE,
        "SCALE_FACTOR": SCALE_FACTOR,
        "SHUFFLE_BUFFER": SHUFFLE_BUFFER,
        "PREFETCH_BUFFER": PREFETCH_BUFFER
    },
    "MODEL_STRUCT": MODEL_STRUCT,
    "TRAIN_SET": {
        "EPOCH": EPOCH,
        "BATCH_SIZE": BATCH_SIZE,
        "LOSS": LOSS.get_config(),
        "REPLACE_NAN": REPLACE_NAN,
        "REPLACE_INF": REPLACE_INF,
        "EARLY_STOP": EARLY_STOP,
        "STOP_NAN": STOP_NAN,
        "WRITE_GRAD": WRITE_GRAD
    }
}

In [30]:
LR_SCHEDULE = LinearWarmup(
    LR_SCHEDULE, 
    warmup_steps=batch_per_epoch * WARMUP_EPOCH, 
    warmup_learning_rate=1e-6,
    name="linear_warmup"
)
LR_SCHEDULE.get_config()["name"]

'linear_warmup'

In [28]:
# start training
timestamp = datetime.now().strftime("%m%d-%H%M%S")
MAIN_LOGDIR = f"./logs/lr_schedule/chk_repro"

# encode_names = ("3mer", "4mer", "5mer", )
encode_names = ("5mer", )
metric_list = [
    ScaledRMSE(SCALE_FACTOR),
    ScaledMAE(SCALE_FACTOR),
    PosAccuracy(SCALE_FACTOR),
    AllAccuracy(SCALE_FACTOR),
    ClassAccuracy(SCALE_FACTOR),
    SegmentPosAccuracy(SCALE_FACTOR, segment="V"),
    SegmentPosAccuracy(SCALE_FACTOR, segment="D"),
    SegmentPosAccuracy(SCALE_FACTOR, segment="J")
]

all_history = {}
for name in ("onehot", "2mer", "3mer", "4mer", "5mer"):
    
    # build model and compile
    model = create_model(name, **MODEL_STRUCT)  # must build model first to ensure the initialization is the same
    if name not in encode_names:
        continue
    # display(name)
    # display(model.get_layer("dense_1").weights[0])

    # prepare dataset
    ds_train_input = transform_ds(ds_train, name, SCALE_FACTOR, BATCH_SIZE, SHUFFLE_BUFFER, SEED)
    ds_valid_input = transform_ds(ds_valid, name, SCALE_FACTOR, BATCH_SIZE)
    ds_train_input = ds_train_input.prefetch(PREFETCH_BUFFER)
    ds_valid_input = ds_valid_input.prefetch(PREFETCH_BUFFER)
    batch_per_epoch = int(ds_train_input.cardinality().numpy())
    
    # compile
    WARMUP_EPOCH = 3
    DECAY_EPOCH = 250
    LR_SCHEDULE = 1e-3
    # LR_SCHEDULE = keras.optimizers.schedules.CosineDecay(
    #     1e-3, decay_steps=batch_per_epoch * DECAY_EPOCH, alpha=0.3, name="cos_decay"
    # )

#     LR_SCHEDULE = tf.keras.optimizers.schedules.CosineDecayRestarts(
#         1e-3,
#         batch_per_epoch,
#         t_mul=2.0,
#         m_mul=0.95,
#         alpha=1e-3,
#         name="cos_decay_restart"
#     )

#     LR_SCHEDULE = tf.keras.optimizers.schedules.ExponentialDecay(
#         1e-3, 
#         decay_steps=batch_per_epoch * (DECAY_EPOCH // int(log(0.3, 0.95))), 
#         decay_rate=0.95, 
#         staircase=True, 
#         name="exp_decay"
#     )
    
    LR_SCHEDULE = LinearWarmup(
        LR_SCHEDULE, 
        warmup_steps=batch_per_epoch * WARMUP_EPOCH, 
        warmup_learning_rate=1e-6,
        name="linear_warmup"
    )
    OPTIMIZER = keras.optimizers.Adam(
        learning_rate=LR_SCHEDULE,
        clipnorm=1e-3
    )
    model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=metric_list, run_eagerly=WRITE_GRAD)
    if isinstance(LR_SCHEDULE, float):
        opt_info = {"OPTIMIZER": OPTIMIZER.get_config()}
    else:
        opt_info = {"WARMUP_EPOCH": WARMUP_EPOCH, "DECAY_EPOCH": DECAY_EPOCH, "OPTIMIZER": OPTIMIZER.get_config()}
    train_info["TRAIN_SET"].update(opt_info)
    
    # setting callbacks
    logdir = MAIN_LOGDIR + f"/{model.name}"
    tensorboard_cb = keras.callbacks.TensorBoard(
        logdir, 
        histogram_freq=max(1, EPOCH // 20),
        write_graph=False
    )
    # spike_stop_cb = SpikeStopping("s_rmse", fold_allow=1.5, min_allow=3, skip_first=10)
    write_info_cb = WriteInfo(logdir, train_info)
    write_time_cb = WriteTrainingTime(logdir)
    write_lr_cb = WriteLearningRate(logdir)
    cb_list = [
        tensorboard_cb,
        # spike_stop_cb,
        write_info_cb,
        write_time_cb,
        write_lr_cb
    ]
    if STOP_NAN:
        stop_nan_cb = keras.callbacks.TerminateOnNaN()
        cb_list.append(stop_nan_cb)
    if EARLY_STOP:
        early_stop_cb = keras.callbacks.EarlyStopping(**EARLY_STOP)
        cb_list.append(early_stop_cb)
    
    # train
    print(f"{model.name} training!")
    model.set_model(
        logdir, write_grad=WRITE_GRAD,
        scalars=["norm"], scalar_freq=0, histogram_freq=0, 
        update_freq="epoch", batch_per_epoch=batch_per_epoch,
        inspect_adam=False, inspect_clip=False, replace_nan=REPLACE_NAN, replace_inf=REPLACE_INF
    )
    all_history[name] = model.fit(
        ds_train_input,
        epochs=EPOCH, 
        validation_data=ds_valid_input,
        callbacks=cb_list,
        verbose=0
    )
    # model.save(logdir + "/model")

'5mer'

<tf.Variable 'dense_1/kernel:0' shape=(256, 256) dtype=float32, numpy=
array([[ 0.09940689, -0.04302669,  0.06853006, ...,  0.04871876,
        -0.01533163, -0.07063694],
       [ 0.08925147,  0.00025051, -0.08533851, ..., -0.09536003,
         0.04990416,  0.02982887],
       [ 0.04703335,  0.09673565,  0.02113684, ..., -0.09359347,
         0.10377292, -0.01580728],
       ...,
       [-0.05997573,  0.05840721,  0.0018357 , ..., -0.09509703,
        -0.10667745,  0.09354789],
       [-0.00074476, -0.09086062, -0.05029048, ..., -0.03969767,
        -0.07108419,  0.05995592],
       [-0.02960476, -0.01729961, -0.0616333 , ...,  0.06001755,
         0.03784332, -0.10430547]], dtype=float32)>

5mer_model training!


KeyboardInterrupt: 

### Result

In [None]:
performance = pd.DataFrame()
for i, name in enumerate(encode_names):
    df = pd.DataFrame(all_history[name].history)
    best_loss_idx = df["loss"].argmin()
    best_val_loss_idx = df["val_loss"].argmin()
    valid_metrics_idx = df.columns.str.startswith("val")
    train_metrics = df.columns[~valid_metrics_idx]
    valid_metrics = df.columns[valid_metrics_idx]
    performance.loc[i, "name"] = name
    performance.loc[i, "number_of_parameters"] = all_history[name].model.count_params()
    performance.loc[i, "best_loss_epoch"] = df.index[best_loss_idx]
    performance.loc[i, train_metrics] = df.loc[best_loss_idx, train_metrics]
    performance.loc[i, "best_val_loss_epoch"] = df.index[best_val_loss_idx]
    performance.loc[i, valid_metrics] = df.loc[best_loss_idx, valid_metrics]

performance = pd.DataFrame(performance)
performance.to_csv(MAIN_LOGDIR + "/performance.csv", index=False)
performance

In [None]:
# for name in encode_names:
#     print(name)
#     model = all_history[name].model
#     ds_test_input = transform_ds(ds_test, name, SCALE_FACTOR, BATCH_SIZE)
#     model.evaluate(ds_test_input)

In [None]:
# if DATA_PATH.endswith("10_000_processed.tsv"):
#     print("Train on random ends, test on not random ends")
#     df_other = pd.read_csv("./data/airrship_shm_seed42_10_000.tsv", sep="\t")
# else:
#     print("Train on not random ends, test on random ends")
#     df_other = pd.read_csv("./data/airrship_shm_seed42_10_000_processed.tsv", sep="\t")
    
# df_other = df_other.iloc[-1000:]
# print(df_other.shape)

# pos_names = [f"{seg}_sequence_{pos}" for seg in "vdj" for pos in ("start", "end")]
# seq_other = df_other["sequence"]
# pos_other = df_other[pos_names]

# ds_seq_other = Dataset.from_tensor_slices(seq_other.to_numpy())
# ds_pos_other = Dataset.from_tensor_slices(pos_other.to_numpy())
# ds_all_other = Dataset.zip((ds_seq_other, ds_pos_other))

# for name in encode_names:
#     print(name)
#     model = all_history[name].model
#     ds_other = transform_ds(ds_all_other, name, SCALE_FACTOR, BATCH_SIZE)
#     model.evaluate(ds_other)