### Import and global config

In [1]:
%env TF_CPP_MIN_LOG_LEVEL=2

env: TF_CPP_MIN_LOG_LEVEL=2


In [2]:
from datetime import datetime
from itertools import product
import re
import json
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.data import Dataset

In [3]:
def check_and_set_gpu():
    gpu_list = tf.config.list_physical_devices('GPU')
    if len(gpu_list) == 0:
        print("No available GPU!")
    else:
        try:
            tf.config.experimental.set_memory_growth(gpu_list[0], True)
            print("Enable VRAM growth")
        except e:
            print(e)
check_and_set_gpu()

Enable VRAM growth


In [4]:
# for reproducibility
SEED = 42
keras.utils.set_random_seed(SEED)
tf.config.experimental.enable_op_determinism()

# for network debugging
# tf.debugging.enable_check_numerics()
# tf.debugging.experimental.enable_dump_debug_info(  #  incompatible with enable_op_determinism()
#     f"./logs/tfdbg2_logdir",
#     tensor_debug_mode="FULL_HEALTH",
#     circular_buffer_size=-1
# )

### load data

In [5]:
DATA_PATH = "./data/airrship_shm_seed42_10_000.tsv"
df_data = pd.read_csv(DATA_PATH, sep="\t")
df_data.head()

Unnamed: 0,sequence_id,sequence,productive,stop_codon,vj_in_frame,v_call,d_call,j_call,junction,junction_aa,...,d_sequence_start,d_sequence_end,j_sequence_start,j_sequence_end,shm_events,shm_count,shm_freq,unmutated_sequence,gapped_unmutated_sequence,gapped_mutated_sequence
0,0,CAGGTGCAGCTGCGGGAGTCGGGCCCAGGGCTGGTGAAGCCTTTGG...,T,T,F,IGHV4-61*08,IGHD3-3*02,IGHJ4*01,TGCGCGAGGCCGCCAGGTGTATCAGCATTTAGGAGGACACCCGCTT...,CARPPGVSAFRRTPAWDFDPW,...,307,318,338,382,"14:A>G,30:A>G,44:C>T,68:C>A,84:C>T,89:G>C,90:C...",28,0.073298,CAGGTGCAGCTGCAGGAGTCGGGCCCAGGACTGGTGAAGCCTTCGG...,CAGGTGCAGCTGCAGGAGTCGGGCCCA...GGACTGGTGAAGCCTT...,CAGGTGCAGCTGCGGGAGTCGGGCCCA...GGGCTGGTGAAGCCTT...
1,1,CAGGTCACTTTGAGGGAGTCTGGTCCTGCGCTGGTGAAACCCACAC...,T,T,F,IGHV2-70*19,IGHD2-8*02,IGHJ4*01,TGTGCACGGGGGCATGTCCACGATAGGGTCTTTCCGAGAGTTGACT...,CARGHVHDRVFPRVDFW,...,310,313,329,370,"9:C>T,72:C>T,82:T>C,85:C>A,88:A>T,100:A>G,105:...",20,0.054054,CAGGTCACCTTGAGGGAGTCTGGTCCTGCGCTGGTGAAACCCACAC...,CAGGTCACCTTGAGGGAGTCTGGTCCT...GCGCTGGTGAAACCCA...,CAGGTCACTTTGAGGGAGTCTGGTCCT...GCGCTGGTGAAACCCA...
2,2,GAGGTGCAGCTCCCGGAGTCTGGGGGCGGCCTGGTACAGCCTGGGG...,T,T,F,IGHV3-23*03,IGHD5-24*01,IGHJ4*01,TGTGCGAGAGACGGAAAAAAGAGACCCGACTGG,CARDGKKRPDW,...,305,309,314,349,"12:G>C,13:T>C,14:T>C,27:A>C,31:T>C,71:C>T,90:C...",34,0.097421,GAGGTGCAGCTGTTGGAGTCTGGGGGAGGCTTGGTACAGCCTGGGG...,GAGGTGCAGCTGTTGGAGTCTGGGGGA...GGCTTGGTACAGCCTG...,GAGGTGCAGCTCCCGGAGTCTGGGGGC...GGCCTGGTACAGCCTG...
3,3,CAGGTGCAGCTGGTGGAGTCTGGGGGAGGCGTGGACCAGCCTGGGA...,T,T,F,IGHV3-33*05,IGHD3/OR15-3a*01,IGHJ4*01,TGTGCGAGAGACAAAAATTTGGGACTGGCCGGGAACTTCTTTGACT...,CARDKNLGLAGNFFDYW,...,304,313,326,367,"35:T>A,72:G>T,92:G>C,98:G>A,132:G>A,151:A>T,15...",17,0.046322,CAGGTGCAGCTGGTGGAGTCTGGGGGAGGCGTGGTCCAGCCTGGGA...,CAGGTGCAGCTGGTGGAGTCTGGGGGA...GGCGTGGTCCAGCCTG...,CAGGTGCAGCTGGTGGAGTCTGGGGGA...GGCGTGGACCAGCCTG...
4,4,GAGGTGCAGCTGGTGGAGTCTGGGGGAGGCTTAGTTCAGCCTGGGG...,T,T,F,IGHV3-74*02,IGHD1-26*01,IGHJ4*01,TGTGCAAGACAAGTGGGGGGCAATATCGACCACCTTTCGAAATACT...,CARQVGGNIDHLSKYYW,...,298,301,329,367,"89:G>C,93:C>T,97:T>C,119:C>T,138:G>A,147:A>C,1...",14,0.038147,GAGGTGCAGCTGGTGGAGTCTGGGGGAGGCTTAGTTCAGCCTGGGG...,GAGGTGCAGCTGGTGGAGTCTGGGGGA...GGCTTAGTTCAGCCTG...,GAGGTGCAGCTGGTGGAGTCTGGGGGA...GGCTTAGTTCAGCCTG...


In [6]:
pos_names = [f"{seg}_sequence_{pos}" for seg in "vdj" for pos in ("start", "end")]
sequences = df_data["sequence"]
positions = df_data[pos_names]

display(sequences.notna().all())
display(positions.notna().all())

True

v_sequence_start    True
v_sequence_end      True
d_sequence_start    True
d_sequence_end      True
j_sequence_start    True
j_sequence_end      True
dtype: bool

In [7]:
ds_seq = Dataset.from_tensor_slices(sequences.to_numpy())
ds_pos = Dataset.from_tensor_slices(positions.to_numpy())
ds_all = Dataset.zip((ds_seq, ds_pos))

display(ds_all.take(1).get_single_element())
display(ds_all.cardinality().numpy())

(<tf.Tensor: shape=(), dtype=string, numpy=b'CAGGTGCAGCTGCGGGAGTCGGGCCCAGGGCTGGTGAAGCCTTTGGAGACCCTGTCCCTCACCTGCAATGTCTCTGGTGGCTCTGTCACTAGTGGTGGTTACTACTGGAGTTGGGTCCGGCTGACCCCAGGGAAGGGACTGGACTGGATTGGTTTTCTTTATTACAGTGGGAGTACCAATTACAACCCCTCCCTCGAGACTCGAGTCACCATATCAGTAGACACGGCCAAGAACCAGTTCTCTCTGAAGGTGAGCTCTGTGACCGCTGCGGACACGGCCGTGTATTACTGCGCGAGGCCGCCAGGTGTATCAGCATTTAGGAGGACACCCGCTTGGGACTTTGACCCCTGGGGCCATGGAACCCTGGTCACCGTCTCCTCAG'>,
 <tf.Tensor: shape=(6,), dtype=int64, numpy=array([  1, 296, 307, 318, 338, 382])>)

10000

### Functions for encoding

In [8]:
def dna_onehot_tensor(seq):
    table = tf.lookup.StaticHashTable(
        initializer=tf.lookup.KeyValueTensorInitializer(
            keys=tf.constant(["A", "C", "G", "T"], dtype=tf.string),
            values=tf.constant([0, 1, 2, 3]),
        ),
        default_value=tf.constant(-1)
    )
    chars = tf.strings.bytes_split(seq)
    ind = table.lookup(chars)
    encoded = tf.one_hot(ind, depth=4)
    return encoded

# test dna_onehot_tensor
# test_oh = ds_all.take(6).map(
#     lambda x, y: (dna_onehot_tensor(x), y)
# )
# for padded_batch in test_oh.padded_batch(3):
#     display(padded_batch)

In [9]:
def get_kmer_tensor(seq, k):
    chars = tf.strings.bytes_split(seq)
    kmers = tf.strings.ngrams(chars, k, separator="")
    sentence = tf.strings.reduce_join(kmers, separator=" ")
    return sentence

# test get_kmer_tensor()
# test_kmer = ds_all.take(3).map(
#     lambda x, y: (get_kmer_tensor(x, 3), y)
# )
# display(test_kmer.batch(3).take(1).get_single_element())

In [10]:
def get_kmer_vocab(k):
    return ["".join(x) for x in product("ACGT", repeat=k)]

# test get_kmer_vocab()
# vocab_kmer = get_kmer_vocab(3)
# " ".join(vocab_kmer)

### Prepare datasets

In [11]:
TRAIN_SIZE = 8000
VALID_SIZE = 1000
TEST_SIZE = 1000

In [12]:
ds_train = ds_all.take(TRAIN_SIZE)
ds_not_train = ds_all.skip(TRAIN_SIZE)
ds_valid = ds_not_train.take(VALID_SIZE)
ds_test = ds_not_train.skip(VALID_SIZE)

assert ds_train.cardinality().numpy() == TRAIN_SIZE
assert ds_valid.cardinality().numpy() == VALID_SIZE
assert ds_test.cardinality().numpy() == TEST_SIZE

In [13]:
def transform_ds(ds, method, scale_factor, batch_size=None, shuffle_buffer=None, shuffle_seed=None):
    assert method == "onehot" or method.endswith("mer")
    assert (shuffle_buffer is not None) ^ (shuffle_seed is None)
    if method == "onehot":
        ds = ds.map(lambda x, y: (dna_onehot_tensor(x), y / scale_factor))
    else:
        k = int(method[0])
        ds = ds.map(lambda x, y: (get_kmer_tensor(x, k), y / scale_factor))
    if not batch_size:
        return ds
    if shuffle_buffer:
        ds = ds.shuffle(shuffle_buffer, seed=shuffle_seed, reshuffle_each_iteration=True)
    batched = ds.padded_batch(batch_size) if method == "onehot" else ds.batch(batch_size)
    return batched
                    
# test encode_ds()
# for batch_seq, batch_pos in transform_ds(ds_train.take(6), "onehot", 450, 3, 5, 1):
#     display(batch_seq.shape)
#     display(batch_pos.shape)   
# display(batch_seq)
# display(batch_pos)
    
# for batch_seq, batch_pos in transform_ds(ds_train.take(6), "3mer", 450, 3, 5, 1):
#     display(batch_seq.shape)
#     display(batch_pos.shape)
# display(batch_seq)
# display(batch_pos)

### Functions and classes for building model

In [14]:
class ModelWriteGrad(keras.Model):
    
    def set_model(
        self, logdir, write_grad, 
        scalars=[], scalar_freq=1, histogram_freq=1, update_freq="batch", batch_num=None, 
        inspect_adam=False, inspect_clip=False, replace_nan=None, replace_inf=None
    ):
        # Functions for compute scalars of gradients
        self.scalar_funs = {
            "abs_max": lambda x: tf.reduce_max(tf.abs(x)),
            "abs_min": lambda x: tf.reduce_min(tf.abs(x)),
            "norm": lambda x: tf.norm(x)
        }
        
        # check arguments
        assert set(scalars).issubset(set(self.scalar_funs.keys()))
        assert update_freq in ("batch", "epoch")
        assert not (update_freq == "epoch" and batch_num is None)
        
        # initialize
        self.step = 0
        self.writer = tf.summary.create_file_writer(logdir + "/gradients") if write_grad else None
        self.write_grad = write_grad
        self.scalars = scalars
        self.scalar_freq = scalar_freq if update_freq == "batch" else scalar_freq * batch_num
        self.histogram_freq = histogram_freq if update_freq == "batch" else histogram_freq * batch_num
        self.update_freq = update_freq
        self.batch_num = batch_num
        self.inspect_clip = inspect_clip
        self.replace_nan = tf.constant(replace_nan, dtype=tf.float32) if replace_nan else False
        self.replace_inf = tf.constant(replace_inf, dtype=tf.float32) if replace_inf else False
        
        opt_param = self.optimizer.get_config()
        self.inspect_adam = (inspect_adam and opt_param["name"] == "Adam")
        if self.inspect_adam:
            var_num = len(self.trainable_variables)
            self.m = [0 for _ in range(var_num)]
            self.v = [0 for _ in range(var_num)]
            self.lr = opt_param["learning_rate"]
            self.beta_1 = opt_param["beta_1"]
            self.beta_2 = opt_param["beta_2"]
            self.epsilon = opt_param["epsilon"]
        if self.inspect_clip:
            self.clipvalue = opt_param.get("clipvalue")
            self.clipnorm = opt_param.get("clipnorm")
            
    def write_value(self, values, value_name, trainable_vars):
        with self.writer.as_default():
            for val, var in zip(values, trainable_vars):
                if isinstance(val, tf.IndexedSlices):
                    val = tf.convert_to_tensor(val)
                var_name = var.name.replace(':', '_')
                var_name = re.sub(r"/gru_cell_\d+", "", var_name)
                if self.histogram_freq != 0 and self.step % self.histogram_freq == 0:
                    tf.summary.histogram(f"{var_name}/{value_name}", val, step=self.step)
                if self.scalars and self.scalar_freq != 0 and self.step % self.scalar_freq == 0:
                    for scalar in self.scalars:
                        sum_name = f"{var_name}/{value_name}/{scalar}"
                        sum_val = self.scalar_funs[scalar](val)
                        tf.summary.scalar(sum_name, sum_val, step=self.step)
            self.writer.flush()
            
    def train_step(self, data):
        
        # Compute gradients
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        
        # Check and process nan and Inf
        for i in range(len(gradients)):
            grad = gradients[i]
            var = trainable_vars[i]
            var_name = var.name.replace(':', '_')
            var_name = re.sub(r"/gru_cell_\d+", "", var_name)
            
            nan_bool = tf.math.is_nan(grad)
            if self.write_grad and self.inspect_clip:
                show_nan = f"Batch {self.step}: Nan in gradients of {var_name}"
                tf.cond(tf.reduce_any(nan_bool), lambda: tf.print(show_nan), tf.no_op)
            if self.replace_nan:
                if isinstance(grad, tf.IndexedSlices):
                    grad = tf.convert_to_tensor(grad)
                grad = tf.where(nan_bool, self.replace_nan, grad)
                gradients[i] = grad
            
            inf_bool = tf.math.is_inf(grad)
            if self.write_grad and self.inspect_clip:
                show_inf = f"Batch {self.step}: Inf in gradients of {var_name}"
                tf.cond(tf.reduce_any(inf_bool), lambda: tf.print(show_inf), tf.no_op)
            if self.replace_inf:
                if isinstance(grad, tf.IndexedSlices):
                    grad = tf.convert_to_tensor(grad)
                pos_bool = inf_bool & (grad > 0)
                neg_bool = inf_bool & (grad < 0)
                grad = tf.where(pos_bool, self.replace_inf, grad)
                grad = tf.where(neg_bool, -1 * self.replace_inf, grad)
                gradients[i] = grad
                
            if self.write_grad and self.inspect_clip and self.clipvalue:
                show_clip = f"Batch {self.step}: clip value for gradients of {var_name}"
                clip_bool = tf.reduce_any(tf.abs(grad) > self.clipvalue)
                tf.cond(clip_bool, lambda: tf.print(show_clip), tf.no_op)
            if self.write_grad and self.inspect_clip and self.clipnorm:
                show_clip = f"Batch {self.step}: clip norm for gradients of {var_name}"
                clip_bool = tf.reduce_any(tf.norm(grad) > self.clipnorm)
                tf.cond(clip_bool, lambda: tf.print(show_clip), tf.no_op)
        
        # Compute adam values
        if self.write_grad and self.inspect_adam:
            t = self.step + 1
            lr = self.lr * tf.sqrt(1 - tf.pow(self.beta_2, t)) / (1 - tf.pow(self.beta_1, t))
            var_num = len(self.trainable_variables)
            root_v = [0 for i in range(var_num)]
            weight_updates = [0 for i in range(var_num)]
            for i in range(var_num):
                grad = gradients[i]
                if isinstance(grad, tf.IndexedSlices):
                    grad = tf.convert_to_tensor(grad)
                self.m[i] = self.m[i] * self.beta_1 + (1 - self.beta_1) * grad
                self.v[i] = self.v[i] * self.beta_2 + (1 - self.beta_2) * tf.square(grad)
                root_v[i] = tf.sqrt(self.v[i])
                weight_updates[i] = lr * self.m[i] / (root_v[i] + self.epsilon)
        
        # Record gradients
        if self.write_grad:
            self.write_value(gradients, "grads", trainable_vars)
            if self.inspect_adam:
                self.write_value(self.m, "m", trainable_vars)
                self.write_value(root_v, "root_v", trainable_vars)
                self.write_value(weight_updates, "updates", trainable_vars)
                
        # Update step, weights and metrics
        self.step += 1
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.compiled_metrics.update_state(y, y_pred)
        
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

In [15]:
# build the encoders for kmer
def create_kmer_encoder(k, embed_dim):
    text_vec_layer = keras.layers.TextVectorization(
        standardize=None,
        split="whitespace",
        vocabulary=get_kmer_vocab(k),
        name="text_vectorize"
    )
    embedding_layer = keras.layers.Embedding(
        input_dim=text_vec_layer.vocabulary_size(), 
        output_dim=embed_dim, 
        mask_zero=True,
        name="embedding"
    )

    kmer_seq = keras.Input(shape=(1,), dtype=tf.string, name=f"{k}mer_seq")
    tokens = text_vec_layer(kmer_seq)
    embedding = embedding_layer(tokens)
    kmer_encoder = keras.models.Model(kmer_seq, embedding, name=f"{k}mer_encoder")
    return kmer_encoder
    
# test_encoder = create_kmer_encoder(5, 64)
# test_encoder.summary()

In [16]:
# build the complete models
def create_model(
    encode_name, *, 
    EMBED_DIM, GRU_NUM, GRU_UNIT, GRU_BIDIRECT, DENSE_NUM, DENSE_UNIT, LEAKY_ALPHA, OUT_ACTI
):
    
    # input and encoding part
    if encode_name == "onehot":
        model_input = keras.Input(shape=(None, 4), name="input")
        x = keras.layers.Masking(mask_value=[0.0, 0.0, 0.0, 0.0], name="mask")(model_input)
    else:
        k = int(encode_name[0])
        model_input = keras.Input(shape=(1, ), dtype=tf.string, name="input")
        kmer_encoder = create_kmer_encoder(k, EMBED_DIM)
        x = kmer_encoder(model_input)
    
    # gru part
    for i in range(GRU_NUM):
        gru_bidirect = GRU_BIDIRECT if isinstance(GRU_BIDIRECT, bool) else GRU_BIDIRECT[i]
        gru_unit = GRU_UNIT if isinstance(GRU_UNIT, int) else GRU_UNIT[i]
        not_last_gru = (i+1 != GRU_NUM)
        gru_layer = keras.layers.GRU(gru_unit, return_sequences=not_last_gru, name=f"gru_{i+1}")
        if gru_bidirect:
            gru_layer = keras.layers.Bidirectional(gru_layer, name=f"gru_{i+1}")
        x = gru_layer(x)
    
    # dense part
    dense_acti = lambda x: keras.activations.relu(x, alpha=LEAKY_ALPHA)
    for i in range(DENSE_NUM):
        dense_unit = DENSE_UNIT if isinstance(DENSE_UNIT, int) else DENSE_UNIT[i]
        x = keras.layers.Dense(dense_unit, activation=dense_acti, name=f"dense_{i+1}")(x)
    
    # build the model
    model_output = keras.layers.Dense(6, activation=OUT_ACTI, name="output")(x)
    model = ModelWriteGrad(model_input, model_output, name=f"{encode_name}_model")
    return model

### Build model, compile and train

In [17]:
def get_batch_num(ds_size, batch_size):
    a, b = divmod(ds_size, batch_size)
    return a + int(b != 0)

In [18]:
def pretty_json(hp):
  json_hp = json.dumps(hp, indent=2)
  return "".join("\t" + line for line in json_hp.splitlines(True))

In [19]:
# Custom metrics
class ScaledRMSE(keras.metrics.RootMeanSquaredError):
    def __init__(self, scale_factor, name="s_rmse", dtype=None):
        super().__init__(name, dtype=dtype)
        self.scale_factor = scale_factor
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, self._dtype) * self.scale_factor
        y_pred = tf.cast(y_pred, self._dtype) * self.scale_factor
        super().update_state(y_true, y_pred, sample_weight)
        
class ScaledMAE(keras.metrics.MeanAbsoluteError):
    def __init__(self, scale_factor, name="s_mae", dtype=None):
        super().__init__(name, dtype=dtype)
        self.scale_factor = scale_factor
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, self._dtype) * self.scale_factor
        y_pred = tf.cast(y_pred, self._dtype) * self.scale_factor
        super().update_state(y_true, y_pred, sample_weight)
        
class PosAccuracy(keras.metrics.Accuracy):
    def __init__(self, scale_factor, name="p_acc", dtype=None):
        super().__init__(name, dtype=dtype)
        self.scale_factor = scale_factor
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.math.rint(tf.cast(y_true, self._dtype) * self.scale_factor)
        y_pred = tf.math.rint(tf.cast(y_pred, self._dtype) * self.scale_factor)
        super().update_state(y_true, y_pred, sample_weight)
        
class AllAccuracy(keras.metrics.Accuracy):
    def __init__(self, scale_factor, name="a_acc", dtype=None):
        super().__init__(name, dtype=dtype)
        self.scale_factor = scale_factor
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.math.rint(tf.cast(y_true, self._dtype) * self.scale_factor)
        y_pred = tf.math.rint(tf.cast(y_pred, self._dtype) * self.scale_factor)
        all_correct_bool = tf.reduce_all(y_true == y_pred, axis=-1)
        correct_mat = tf.where(all_correct_bool, 1, 0)
        y_pred = tf.reshape(correct_mat, (-1, 1))
        y_true = tf.ones(tf.shape(y_true)[0], 1)
        super().update_state(y_true, y_pred, sample_weight)
        
class ClassAccuracy(keras.metrics.Metric):
    def __init__(self, scale_factor, name="c_acc", **kwargs):
        super().__init__(name=name, **kwargs)
        self.scale_factor = scale_factor
        self.len_sum = self.add_weight(name="len_sum", initializer="zeros")
        self.correct_sum = self.add_weight(name="correct_sum", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.math.rint(tf.cast(y_true, self._dtype) * self.scale_factor)
        y_pred = tf.math.rint(tf.cast(y_pred, self._dtype) * self.scale_factor)
        incorrect_count = tf.reduce_sum(tf.abs(y_true - y_pred))
        batch_len_sum = tf.reduce_sum(tf.gather(y_true, indices=(5), axis=-1))
        self.len_sum.assign_add(batch_len_sum)
        self.correct_sum.assign_add(batch_len_sum - incorrect_count)

    def result(self):
        return self.correct_sum / self.len_sum

    def reset_state(self):
        self.len_sum.assign(0.0)
        self.correct_sum.assign(0.0)
        
class SegmentPosAccuracy(PosAccuracy):
    def __init__(self, scale_factor, segment, name="p_acc", dtype=None):
        assert segment in ("V", "D", "J")
        super().__init__(scale_factor, name=f"{segment}_{name}", dtype=dtype)
        self.segment = segment
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        seg_idx_dict = {"V": (0, 1), "D": (2, 3), "J": (4, 5)}
        seg_idx = seg_idx_dict[self.segment]
        y_true_seg = tf.gather(y_true, indices=seg_idx, axis=-1)
        y_pred_seg = tf.gather(y_pred, indices=seg_idx, axis=-1)
        super().update_state(y_true_seg, y_pred_seg, sample_weight)

In [20]:
# A callback class that stops the training when there is a sudden increase in one of the metrics 
class SpikeStopping(keras.callbacks.Callback):
    def __init__(self, monitor, fold_allow=2, min_allow=5, skip_first=10):
        super(SpikeStopping, self).__init__()
        self.monitor = monitor
        self.fold_allow = fold_allow
        self.min_allow = min_allow
        self.skip_first = skip_first
    
    def on_epoch_end(self, epoch, logs=None):
        current = logs.get(self.monitor)
        if epoch == 0:
            self.last = current
        if epoch >= self.skip_first and current > self.last * fold_allow + min_allow:
            print(f"Stop training at epoch {epoch}")
            self.model.stop_training = True
        self.last = current

In [21]:
# A callback class that records the training information
class WriteInfo(keras.callbacks.Callback):
    def __init__(self, logdir, info):
        super().__init__()
        self.writer = tf.summary.create_file_writer(logdir + "/info")
        self.info = info
    
    def on_train_begin(self, logs=None):
        model_info = {
            "MODEL_NAME": self.model.name,
            "MODEL_PARAM_NUM": self.model.count_params(),
        }
        self.info["MODEL_INFO"] = model_info
        info_json = pretty_json(self.info)
        with self.writer.as_default():
            tf.summary.text("training_setting", info_json, step=0)

In [22]:
# A callback class that records the training time
class WriteTrainingTime(keras.callbacks.Callback):
    def __init__(self, logdir, print_time=True):
        super().__init__()
        self.writer = tf.summary.create_file_writer(logdir + "/info")
        self.print_time = print_time
        self.epoch = 0
    
    def on_train_begin(self, logs=None):
        self.start = datetime.now()
        
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch

    def on_train_end(self, logs=None):
        duration = datetime.now() - self.start + datetime.min
        time_str = duration.strftime("%H:%M:%S.%f")
        if self.print_time:
            print(f"Time usage: {time_str}")
        with self.writer.as_default():
            tf.summary.text("training_time", time_str, step=self.epoch)

In [23]:
settings = {
    "GRU_UNIT": [1024, 512, 256, 128],
    "GRU_NUM": [1, 1, 6, 6], 
    "GRU_BIDIRECT": [False, True, False, True], 
    "DENSE_UNIT": [128, 128, 256, 256], 
    "DENSE_NUM": [4, 4, 1, 1]
}
df_settings = pd.DataFrame(settings, index=range(1, 5), dtype="object")
df_settings

Unnamed: 0,GRU_UNIT,GRU_NUM,GRU_BIDIRECT,DENSE_UNIT,DENSE_NUM
1,1024,1,False,128,4
2,512,1,True,128,4
3,256,6,False,256,1
4,128,6,True,256,1


In [24]:
SETTING = 4
MODEL_STRUCT = {
    "EMBED_DIM": 64,
    "GRU_UNIT": df_settings.loc[SETTING, "GRU_UNIT"],
    "GRU_NUM": df_settings.loc[SETTING, "GRU_NUM"],
    "GRU_BIDIRECT": df_settings.loc[SETTING, "GRU_BIDIRECT"],
    "DENSE_UNIT": df_settings.loc[SETTING, "DENSE_UNIT"],
    "DENSE_NUM": df_settings.loc[SETTING, "DENSE_NUM"],
    "LEAKY_ALPHA": 0.1,
    "OUT_ACTI": "linear"
}
show_model = create_model("5mer", **MODEL_STRUCT)
show_model.summary()

Model: "5mer_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input (InputLayer)          [(None, 1)]               0         
                                                                 
 5mer_encoder (Functional)   (None, None, 64)          65664     
                                                                 
 gru_1 (Bidirectional)       (None, None, 256)         148992    
                                                                 
 gru_2 (Bidirectional)       (None, None, 256)         296448    
                                                                 
 gru_3 (Bidirectional)       (None, None, 256)         296448    
                                                                 
 gru_4 (Bidirectional)       (None, None, 256)         296448    
                                                                 
 gru_5 (Bidirectional)       (None, None, 256)         2

In [25]:
# training setting
EPOCH = 100
BATCH_SIZE = 32
SCALE_FACTOR = 450
SHUFFLE_BUFFER = 128
PREFETCH_BUFFER = 8
LOSS = keras.losses.MeanSquaredError()
OPTIMIZER = keras.optimizers.Adam(
    learning_rate=1e-3,
    clipnorm=1e-3
)
REPLACE_NAN = 1e-5
REPLACE_INF = 1e-5
EARLY_STOP = False
STOP_NAN = True
WRITE_GRAD = False

train_info = {
    "SEED": SEED,
    "DATA": {
        "DATA_PATH": DATA_PATH,
        "TRAIN_SIZE": TRAIN_SIZE,
        "VALID_SIZE": VALID_SIZE,
        "TEST_SIZE": TEST_SIZE,
        "SCALE_FACTOR": SCALE_FACTOR,
        "SHUFFLE_BUFFER": SHUFFLE_BUFFER,
        "PREFETCH_BUFFER": PREFETCH_BUFFER
    },
    "MODEL_STRUCT": MODEL_STRUCT,
    "TRAIN_SET": {
        "EPOCH": EPOCH,
        "BATCH_SIZE": BATCH_SIZE,
        "LOSS": LOSS.get_config(),
        "OPTIMIZER": OPTIMIZER.get_config(),
        "REPLACE_NAN": REPLACE_NAN,
        "REPLACE_INF": REPLACE_INF,
        "EARLY_STOP": EARLY_STOP,
        "STOP_NAN": STOP_NAN,
        "WRITE_GRAD": WRITE_GRAD
    }
}

In [None]:
# start training
timestamp = datetime.now().strftime("%m%d-%H%M%S")
MAIN_LOGDIR = f"./logs/ce_no_spike/repro_setting_{SETTING}_seed_{SEED}_3"

encode_names = ("onehot", "2mer", "3mer", "4mer", "5mer")
metric_list = [
    ScaledRMSE(SCALE_FACTOR),
    ScaledMAE(SCALE_FACTOR),
    PosAccuracy(SCALE_FACTOR),
    AllAccuracy(SCALE_FACTOR),
    ClassAccuracy(SCALE_FACTOR),
    SegmentPosAccuracy(SCALE_FACTOR, segment="V"),
    SegmentPosAccuracy(SCALE_FACTOR, segment="D"),
    SegmentPosAccuracy(SCALE_FACTOR, segment="J")
]
batch_num = get_batch_num(TRAIN_SIZE, BATCH_SIZE)

all_history = {}
for name in encode_names:
    
    # build model and compile
    model = create_model(name, **MODEL_STRUCT)
    model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=metric_list, run_eagerly=WRITE_GRAD)
    
    # prepare dataset
    ds_train_input = transform_ds(ds_train, name, SCALE_FACTOR, BATCH_SIZE, SHUFFLE_BUFFER, SEED)
    ds_valid_input = transform_ds(ds_valid, name, SCALE_FACTOR, BATCH_SIZE)
    ds_train_input = ds_train_input.prefetch(PREFETCH_BUFFER)
    ds_valid_input = ds_valid_input.prefetch(PREFETCH_BUFFER)
    
    # setting callbacks
    logdir = MAIN_LOGDIR + f"/{model.name}"
    tensorboard_cb = keras.callbacks.TensorBoard(
        logdir, 
        histogram_freq=max(1, EPOCH // 20),
        write_graph=False
    )
    # spike_stop_cb = SpikeStopping("s_rmse", fold_allow=1.5, min_allow=3, skip_first=10)
    write_info_cb = WriteInfo(logdir, train_info)
    write_time_cb = WriteTrainingTime(logdir)
    cb_list = [
        tensorboard_cb,
        # spike_stop_cb,
        write_info_cb,
        write_time_cb
    ]
    if STOP_NAN:
        stop_nan_cb = keras.callbacks.TerminateOnNaN()
        cb_list.append(stop_nan_cb)
    if EARLY_STOP:
        early_stop_cb = keras.callbacks.EarlyStopping(**EARLY_STOP)
        cb_list.append(early_stop_cb)
    
    # train
    print(f"{model.name} training!")
    model.set_model(
        logdir, write_grad=WRITE_GRAD,
        scalars=["norm"], scalar_freq=0, histogram_freq=0, update_freq="epoch", batch_num=batch_num,
        inspect_adam=False, inspect_clip=False, replace_nan=REPLACE_NAN, replace_inf=REPLACE_INF
    )
    all_history[name] = model.fit(
        ds_train_input,
        epochs=EPOCH, 
        validation_data=ds_valid_input,
        callbacks=cb_list,
        verbose=0
    )
    model.save(logdir + "/model")

onehot_model training!


### Result

In [None]:
performance = pd.DataFrame()
for i, name in enumerate(encode_names):
    df = pd.DataFrame(all_history[name].history)
    best_loss_idx = df["loss"].argmin()
    best_val_loss_idx = df["val_loss"].argmin()
    valid_metrics_idx = df.columns.str.startswith("val")
    train_metrics = df.columns[~valid_metrics_idx]
    valid_metrics = df.columns[valid_metrics_idx]
    performance.loc[i, "name"] = name
    performance.loc[i, "number_of_parameters"] = all_history[name].model.count_params()
    performance.loc[i, "best_loss_epoch"] = df.index[best_loss_idx]
    performance.loc[i, train_metrics] = df.loc[best_loss_idx, train_metrics]
    performance.loc[i, "best_val_loss_epoch"] = df.index[best_val_loss_idx]
    performance.loc[i, valid_metrics] = df.loc[best_loss_idx, valid_metrics]

performance = pd.DataFrame(performance)
performance.to_csv(MAIN_LOGDIR + "/performance.csv", index=False)
performance

In [None]:
# for name in encode_names:
#     print(name)
#     model = all_history[name].model
#     ds_test_input = transform_ds(ds_test, name, SCALE_FACTOR, BATCH_SIZE)
#     model.evaluate(ds_test_input)

In [None]:
# if DATA_PATH.endswith("10_000_processed.tsv"):
#     print("Train on random ends, test on not random ends")
#     df_other = pd.read_csv("./data/airrship_shm_seed42_10_000.tsv", sep="\t")
# else:
#     print("Train on not random ends, test on random ends")
#     df_other = pd.read_csv("./data/airrship_shm_seed42_10_000_processed.tsv", sep="\t")
    
# df_other = df_other.iloc[-1000:]
# print(df_other.shape)

# pos_names = [f"{seg}_sequence_{pos}" for seg in "vdj" for pos in ("start", "end")]
# seq_other = df_other["sequence"]
# pos_other = df_other[pos_names]

# ds_seq_other = Dataset.from_tensor_slices(seq_other.to_numpy())
# ds_pos_other = Dataset.from_tensor_slices(pos_other.to_numpy())
# ds_all_other = Dataset.zip((ds_seq_other, ds_pos_other))

# for name in encode_names:
#     print(name)
#     model = all_history[name].model
#     ds_other = transform_ds(ds_all_other, name, SCALE_FACTOR, BATCH_SIZE)
#     model.evaluate(ds_other)