In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install polars 
!pip install einops

In [None]:
import gc
import os
import pickle
import random
import joblib
import math

import numpy as np
import polars as pl
import pandas as pd
import pyarrow

from tqdm import tqdm

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import average_precision_score as APS

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers , Model

from einops import rearrange, repeat

In [None]:
class CFG:
    DEBUG = True
    PREPROCESS = False
    EPOCHS = 8
    BATCH_SIZE = 512
    LR = 1e-3
    WD = 0.05
    
    N_ROWS = None
    
    NBR_FOLDS = 15
    SELECTED_FOLDS = [0]
    
    ON_TPU = False
    
    if DEBUG:
        EPOCHS = 3
        BATCH_SIZE = 512
        NBR_FOLDS = 5
        N_ROWS = 1_000_000
    
    CHANNELS = 128
    EMB_SIZE = 64
    
    
    SEED = 2024


In [None]:
def set_seeds(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)

set_seeds(seed=CFG.SEED)

In [None]:
# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
    tf.config.experimental_connect_to_cluster(tpu)
    # This is the TPU initialization code that has to be at the beginning.
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
    CFG.ON_TPU = True
    print(f"Running on TPU, ON_TPU ={CFG.ON_TPU}")
    print("REPLICAS: ", strategy.num_replicas_in_sync)
except tf.errors.NotFoundError:
    print("Not on TPU")

# Preprocessing

In [None]:
if CFG.PREPROCESS:
    enc = {'l': 1, 'y': 2, '@': 3, '3': 4, 'H': 5, 'S': 6, 'F': 7, 'C': 8, 'r': 9, 's': 10, '/': 11, 'c': 12, 'o': 13,
           '+': 14, 'I': 15, '5': 16, '(': 17, '2': 18, ')': 19, '9': 20, 'i': 21, '#': 22, '6': 23, '8': 24, '4': 25, '=': 26,
           '1': 27, 'O': 28, '[': 29, 'D': 30, 'B': 31, ']': 32, 'N': 33, '7': 34, 'n': 35, '-': 36}
    train_raw = pd.read_parquet('/kaggle/input/leash-BELKA/train.parquet')
    smiles = train_raw[train_raw['protein_name']=='BRD4']['molecule_smiles'].values
    assert (smiles!=train_raw[train_raw['protein_name']=='HSA']['molecule_smiles'].values).sum() == 0
    assert (smiles!=train_raw[train_raw['protein_name']=='sEH']['molecule_smiles'].values).sum() == 0
    def encode_smile(smile):
        tmp = [enc[i] for i in smile]
        tmp = tmp + [0]*(142-len(tmp))
        return np.array(tmp).astype(np.uint8)

    smiles_enc = joblib.Parallel(n_jobs=96)(joblib.delayed(encode_smile)(smile) for smile in tqdm(smiles))
    smiles_enc = np.stack(smiles_enc)
    train = pd.DataFrame(smiles_enc, columns = [f'enc{i}' for i in range(142)])
    train['bind1'] = train_raw[train_raw['protein_name']=='BRD4']['binds'].values
    train['bind2'] = train_raw[train_raw['protein_name']=='HSA']['binds'].values
    train['bind3'] = train_raw[train_raw['protein_name']=='sEH']['binds'].values
    train.to_parquet('train_enc.parquet')

    test_raw = pd.read_parquet('/kaggle/input/leash-BELKA/test.parquet')
    smiles = test_raw['molecule_smiles'].values

    smiles_enc = joblib.Parallel(n_jobs=96)(joblib.delayed(encode_smile)(smile) for smile in tqdm(smiles))
    smiles_enc = np.stack(smiles_enc)
    test = pd.DataFrame(smiles_enc, columns = [f'enc{i}' for i in range(142)])
    test.to_parquet('test_enc.parquet')

else:
    if CFG.DEBUG:
        train = pl.read_parquet('/kaggle/input/belka-enc-dataset/train_enc.parquet', n_rows=CFG.N_ROWS)
        test = pl.read_parquet('/kaggle/input/belka-enc-dataset/test_enc.parquet', n_rows=CFG.N_ROWS)
    else:
        train = pl.read_parquet('/kaggle/input/belka-enc-dataset/train_enc.parquet')
        test = pl.read_parquet('/kaggle/input/belka-enc-dataset/test_enc.parquet')

# Make Dataset

In [None]:
FEATURES = [f'enc{i}' for i in range(142)]
TARGETS = ['bind1', 'bind2', 'bind3']
skf = StratifiedKFold(n_splits = CFG.NBR_FOLDS, shuffle = True, random_state = 42)

train_idx, valid_idx = next(iter(skf.split(np.arange(len(train)), train[TARGETS].sum_horizontal())))
    
X_train = train[train_idx, FEATURES].to_numpy()
y_train = train[train_idx, TARGETS].to_numpy()
X_val = train[valid_idx, FEATURES].to_numpy()
y_val = train[valid_idx, TARGETS].to_numpy()
print('data loaded')


# Make Model

In [None]:
class ModelArgs:
    model_input_dim = CFG.EMB_SIZE
    model_states: int = 64
    model_internal_dim: int = CFG.CHANNELS 
    conv_kernel_size: int = 3
    delta_t_rank = math.ceil(model_input_dim / 16)
    delta_t_min: float = 0.001
    delta_t_max: float = 0.1
    delta_t_scale: float = 0.1
    delta_t_init_floor: float = 1e-4
    layer_id: int = -1
    seq_length: int = 142
    num_layers: int = 4
    dropout_rate: float = 0.2
    use_lm_head: float = False
    num_classes: int = 3
    vocab_size: int = 37
    final_activation = 'sigmoid'


In [None]:
class RMSNorm(layers.Layer):
    def __init__(self, d_model: int, eps: float=1e-5, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.eps = eps
        self.weight = tf.Variable(np.ones(d_model), dtype=tf.float32, trainable=True)

    def call(self, x):
        x = tf.math.reduce_mean(tf.math.pow(x, 2), axis=-1, keepdims=True)
        output = x * tf.math.rsqrt(x + self.eps) * self.weight
        return output

In [None]:
def selective_scan(u, delta, A, B, C, D):
    dA = tf.einsum('bld,dn->bldn', delta, A)
    dB_u = tf.einsum('bld,bld,bln->bldn', delta, u, B)
    
    dA_cumsum = tf.pad(dA[:, 1:], [[0, 0], [1, 1], [0, 0], [0, 0]])[:, 1:, :, :]
    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])
    dA_cumsum = tf.math.cumsum(dA_cumsum, axis=1)
    dA_cumsum = tf.exp(dA_cumsum)
    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])
    
    x = dB_u * dA_cumsum
    x = tf.math.cumsum(x, axis=1) / (dA_cumsum + 1e-12)
    
    y = tf.einsum('bldn,bln->bld', x, C)
    
    return y + u * D

In [None]:
class MambaLayer(layers.Layer):
    def __init__(self, model_args: ModelArgs, *args, **kwargs):
        super().__init__()
        self.args = model_args 
        
        self.in_projection = layers.Dense(
            self.args.model_internal_dim * 2,
            input_shape = (self.args.model_input_dim,),
            use_bias  = False
        )
        self.conv1D = layers.Conv1D(
            filters=self.args.model_internal_dim,
            kernel_size = self.args.conv_kernel_size,
            padding = 'same',
            use_bias = True,
            groups = self.args.model_internal_dim,
            data_format = 'channels_last',
            activation = 'silu'
        )
        
        
        A = repeat(
            tf.range(1, self.args.model_states +1, dtype=tf.float32),
            'n -> d n', d=self.args.model_internal_dim)
        self.A_log = tf.Variable(tf.math.log(A), trainable=False, dtype=tf.float32)
        
        self.D = tf.Variable(tf.ones(self.args.model_internal_dim), dtype=tf.float32)
        
        self.x_projection = layers.Dense(
            self.args.delta_t_rank + self.args.model_states *2,
            use_bias=False
        )
        
        self.delta_projection = layers.Dense(
            self.args.model_internal_dim,
            input_shape=(self.args.delta_t_rank,),
            use_bias=True
        )
        
        self.out_projection = layers.Dense(
            self.args.model_input_dim,
            input_shape = (self.args.model_internal_dim, ),
            use_bias = False
        )
    
    
    def call(self, x):
        batch_size, seq_len, channels = x.shape
        
        x_res = self.in_projection(x)
        x, res = tf.split(
            x_res,
            [self.args.model_internal_dim, self.args.model_internal_dim],
            axis = -1
        )
        
        x = self.conv1D(x)
        y = self.ssm(x)
        y = y * tf.nn.swish(res)
        return self.out_projection(y)
    
    def ssm(self, x):
        d_in,n = self.A_log.shape
        
        A = -tf.exp(tf.cast(self.A_log,tf.float32))
        D = tf.cast(self.D, tf.float32)
        
        x_dbl = self.x_projection(x)
        delta, B, C = tf.split(
            x_dbl,
            [self.args.delta_t_rank, n, n],
            axis=-1,
        )
        
        delta = tf.nn.softplus(
            self.delta_projection(delta)
        )
        
        return selective_scan(x, delta, A, B, C, D)

In [None]:
class DemoLayer(layers.Layer):
    def __init__(self, model_args, *args, **kwargs):
        super().__init__()
        self.args = model_args
        self.conv1D = layers.Conv1D(
            filters=self.args.model_input_dim,
            kernel_size = self.args.conv_kernel_size,
            padding = 'same',
            use_bias = True,
            groups = self.args.model_internal_dim,
            data_format = 'channels_last',
            activation = 'silu'
        )
    
    def call(self, x):
        x = self.conv1D(x)
        return x

In [None]:
class ResidualLayer(layers.Layer):
    def __init__(self, model_args: ModelArgs, *args, **kwargs):
        super().__init__()
        self.args = model_args
        self.norm = layers.LayerNormalization(epsilon=1e-5)
#         self.demo = DemoLayer(self.args)
        self.mamba = MambaLayer(self.args)
    
    def call(self, x):
        res = x
        x = self.norm(x)
#         x = self.demo(x)
        x = self.mamba(x)
        x = x + res
        return x        

In [None]:
def make_model(args: ModelArgs):

    inputs = layers.Input(shape=(args.seq_length,), dtype='int32')
    x = layers.Embedding(input_dim=args.vocab_size, output_dim=args.model_input_dim, input_length=args.seq_length, mask_zero = True)(inputs)
    for i in range(args.num_layers):
        x = ResidualLayer(args)(x)
        x = layers.Dropout(args.dropout_rate)(x)

    x = layers.LayerNormalization(epsilon=1e-5)(x)

    x = layers.GlobalMaxPooling1D()(x)

    x = layers.Dense(1024, activation='relu')(x)
    x = layers.Dropout(0.1)(x)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.1)(x)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.1)(x)

    outputs = tf.keras.layers.Dense(3, activation='sigmoid')(x)

    model = tf.keras.models.Model(inputs = inputs, outputs = outputs)
    
    loss = 'binary_crossentropy'
    optimizer = tf.keras.optimizers.Adam(learning_rate=CFG.LR, weight_decay = CFG.WD)
    weighted_metrics = [tf.keras.metrics.AUC(curve='PR', name = 'avg_precision')]

    model.compile(
        loss=loss,
        optimizer=optimizer,
        weighted_metrics=weighted_metrics,
        )
    return model

In [None]:
def make_tpu_model(args: ModelArgs):
    with strategy.scope():
        inputs = layers.Input(shape=(args.seq_length,), dtype='int32')
        x = layers.Embedding(input_dim=args.vocab_size, output_dim=args.model_input_dim, input_length=args.seq_length, mask_zero = True)(inputs)
        for i in range(args.num_layers):
            x = ResidualLayer(args)(x)
            x = layers.Dropout(args.dropout_rate)(x)

        x = layers.LayerNormalization(epsilon=1e-5)(x)

        x = layers.GlobalMaxPooling1D()(x)

        x = layers.Dense(1024, activation='relu')(x)
        x = layers.Dropout(0.1)(x)
        x = layers.Dense(512, activation='relu')(x)
        x = layers.Dropout(0.1)(x)
        x = layers.Dense(512, activation='relu')(x)
        x = layers.Dropout(0.1)(x)

        outputs = tf.keras.layers.Dense(3, activation='sigmoid')(x)

        model = tf.keras.models.Model(inputs = inputs, outputs = outputs)
        loss = 'binary_crossentropy'
        optimizer = tf.keras.optimizers.Adam(learning_rate=CFG.LR, weight_decay = CFG.WD)
        weighted_metrics = [tf.keras.metrics.AUC(curve='PR', name = 'avg_precision')]

        model.compile(
            loss=loss,
            optimizer=optimizer,
            weighted_metrics=weighted_metrics,
            )
        return model

# Training

In [None]:
es = tf.keras.callbacks.EarlyStopping(patience=5, monitor="val_loss", mode='min', verbose=1)
checkpoint = tf.keras.callbacks.ModelCheckpoint(monitor='val_loss', filepath=f"mamba_model.weights.h5",
                                                    save_best_only=True, save_weights_only=True,
                                                mode='min')
reduce_lr_loss = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.05, patience=5, verbose=1)
model = make_tpu_model(ModelArgs) if CFG.ON_TPU else make_model(ModelArgs)
model.summary()

In [None]:
print('train start')
history = model.fit(
            X_train, y_train,
            validation_data=(X_val, y_val),
            epochs=CFG.EPOCHS,
            callbacks=[checkpoint, reduce_lr_loss, es],
            batch_size=CFG.BATCH_SIZE,
            verbose=1,
        )
    


# Prediction

In [None]:
model.load_weights(f"mamba_model.weights.h5")
    
print('prediction start')
oof = model.predict(X_val, batch_size = 2*CFG.BATCH_SIZE)
print('fold :', fold, 'CV score =', APS(y_val, oof, average = 'micro'))

preds = model.predict(test.to_numpy(), batch_size = 2*CFG.BATCH_SIZE)

# Save output

In [None]:
val = pl.DataFrame()


In [None]:
tst = pd.read_parquet('/kaggle/input/leash-BELKA/test.parquet')
tst['binds'] = 0
tst.loc[tst['protein_name']=='BRD4', 'binds'] = preds[(tst['protein_name']=='BRD4').values, 0]
tst.loc[tst['protein_name']=='HSA', 'binds'] = preds[(tst['protein_name']=='HSA').values, 1]
tst.loc[tst['protein_name']=='sEH', 'binds'] = preds[(tst['protein_name']=='sEH').values, 2]
tst[['id', 'binds']].to_csv('submission.csv', index = False)