In [None]:
from google.colab import drive
drive.mount('/content/drive/')

# Library import

In [None]:
import os
import random
import gc
import numpy as np
import polars as pl
from tqdm.notebook import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset, TensorDataset, DataLoader
import torch.nn.functional as F
from torch.optim import Adam

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import average_precision_score as APS


import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

from torch.distributed import init_process_group, destroy_process_group

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_roader as xpl

import torch_xla.distributed.xla_backend # Registers `xla://` init_method
import torch_xla.experimental.pjrt_backend # Required for torch.distributed on TPU v2 and v3



# Set Config

In [None]:
print(f'torch : {torch.__version__}')
print(f'torch_xla: {torch_xla.__version__}')

In [None]:
if xm.get_xla_supported_device():
    device = xm.xla_device()
else:
    raise Exception('should use xla device to multi process')
    device = "cpu" # Defaults to CPU if NVIDIA GPU/Apple GPU aren't available

print(f"Using device: {device}")

In [None]:
class CFG:
    DEBUG = True
    PREPROCESS = False

    SEED = 2024

    N_ROWS = None if not DEBUG else 100_000

    BATCH_SIZE = 4096
    EPOCHS = 20 if not DEBUG else 3
    NUM_FOLDS = 30 if not DEBUG else 5
    SELECTED_FOLDS = [0]
    SAVE_EVERY = 3


    DATA_SRC = '/kaggle/input/belka-enc-dataset'
    WORK_DIR = '/kaggle/working'

In [None]:
def set_seeds(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)

set_seeds(seed= CFG.SEED)


# Load Data

In [None]:
if CFG.PREPROCESS:
    enc = {'l': 1, 'y': 2, '@': 3, '3': 4, 'H': 5, 'S': 6, 'F': 7, 'C': 8, 'r': 9, 's': 10, '/': 11, 'c': 12, 'o': 13,
           '+': 14, 'I': 15, '5': 16, '(': 17, '2': 18, ')': 19, '9': 20, 'i': 21, '#': 22, '6': 23, '8': 24, '4': 25, '=': 26,
           '1': 27, 'O': 28, '[': 29, 'D': 30, 'B': 31, ']': 32, 'N': 33, '7': 34, 'n': 35, '-': 36}
    train_raw = pd.read_parquet('/kaggle/input/leash-BELKA/train.parquet')
    smiles = train_raw[train_raw['protein_name']=='BRD4']['molecule_smiles'].values
    assert (smiles!=train_raw[train_raw['protein_name']=='HSA']['molecule_smiles'].values).sum() == 0
    assert (smiles!=train_raw[train_raw['protein_name']=='sEH']['molecule_smiles'].values).sum() == 0
    def encode_smile(smile):
        tmp = [enc[i] for i in smile]
        tmp = tmp + [0]*(142-len(tmp))
        return np.array(tmp).astype(np.uint8)

    smiles_enc = joblib.Parallel(n_jobs=96)(joblib.delayed(encode_smile)(smile) for smile in tqdm(smiles))
    smiles_enc = np.stack(smiles_enc)
    train = pd.DataFrame(smiles_enc, columns = [f'enc{i}' for i in range(142)])
    train['bind1'] = train_raw[train_raw['protein_name']=='BRD4']['binds'].values
    train['bind2'] = train_raw[train_raw['protein_name']=='HSA']['binds'].values
    train['bind3'] = train_raw[train_raw['protein_name']=='sEH']['binds'].values
    train.to_parquet('train_enc.parquet')

    test_raw = pd.read_parquet('/kaggle/input/leash-BELKA/test.parquet')
    smiles = test_raw['molecule_smiles'].values

    smiles_enc = joblib.Parallel(n_jobs=96)(joblib.delayed(encode_smile)(smile) for smile in tqdm(smiles))
    smiles_enc = np.stack(smiles_enc)
    test = pd.DataFrame(smiles_enc, columns = [f'enc{i}' for i in range(142)])
    test.to_parquet('test_enc.parquet')
else:
    train = pl.read_parquet(
        source=f'{CFG.DATA_SRC}/train_enc.parquet',
        n_rows=CFG.N_ROWS
    )
    test = pl.read_parquet(
        source=f'{CFG.DATA_SRC}/test_enc.parquet',
        n_rows=CFG.N_ROWS
    )

# Make Model

In [None]:
class SELayer(nn.Module):
    def __init__(self,channel, reduction=16):
        super().__init__()

        self.avg_pooling = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        batch, channel, _ = x.shape

        y = self.avg_pooling(x).view(batch, channel)
        y = self.fc(y).view(batch,channel, 1)
        y = x * y.expand_as(x)

        return y


In [None]:
def make_conv1(in_channels, out_channels):
    conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
                          kernel_size=1, stride=1,padding='same',bias=True)
    nn.init.kaiming_normal_(conv.weight)
    return conv

def make_conv3(in_channels, out_channels, stride=1):
    conv3 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
                          kernel_size=3, stride=1, padding='same', bias=True)
    nn.init.kaiming_normal_(conv3.weight)
    return conv3



class BottleneckReSELayer(nn.Module):

    def __init__(self,input_size, output_size):
        super().__init__()

        self.is_io_same = input_size == output_size

        hidden_size = output_size // 4

        conv1_1 = make_conv1(input_size, hidden_size)
        conv3 = make_conv3(hidden_size,hidden_size)
        conv1_2 = make_conv1(hidden_size, output_size)
        conv1_3 = make_conv1(input_size, output_size)

        self.fc1 = nn.Sequential(
            conv1_1,
            nn.BatchNorm1d(num_features=hidden_size),
            nn.ReLU(inplace=True),
            conv3,
            nn.BatchNorm1d(num_features=hidden_size),
            nn.ReLU(inplace=True),
            conv1_2,
            nn.BatchNorm1d(num_features=output_size),
            SELayer(channel=output_size)
        )

        self.fc2 = nn.Sequential(
            conv1_3,
            nn.BatchNorm1d(num_features=output_size)
        )

    def forward(self, x):
        identity = x
        x = self.fc1(x)
        if not self.is_io_same:
            identity = self.fc2(identity)

        out = x + identity
        out = F.relu(out)

        return out


In [None]:
class ReSEModel(nn.Module):
    def __init__(self, enc_dict_size:int, channels:int, rese_layer_size:list, num_class=3):
        super().__init__()

        self.embedding = nn.Embedding(
            num_embeddings=enc_dict_size,
            embedding_dim=channels,
            padding_idx=0,
        )
        self.btl_rese_layers, out_dim = self._make_rese_layers(channels, rese_layer_size)

        self.global_max_pool = nn.AdaptiveMaxPool1d(1)
        self.mlp_head = nn.Sequential(
            nn.Linear(out_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(1024,512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(512, num_class)
        )

    def forward(self, x):
        x = self.embedding(x).permute(0,2,1)
        x = self.btl_rese_layers(x)
        x = self.global_max_pool(x).squeeze(2)
        x = self.mlp_head(x)

        return x


    def _make_rese_layers(self, channels:int, layer_size:list):
        btl_rese_layers = []
        dim = channels
        for i, num_layers in enumerate(layer_size, 1):
            btl_rese_layers.append(BottleneckReSELayer(dim, channels*i))
            for j in range(num_layers-1):
                btl_rese_layers.append(BottleneckReSELayer(channels*i, channels*i))
            dim = channels * i
        btl_rese_layers = nn.Sequential(*btl_rese_layers)
        return btl_rese_layers, dim

model = ReSEModel(36, 128, [2,2,3]).to(device)

# Make Trainer

In [None]:
def prepare_dataloader(dataset:Dataset, batch_size:int, rank:int, world_size:int, test:bool =False):
    if test:
        sampler=None
    else:
        sampler=DistributedSampler(dataset, num_replicas=world_size, rank=rank,shuffle=True)

    return DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
        samplar=sampler,
        num_workers=2,
        pin_memory=True
    )

In [None]:
class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        eval_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        save_every: int,
    ):
        self.device = xm.xla_device()
        self.model = model.to(device)
        self.train_data = train_data
        self.eval_data = eval_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.model = DDP(self.model, gradient_as_bucket_view=True)
        self.train_data = xpl.MpDeviceLoader(train_data, device)


    def _run_batch(self,X, y):
        loss = None
        with torch_xla.step():
            X, y = X.to(device), y.to(device)
            self.optimizer.zero_grad()
            pred = model(X)
            loss = F.binary_cross_entropy_with_logits(pred, y)
            loss.backward()
            self.optimizer.step()
        return loss

    def _run_epoch(self, epoch_count):
        self.model.train()

        running_loss = 0
        total = 0
        batch_size = len(next(iter(self.train_data))[0])

        print(f"[TPU{self.device}] Epoch {epoch_count} | Batchsize: {batch_size} | Steps: {len(self.train_data)}")
        self.train_data.sampler.set_epoch(epoch_count)

        if xm.is_master_ordinal():
            dataloader = tqdm(self.train_data).set_description(f'Train {epoch_count}')
            for X, y in dataloader:
                loss = self._run_batch(X, y)
                running_loss += loss.item()
                total += y.size(0)
                dataloader.set_postfix(loss=running_loss/total)
        else:
            for X, y in self.train_data:
                self._run_batch(X, y)


    def _run_eval(self, epoch_count:int):
        running_loss = 0
        total = 0

        self.model.eval()
        with torch.no_grad():
            for X, y in tqdm(self.eval_data).set_description(f'Valid {epoch_count}'):
                X, y = X.to(device), y.to(device)
                pred = model(X)
                loss = F.binary_cross_entropy_with_logits(pred, y)
                running_loss += loss.item()
                total += y.size(0)
                dataloader.set_postfix(loss=running_loss/total)
        return running_loss/total


    def _save_checkpoint(self, epoch_count):
        ckp = self.model.module.state_dict()
        PATH = 'checkpoint.pt'
        torch.save(ckp, PATH)
        print(f"Epoch {epoch_count} | Training checkpoint saved at {PATH}")


    def fit(self, max_epochs):
        for epoch in max_epochs:
            self._run_epoch(epoch)

            if xm.is_master_ordinal():
                if epoch % self.save_every == 0 and epoch > 0:
                    self._save_checkpoint(epoch)

                loss = self._run_eval(epoch)
                print(f'eval_loss : {loss}')


In [None]:
def load_dataloader(train_df:pl.DataFrame, rank:int, world_size:int=None):
    FEATURES = [f'enc{i}' for i in range(142)]
    TARGETS = ['bind1', 'bind2', 'bind3']

    skf = StratifiedKFold(n_splits = CFG.NUM_FOLDS, shuffle = True, random_state = 42)

    for fold, (train_idx, valid_idx) in enumerate(skf.split(np.arange(len(train)), train[TARGETS].sum_horizontal())):
        if fold not in CFG.SELECTED_FOLDS:
            continue
        print(f'Fold: {fold}')
        X_train = torch.tensor(train[train_idx, FEATURES].to_numpy(), dtype= torch.int)
        y_train = torch.tensor(train[train_idx, TARGETS].to_numpy(), dtype= torch.float16)
        X_eval = torch.tensor(train[valid_idx, FEATURES].to_numpy(), dtype= torch.int)
        y_eval = torch.tensor(train[valid_idx, TARGETS].to_numpy(), dtype= torch.float16)

        train_dataset = TensorDataset(X_train, y_train)
        valid_dataset = TensorDataset(X_eval, y_eval)

        print('set datasets')
        del X_train,y_train
        gc.collect()

        train_loader = prepare_dataloader(train_dataset, CFG.BATCH_SIZE, rank, world_size)
        print('set train_loader')
        valid_loader = prepare_dataloader(valid_dataset, CFG.BATCH_SIZE, rank, world_size, test=True)
        print('set valid_loader')
        del train_dataset, valid_dataset
        gc.collect()
    return train_loader, valid_loader, X_eval, y_eval


In [None]:
def _mp_fn(rank,world_size):
    init_process_group(backend='xla', init_method='xla://')
    train_loader, valid_loader, X_eval, y_eval = load_dataloader(train_df, rank, world_size)
    model= ReSEModel(37, 64, [1])
    optimizer = optimizer = Adam(params=model.parameters(), lr=0.0001)
    trainer = Trainer(model, train_loader, valid_loader, optimizer, CFG.SAVE_EVERY, device, rank, world_size)
    trainer.fit(CFG.EPOCHS)


In [None]:
xmp.spawn(_mp_fn, args=())