### Import and global config

In [1]:
%env TF_CPP_MIN_LOG_LEVEL=2

env: TF_CPP_MIN_LOG_LEVEL=2


In [74]:
from datetime import datetime
from itertools import product
import re
import json
import math
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.data import Dataset
from tensorflow.keras import backend
import sklearn.metrics as sk_metrics

In [4]:
def check_and_set_gpu():
    gpu_list = tf.config.list_physical_devices('GPU')
    if len(gpu_list) == 0:
        print("No available GPU!")
    else:
        try:
            tf.config.experimental.set_memory_growth(gpu_list[0], True)
            print("Enable VRAM growth")
        except e:
            print(e)
check_and_set_gpu()

Enable VRAM growth


In [5]:
# for reproducibility
SEED = 42
keras.utils.set_random_seed(SEED)
tf.config.experimental.enable_op_determinism()

# for network debugging
# tf.debugging.enable_check_numerics()
# tf.debugging.experimental.enable_dump_debug_info(  #  incompatible with enable_op_determinism()
#     f"./logs/tfdbg2_logdir",
#     tensor_debug_mode="FULL_HEALTH",
#     circular_buffer_size=-1
# )

### load data

In [6]:
DATA_PATH = "./data/airrship_shm_seed42_100_000.tsv"
df_data = pd.read_csv(DATA_PATH, sep="\t")
df_data.head()

Unnamed: 0,sequence_id,sequence,productive,stop_codon,vj_in_frame,v_call,d_call,j_call,junction,junction_aa,...,d_sequence_start,d_sequence_end,j_sequence_start,j_sequence_end,shm_events,shm_count,shm_freq,unmutated_sequence,gapped_unmutated_sequence,gapped_mutated_sequence
0,0,CAGGTGCAGCTGCGGGAGTCGGGCCCAGGGCTGGTGAAGCCTTTGG...,T,T,F,IGHV4-61*08,IGHD3-3*02,IGHJ4*01,TGCGCGAGGCCGCCAGGTGTATCAGCATTTAGGAGGACACCCGCTT...,CARPPGVSAFRRTPAWDFDPW,...,307,318,338,382,"14:A>G,30:A>G,44:C>T,68:C>A,84:C>T,89:G>C,90:C...",28,0.073298,CAGGTGCAGCTGCAGGAGTCGGGCCCAGGACTGGTGAAGCCTTCGG...,CAGGTGCAGCTGCAGGAGTCGGGCCCA...GGACTGGTGAAGCCTT...,CAGGTGCAGCTGCGGGAGTCGGGCCCA...GGGCTGGTGAAGCCTT...
1,1,CAGGTCACTTTGAGGGAGTCTGGTCCTGCGCTGGTGAAACCCACAC...,T,T,F,IGHV2-70*19,IGHD2-8*02,IGHJ4*01,TGTGCACGGGGGCATGTCCACGATAGGGTCTTTCCGAGAGTTGACT...,CARGHVHDRVFPRVDFW,...,310,313,329,370,"9:C>T,72:C>T,82:T>C,85:C>A,88:A>T,100:A>G,105:...",20,0.054054,CAGGTCACCTTGAGGGAGTCTGGTCCTGCGCTGGTGAAACCCACAC...,CAGGTCACCTTGAGGGAGTCTGGTCCT...GCGCTGGTGAAACCCA...,CAGGTCACTTTGAGGGAGTCTGGTCCT...GCGCTGGTGAAACCCA...
2,2,GAGGTGCAGCTCCCGGAGTCTGGGGGCGGCCTGGTACAGCCTGGGG...,T,T,F,IGHV3-23*03,IGHD5-24*01,IGHJ4*01,TGTGCGAGAGACGGAAAAAAGAGACCCGACTGG,CARDGKKRPDW,...,305,309,314,349,"12:G>C,13:T>C,14:T>C,27:A>C,31:T>C,71:C>T,90:C...",34,0.097421,GAGGTGCAGCTGTTGGAGTCTGGGGGAGGCTTGGTACAGCCTGGGG...,GAGGTGCAGCTGTTGGAGTCTGGGGGA...GGCTTGGTACAGCCTG...,GAGGTGCAGCTCCCGGAGTCTGGGGGC...GGCCTGGTACAGCCTG...
3,3,CAGGTGCAGCTGGTGGAGTCTGGGGGAGGCGTGGACCAGCCTGGGA...,T,T,F,IGHV3-33*05,IGHD3/OR15-3a*01,IGHJ4*01,TGTGCGAGAGACAAAAATTTGGGACTGGCCGGGAACTTCTTTGACT...,CARDKNLGLAGNFFDYW,...,304,313,326,367,"35:T>A,72:G>T,92:G>C,98:G>A,132:G>A,151:A>T,15...",17,0.046322,CAGGTGCAGCTGGTGGAGTCTGGGGGAGGCGTGGTCCAGCCTGGGA...,CAGGTGCAGCTGGTGGAGTCTGGGGGA...GGCGTGGTCCAGCCTG...,CAGGTGCAGCTGGTGGAGTCTGGGGGA...GGCGTGGACCAGCCTG...
4,4,GAGGTGCAGCTGGTGGAGTCTGGGGGAGGCTTAGTTCAGCCTGGGG...,T,T,F,IGHV3-74*02,IGHD1-26*01,IGHJ4*01,TGTGCAAGACAAGTGGGGGGCAATATCGACCACCTTTCGAAATACT...,CARQVGGNIDHLSKYYW,...,298,301,329,367,"89:G>C,93:C>T,97:T>C,119:C>T,138:G>A,147:A>C,1...",14,0.038147,GAGGTGCAGCTGGTGGAGTCTGGGGGAGGCTTAGTTCAGCCTGGGG...,GAGGTGCAGCTGGTGGAGTCTGGGGGA...GGCTTAGTTCAGCCTG...,GAGGTGCAGCTGGTGGAGTCTGGGGGA...GGCTTAGTTCAGCCTG...


In [7]:
pos_names = [f"{seg}_sequence_{pos}" for seg in "vdj" for pos in ("start", "end")]
sequences = df_data["sequence"]
positions = df_data[pos_names]

display(sequences.notna().all())
display(positions.notna().all())

True

v_sequence_start    True
v_sequence_end      True
d_sequence_start    True
d_sequence_end      True
j_sequence_start    True
j_sequence_end      True
dtype: bool

In [8]:
ds_seq = Dataset.from_tensor_slices(sequences.to_numpy())
ds_pos = Dataset.from_tensor_slices(positions.to_numpy())
ds_all = Dataset.zip((ds_seq, ds_pos))

display(ds_all.take(1).get_single_element())
display(ds_all.cardinality().numpy())

(<tf.Tensor: shape=(), dtype=string, numpy=b'CAGGTGCAGCTGCGGGAGTCGGGCCCAGGGCTGGTGAAGCCTTTGGAGACCCTGTCCCTCACCTGCAATGTCTCTGGTGGCTCTGTCACTAGTGGTGGTTACTACTGGAGTTGGGTCCGGCTGACCCCAGGGAAGGGACTGGACTGGATTGGTTTTCTTTATTACAGTGGGAGTACCAATTACAACCCCTCCCTCGAGACTCGAGTCACCATATCAGTAGACACGGCCAAGAACCAGTTCTCTCTGAAGGTGAGCTCTGTGACCGCTGCGGACACGGCCGTGTATTACTGCGCGAGGCCGCCAGGTGTATCAGCATTTAGGAGGACACCCGCTTGGGACTTTGACCCCTGGGGCCATGGAACCCTGGTCACCGTCTCCTCAG'>,
 <tf.Tensor: shape=(6,), dtype=int64, numpy=array([  1, 296, 307, 318, 338, 382])>)

100000

### Functions for encoding

In [9]:
def dna_onehot_tensor(seq):
    table = tf.lookup.StaticHashTable(
        initializer=tf.lookup.KeyValueTensorInitializer(
            keys=tf.constant(["A", "C", "G", "T"], dtype=tf.string),
            values=tf.constant([0, 1, 2, 3]),
        ),
        default_value=tf.constant(-1)
    )
    chars = tf.strings.bytes_split(seq)
    ind = table.lookup(chars)
    encoded = tf.one_hot(ind, depth=4)
    return encoded

# test dna_onehot_tensor
# test_oh = ds_all.take(6).map(
#     lambda x, y: (dna_onehot_tensor(x), y)
# )
# for padded_batch in test_oh.padded_batch(3):
#     display(padded_batch)

In [10]:
def get_kmer_tensor(seq, k):
    chars = tf.strings.bytes_split(seq)
    kmers = tf.strings.ngrams(chars, k, separator="")
    sentence = tf.strings.reduce_join(kmers, separator=" ")
    return sentence

# test get_kmer_tensor()
# test_kmer = ds_all.take(3).map(
#     lambda x, y: (get_kmer_tensor(x, 3), y)
# )
# display(test_kmer.batch(3).take(1).get_single_element())

In [11]:
def get_kmer_vocab(k):
    return ["".join(x) for x in product("ACGT", repeat=k)]

# test get_kmer_vocab()
# vocab_kmer = get_kmer_vocab(3)
# " ".join(vocab_kmer)

### Prepare datasets

In [12]:
TRAIN_SIZE = 99000
VALID_SIZE = 1000
TEST_SIZE = 0

In [13]:
ds_train = ds_all.take(TRAIN_SIZE)
ds_not_train = ds_all.skip(TRAIN_SIZE)
ds_valid = ds_not_train.take(VALID_SIZE)
ds_test = ds_not_train.skip(VALID_SIZE)

assert ds_train.cardinality().numpy() == TRAIN_SIZE
assert ds_valid.cardinality().numpy() == VALID_SIZE
assert ds_test.cardinality().numpy() == TEST_SIZE

In [14]:
def transform_ds(ds, method, scale_factor, batch_size=None, shuffle_buffer=None, shuffle_seed=None):
    assert method == "onehot" or method.endswith("mer")
    assert (shuffle_buffer is not None) ^ (shuffle_seed is None)
    if method == "onehot":
        ds = ds.map(lambda x, y: (dna_onehot_tensor(x), y / scale_factor))
    else:
        k = int(method[0])
        ds = ds.map(lambda x, y: (get_kmer_tensor(x, k), y / scale_factor))
    if not batch_size:
        return ds
    if shuffle_buffer:
        ds = ds.shuffle(shuffle_buffer, seed=shuffle_seed, reshuffle_each_iteration=True)
    batched = ds.padded_batch(batch_size) if method == "onehot" else ds.batch(batch_size)
    return batched
                    
# # test encode_ds()
# for batch_seq, batch_pos in transform_ds(ds_train.take(6), "onehot", 450, 3, 5, 1):
#     display(batch_seq.shape)
#     display(batch_pos.shape)   
# display(batch_seq)
# display(batch_pos)
    
# for batch_seq, batch_pos in transform_ds(ds_train.take(6), "3mer", 450, 3, 5, 1):
#     display(batch_seq.shape)
#     display(batch_pos.shape)
# display(batch_seq)
# display(batch_pos)

# batched = transform_ds(ds_train, "3mer", 450, 32, 128, SEED)
# display(batched.cardinality().numpy())
# prefetched = batched.prefetch(8)
# display(prefetched.cardinality().numpy())

### Functions and classes for building model

In [15]:
# Custom metrics
class ScaledRMSE(keras.metrics.RootMeanSquaredError):
    def __init__(self, scale_factor, name="s_rmse", dtype=None):
        super().__init__(name, dtype=dtype)
        self.scale_factor = scale_factor
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.math.rint(tf.cast(y_true, self._dtype) * self.scale_factor)
        y_pred = tf.cast(y_pred, self._dtype) * self.scale_factor
        super().update_state(y_true, y_pred, sample_weight)
        
class ScaledMAE(keras.metrics.MeanAbsoluteError):
    def __init__(self, scale_factor, name="s_mae", dtype=None):
        super().__init__(name, dtype=dtype)
        self.scale_factor = scale_factor
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.math.rint(tf.cast(y_true, self._dtype) * self.scale_factor)
        y_pred = tf.cast(y_pred, self._dtype) * self.scale_factor
        super().update_state(y_true, y_pred, sample_weight)
        
class PosAccuracy(keras.metrics.Accuracy):
    def __init__(self, scale_factor, name="p_acc", dtype=None):
        super().__init__(name, dtype=dtype)
        self.scale_factor = scale_factor
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.math.rint(tf.cast(y_true, self._dtype) * self.scale_factor)
        y_pred = tf.math.rint(tf.cast(y_pred, self._dtype) * self.scale_factor)
        super().update_state(y_true, y_pred, sample_weight)
        
class AllAccuracy(keras.metrics.Accuracy):
    def __init__(self, scale_factor, name="a_acc", dtype=None):
        super().__init__(name, dtype=dtype)
        self.scale_factor = scale_factor
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.math.rint(tf.cast(y_true, self._dtype) * self.scale_factor)
        y_pred = tf.math.rint(tf.cast(y_pred, self._dtype) * self.scale_factor)
        all_correct_bool = tf.reduce_all(y_true == y_pred, axis=-1)
        correct_mat = tf.where(all_correct_bool, 1, 0)
        y_pred = tf.reshape(correct_mat, (-1, 1))
        y_true = tf.ones(tf.shape(y_true)[0], 1)
        super().update_state(y_true, y_pred, sample_weight)
        
class ClassAccuracy(keras.metrics.Metric):
    def __init__(self, scale_factor, name="c_acc", **kwargs):
        super().__init__(name=name, **kwargs)
        self.scale_factor = scale_factor
        self.len_sum = self.add_weight(name="len_sum", initializer="zeros")
        self.correct_sum = self.add_weight(name="correct_sum", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.math.rint(tf.cast(y_true, self._dtype) * self.scale_factor)
        y_pred = tf.math.rint(tf.cast(y_pred, self._dtype) * self.scale_factor)
        incorrect_count = tf.reduce_sum(tf.abs(y_true - y_pred))
        batch_len_sum = tf.reduce_sum(tf.gather(y_true, indices=(5), axis=-1))
        self.len_sum.assign_add(batch_len_sum)
        self.correct_sum.assign_add(batch_len_sum - incorrect_count)

    def result(self):
        return self.correct_sum / self.len_sum

    def reset_state(self):
        self.len_sum.assign(0.0)
        self.correct_sum.assign(0.0)
        
class SegmentPosAccuracy(PosAccuracy):
    def __init__(self, scale_factor, segment, name="p_acc", dtype=None):
        assert segment in ("V", "D", "J")
        super().__init__(scale_factor, name=f"{segment}_{name}", dtype=dtype)
        self.segment = segment
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        seg_idx_dict = {"V": (0, 1), "D": (2, 3), "J": (4, 5)}
        seg_idx = seg_idx_dict[self.segment]
        y_true_seg = tf.gather(y_true, indices=seg_idx, axis=-1)
        y_pred_seg = tf.gather(y_pred, indices=seg_idx, axis=-1)
        super().update_state(y_true_seg, y_pred_seg, sample_weight)
        
class SegmentRMSE(ScaledRMSE):
    def __init__(self, scale_factor, segment, name="s_rmse", dtype=None):
        assert segment in ("V", "D", "J")
        super().__init__(scale_factor, name=f"{segment}_{name}", dtype=dtype)
        self.segment = segment
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        seg_idx_dict = {"V": (0, 1), "D": (2, 3), "J": (4, 5)}
        seg_idx = seg_idx_dict[self.segment]
        y_true_seg = tf.gather(y_true, indices=seg_idx, axis=-1)
        y_pred_seg = tf.gather(y_pred, indices=seg_idx, axis=-1)
        super().update_state(y_true_seg, y_pred_seg, sample_weight)
        
class SegmentMAE(ScaledMAE):
    def __init__(self, scale_factor, segment, name="s_mae", dtype=None):
        assert segment in ("V", "D", "J")
        super().__init__(scale_factor, name=f"{segment}_{name}", dtype=dtype)
        self.segment = segment
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        seg_idx_dict = {"V": (0, 1), "D": (2, 3), "J": (4, 5)}
        seg_idx = seg_idx_dict[self.segment]
        y_true_seg = tf.gather(y_true, indices=seg_idx, axis=-1)
        y_pred_seg = tf.gather(y_pred, indices=seg_idx, axis=-1)
        super().update_state(y_true_seg, y_pred_seg, sample_weight)

### Evaluate

In [37]:
def prepare_ds(data_path):
    df_test = pd.read_csv(data_path, sep="\t")
    pos_names = [f"{seg}_sequence_{pos}" for seg in "vdj" for pos in ("start", "end")]
    sequences = df_test["sequence"]
    positions = df_test[pos_names]
    ds_seq = Dataset.from_tensor_slices(sequences.to_numpy())
    ds_pos = Dataset.from_tensor_slices(positions.to_numpy())
    ds_all = Dataset.zip((ds_seq, ds_pos))
    ds_input = transform_ds(ds_all, "3mer", scale_factor=450, batch_size=512)
    return ds_input

In [75]:
def calculate_r2(ds, model):
    x = np.hstack([x for x, y in ds])
    y_true = np.vstack([y for x, y in ds])
    y_true = np.rint(y_true * 450)
    y_pred = model.predict(x, batch_size=512) * 450
    return sk_metrics.r2_score(y_true, y_pred)

In [16]:
best_model = keras.models.load_model("./logs/final_training/top_1/model/best_V", compile=False)



In [17]:
scale_factor = 450
metric_list = [
    ScaledRMSE(scale_factor),
    ScaledMAE(scale_factor),
    PosAccuracy(scale_factor),
    AllAccuracy(scale_factor),
    ClassAccuracy(scale_factor),
    SegmentRMSE(scale_factor, segment="V"),
    SegmentRMSE(scale_factor, segment="D"),
    SegmentRMSE(scale_factor, segment="J"),
    SegmentMAE(scale_factor, segment="V"),
    SegmentMAE(scale_factor, segment="D"),
    SegmentMAE(scale_factor, segment="J")
]
best_model.compile(metrics=metric_list)

In [66]:
ds_train_input = transform_ds(ds_train, "3mer", scale_factor, batch_size=512)
result = best_model.evaluate(ds_train_input)
calculate_r2(ds_train_input, best_model)



0.8044098275618138

In [67]:
ds_valid_input = transform_ds(ds_valid, "3mer", scale_factor, batch_size=512)
result = best_model.evaluate(ds_valid_input)
calculate_r2(ds_valid_input, best_model)



0.7743021615918613

In [68]:
ds_airrship = prepare_ds("./data/airrship_shm_seed24_100_000.tsv")
result = best_model.evaluate(ds_airrship)
calculate_r2(ds_airrship, best_model)



0.7597032598433578

In [69]:
ds_implant = prepare_ds("./data/implant_shm_seed24_100_000.tsv")
result = best_model.evaluate(ds_implant)
calculate_r2(ds_implant, best_model)



0.7372856453095048

In [71]:
ds_rand_end = prepare_ds("./data/airrship_shm_seed24_100_000_random_ends.tsv")
result = best_model.evaluate(ds_rand_end)
calculate_r2(ds_rand_end, best_model)



-11.937032459795487

In [72]:
ds_no_shm = prepare_ds("./data/airrship_seed24_100_000.tsv")
result = best_model.evaluate(ds_no_shm)
calculate_r2(ds_no_shm, best_model)



0.7830434386833991

### Compare with IgBlast

In [102]:
scale_factor = 1
metric_list = [
    ScaledRMSE(scale_factor),
    ScaledMAE(scale_factor),
    PosAccuracy(scale_factor),
    AllAccuracy(scale_factor),
    ClassAccuracy(scale_factor),
    SegmentRMSE(scale_factor, segment="V"),
    SegmentRMSE(scale_factor, segment="D"),
    SegmentRMSE(scale_factor, segment="J"),
    SegmentMAE(scale_factor, segment="V"),
    SegmentMAE(scale_factor, segment="D"),
    SegmentMAE(scale_factor, segment="J")
]

data_list = [
    "airrship_seed24_100_000.tsv",
    "airrship_shm_seed24_100_000.tsv",
    "airrship_shm_seed24_100_000_random_ends.tsv",
    "implant_shm_seed24_100_000.tsv"
]

for data_path in data_list:
    print(data_path)
    df_true = pd.read_csv("./data/" + data_path, sep="\t")
    df_pred = pd.read_csv("./data/igblast_" + data_path, sep="\t")
    assert (df_true["sequence"] == df_pred["sequence"]).all()
    df_true = df_true[pos_names]
    df_pred = df_pred[pos_names]
    
    pred_nan = df_pred.isna()
    if pred_nan.any().any():
        print(pred_nan.sum())
        df_pred.dropna(inplace=True)
    df_true = df_true.loc[df_pred.index]
    
    y_true = df_true.to_numpy()
    y_pred = df_pred.to_numpy()
    for metric in metric_list:
        metric.reset_states()
        metric.update_state(y_true, y_pred)
        result = metric.result()
        print(f"{metric.name}: {result}")
    print(f"r2: {sk_metrics.r2_score(y_true, y_pred)}")
    print()

airrship_seed24_100_000.tsv
v_sequence_start       0
v_sequence_end         0
d_sequence_start    2397
d_sequence_end      2397
j_sequence_start       0
j_sequence_end         0
dtype: int64
s_rmse: 1.780578851699829
s_mae: 0.4715617001056671
p_acc: 0.8244674801826477
a_acc: 0.3391494154930115
c_acc: 0.9923363924026489
V_s_rmse: 0.3858412802219391
D_s_rmse: 3.0084450244903564
J_s_rmse: 0.5583619475364685
V_s_mae: 0.07760519534349442
D_s_mae: 1.192437767982483
J_s_mae: 0.14464206993579865
r2: 0.9407354397961037

airrship_shm_seed24_100_000.tsv
v_sequence_start       0
v_sequence_end         0
d_sequence_start    2353
d_sequence_end      2353
j_sequence_start       3
j_sequence_end         3
dtype: int64
s_rmse: 2.4425885677337646
s_mae: 0.795991837978363
p_acc: 0.7648106217384338
a_acc: 0.21679775416851044
c_acc: 0.9870631694793701
V_s_rmse: 0.7011434435844421
D_s_rmse: 3.8311054706573486
J_s_rmse: 1.652194857597351
V_s_mae: 0.13926099240779877
D_s_mae: 1.7772828340530396
J_s_mae: 0.471