## Disclaimer !!!

This is illustrative code for a high performance transformer. I use my local machine for training and inference. I did not check if this code can be run on kaggle notbook or not.

i achieve lb0.428 (new metric) for single-fold without tricks.

In [1]:
import numpy as np
import polars as pl
import pandas as pd
import os

import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import random
import math
from functools import partial
from concurrent.futures import ProcessPoolExecutor

#model ====================================
# https://gist.github.com/kklemon/98e491ff877c497668c715541f1bf478
# refer to the link above to get fast flash attention wrapper




In [2]:
class Config:
    PREPROCESS = False
    KAGGLE_NOTEBOOK = False
    DEBUG = True
    
    SEED = 42
    EPOCHS = 10
    BATCH_SIZE = 4096
    LR = 1e-3
    WD = 1e-6
    PATIENCE = 10
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    NBR_FOLDS = 15
    SELECTED_FOLDS = [0]
    
    
if Config.DEBUG:
    n_rows = 3*10**4
else:
    n_rows = None
    


In [3]:
if Config.KAGGLE_NOTEBOOK:
    RAW_DIR = "/kaggle/input/leash-BELKA/"
    PROCESSED_DIR = "/kaggle/input/belka-enc-dataset"
    OUTPUT_DIR = ""
    MODEL_DIR = ""
else:
    RAW_DIR = "../data/raw/"
    PROCESSED_DIR = "../data/processed/"
    OUTPUT_DIR = "../data/tf-dataset-original/"
    MODEL_DIR = "../models/"

TRAIN_DATA_NAME = "train_enc.parquet"

In [4]:
def set_seeds(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)

set_seeds(seed=Config.SEED)


In [5]:

#tokenization ====================================

# https://www.ascii-code.com/
MOLECULE_DICT = {
    'l': 1, 'y': 2, '@': 3, '3': 4, 'H': 5, 'S': 6, 'F': 7, 'C': 8, 'r': 9, 's': 10, '/': 11, 'c': 12, 'o': 13,
    '+': 14, 'I': 15, '5': 16, '(': 17, '2': 18, ')': 19, '9': 20, 'i': 21, '#': 22, '6': 23, '8': 24, '4': 25,
    '=': 26, '1': 27, 'O': 28, '[': 29, 'D': 30, 'B': 31, ']': 32, 'N': 33, '7': 34, 'n': 35, '-': 36
}
MAX_MOLECULE_ID = np.max(list(MOLECULE_DICT.values()))
VOCAB_SIZE = MAX_MOLECULE_ID + 10
UNK = 255  # disallow: will cuase error
BOS = MAX_MOLECULE_ID + 1
EOS = MAX_MOLECULE_ID + 2
# rest are reserved
PAD = 0
MAX_LENGTH = 160

MOLECULE_LUT = np.full(256, fill_value=UNK, dtype=np.uint8)
for k, v in MOLECULE_DICT.items():
    ascii = ord(k)
    MOLECULE_LUT[ascii] = v


# SMILESトークン化関数
def make_token(s):
    MOLECULE_LUT = np.full(256, fill_value=255, dtype=np.uint8)
    for k, v in MOLECULE_DICT.items():
        MOLECULE_LUT[ord(k)] = v
    t = np.frombuffer(s.encode(), np.uint8)
    t = MOLECULE_LUT[t]
    t = t.tolist()
    L = len(t) + 2
    token_id = [37] + t + [38] + [0] * (160 - L)
    token_mask = [1] * L + [0] * (160 - L)
    return token_id, token_mask


# トークン列とマスク列をデータフレームに変換
def expand_tokens_to_df(tokens, max_length):
    token_columns = {f"enc{i}": [] for i in range(max_length)}
    mask_columns = {f"enc{i}": [] for i in range(max_length)}

    for token_id, mask in tokens:
        for i in range(max_length):
            token_columns[f"enc{i}"].append(token_id[i] if i < len(token_id) else None)
            mask_columns[f"enc{i}"].append(mask[i] if i < len(mask) else None)

    token_df = pd.DataFrame(token_columns)
    mask_df = pd.DataFrame(mask_columns)

    return token_df, mask_df

In [7]:
for i in range(10):
    input_path = os.path.join("../data/shuffled-dataset/", f"train_{i}.parquet")
    train_raw = pl.read_parquet(input_path, n_rows=n_rows, columns=["molecule_smiles", "bind1", "bind2", "bind3"]).to_pandas()
    print("data loaded", input_path, train_raw.shape)
    smiles = train_raw["molecule_smiles"]
    tokens = smiles.apply(make_token)
    train, mask_df = expand_tokens_to_df(tokens, 160)
    train["bind1"] = train_raw["bind1"]
    train["bind2"] = train_raw["bind2"]
    train["bind3"] = train_raw["bind3"]

    # save
    path = os.path.join(OUTPUT_DIR, f"train_enc_{i}.parquet")
    # train.to_parquet(path)
    # mask_df.to_parquet(os.path.join(OUTPUT_DIR, f"train_mask_{i}.parquet"))
    print("data saved", path)

data loaded ../data/shuffled-dataset/train_0.parquet (30000, 4)


KeyboardInterrupt: 

In [30]:
train = pl.read_parquet(os.path.join(OUTPUT_DIR, "train_enc_0.parquet"))
train_mask = pl.read_parquet(os.path.join(OUTPUT_DIR, "train_mask_0.parquet"))

train

enc0,enc1,enc2,enc3,enc4,enc5,enc6,enc7,enc8,enc9,enc10,enc11,enc12,enc13,enc14,enc15,enc16,enc17,enc18,enc19,enc20,enc21,enc22,enc23,enc24,enc25,enc26,enc27,enc28,enc29,enc30,enc31,enc32,enc33,enc34,enc35,enc36,…,enc126,enc127,enc128,enc129,enc130,enc131,enc132,enc133,enc134,enc135,enc136,enc137,enc138,enc139,enc140,enc141,enc142,enc143,enc144,enc145,enc146,enc147,enc148,enc149,enc150,enc151,enc152,enc153,enc154,enc155,enc156,enc157,enc158,enc159,bind1,bind2,bind3
i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,…,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64
37,28,26,8,17,33,29,30,2,32,19,29,8,3,3,5,32,17,8,12,27,12,12,12,17,7,19,12,17,7,19,12,27,19,33,12,27,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
37,8,26,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,8,12,18,12,12,12,35,12,18,36,35,18,12,12,35,12,18,8,19,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
37,8,6,17,26,28,19,17,26,28,19,33,8,27,8,8,8,8,27,8,33,12,27,35,12,17,33,8,8,18,17,28,19,8,8,6,8,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
37,28,26,8,17,33,29,30,2,32,19,29,8,3,3,5,32,17,8,12,27,12,12,12,17,8,1,19,12,17,8,1,19,12,27,19,33,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
37,8,12,27,12,12,17,33,12,18,35,12,17,33,8,29,8,3,3,5,32,4,28,8,8,33,17,8,19,29,8,3,5,32,4,12,4,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
37,28,26,8,17,33,29,30,2,32,19,29,8,3,5,32,17,8,8,8,27,8,8,8,8,8,27,19,33,12,27,35,12,17,33,12,18,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
37,8,28,12,27,12,12,12,12,17,36,12,18,12,12,17,33,12,4,35,12,17,33,8,8,12,25,12,12,35,13,25,19,35,12,17,33,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
37,8,28,8,17,26,28,19,12,27,12,12,17,33,12,18,35,12,17,33,12,4,12,35,12,17,8,1,19,12,12,4,8,17,26,28,19,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
37,8,29,8,3,3,5,32,17,28,8,12,27,12,12,12,12,12,27,19,29,8,3,5,32,17,33,12,27,35,12,17,33,8,8,12,18,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [39]:
train

Unnamed: 0,enc0,enc1,enc2,enc3,enc4,enc5,enc6,enc7,enc8,enc9,...,enc153,enc154,enc155,enc156,enc157,enc158,enc159,bind1,bind2,bind3
0,37,8,22,8,8,28,12,27,12,12,...,0,0,0,0,0,0,0,0,0,0
1,37,8,22,8,8,28,12,27,12,12,...,0,0,0,0,0,0,0,0,0,0
2,37,8,22,8,8,28,12,27,12,12,...,0,0,0,0,0,0,0,0,0,0
3,37,8,22,8,8,28,12,27,12,12,...,0,0,0,0,0,0,0,0,0,0
4,37,8,22,8,8,28,12,27,12,12,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,37,8,22,8,8,29,8,3,3,5,...,0,0,0,0,0,0,0,0,0,0
99996,37,8,22,8,8,29,8,3,3,5,...,0,0,0,0,0,0,0,0,0,0
99997,37,8,22,8,8,29,8,3,3,5,...,0,0,0,0,0,0,0,0,0,0
99998,37,8,22,8,8,29,8,3,3,5,...,0,0,0,0,0,0,0,0,0,0


for modeling, you need flash attnetion + torch compile to make it run fast.
I train on A6000 gpu with batch size= 2000.

In [43]:


class FlashAttentionTransformerEncoder(nn.Module):
    def __init__(
        self,
        dim_model,
        num_layers,
        num_heads=None,
        dim_feedforward=None,
        dropout=0.0,
        norm_first=False,
        activation=F.gelu,
        rotary_emb_dim=0,
    ):
        super().__init__()

        try:
            from flash_attn.bert_padding import pad_input, unpad_input
            from flash_attn.modules.block import Block
            from flash_attn.modules.mha import MHA
            from flash_attn.modules.mlp import Mlp
        except ImportError:
            raise ImportError('Please install flash_attn from https://github.com/Dao-AILab/flash-attention')
        
        self._pad_input = pad_input
        self._unpad_input = unpad_input

        if num_heads is None:
            num_heads = dim_model // 64
        
        if dim_feedforward is None:
            dim_feedforward = dim_model * 4

        if isinstance(activation, str):
            activation = {
                'relu': F.relu,
                'gelu': F.gelu
            }.get(activation)

            if activation is None:
                raise ValueError(f'Unknown activation {activation}')

        mixer_cls = partial(
            MHA,
            num_heads=num_heads,
            use_flash_attn=True,
            rotary_emb_dim=rotary_emb_dim
        )

        mlp_cls = partial(Mlp, hidden_features=dim_feedforward)

        self.layers = nn.ModuleList([
            Block(
                dim_model,
                mixer_cls=mixer_cls,
                mlp_cls=mlp_cls,
                resid_dropout1=dropout,
                resid_dropout2=dropout,
                prenorm=norm_first,
            ) for _ in range(num_layers)
        ])
    
    def forward(self, x, src_key_padding_mask=None):
        batch, seqlen = x.shape[:2]

        if src_key_padding_mask is None:
            for layer in self.layers:
                x = layer(x)
        else:
            x, indices, cu_seqlens, max_seqlen_in_batch = self._unpad_input(x, ~src_key_padding_mask)
            
            for layer in self.layers:
                x = layer(x, mixer_kwargs=dict(
                    cu_seqlens=cu_seqlens,
                    max_seqlen=max_seqlen_in_batch
                ))
      

            x = self._pad_input(x, indices, batch, seqlen)
            
        return x

class Conv1dBnRelu(nn.Module):
    def __init__(self, in_channels, out_channels, is_bn, **kwargs):
        super(Conv1dBnRelu, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, **kwargs)
        self.is_bn = is_bn
        if self.is_bn:
            self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        if self.is_bn:
            x = self.bn(x)
        return self.relu(x)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=256):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[ :,:x.size(1)]
        
        return x

class Net(nn.Module):
    def __init__(self, ):
        super().__init__()

        embed_dim=512

        self.output_type = ['infer', 'loss']
        self.pe = PositionalEncoding(embed_dim,max_len=256)
        self.embedding = nn.Embedding(VOCAB_SIZE, 64, padding_idx=PAD)
        self.conv_embedding = nn.Sequential(
            Conv1dBnRelu(64, embed_dim, kernel_size=3,stride=1,padding=1, is_bn=True),
        )  #just a simple conv1d-bn-relu . for bn use: BN = partial(nn.BatchNorm1d, eps=5e-3,momentum=0.1)

        self.tx_encoder = FlashAttentionTransformerEncoder(
            dim_model=embed_dim,
            num_heads=8,
            dim_feedforward=embed_dim*4,
            dropout=0.1,
            norm_first=False,
            activation=F.gelu,
            rotary_emb_dim=0,
            num_layers=7,
        )

        self.bind = nn.Sequential(
            nn.Linear(embed_dim, 3),
        )


    def forward(self, batch):
        smiles_token_id   = batch['smiles_token_id'].long()
        smiles_token_mask = batch['smiles_token_mask'].long()
        B, L = smiles_token_id.shape
        x = self.embedding(smiles_token_id)
        x = x.permute(0,2,1).float()
        x = self.conv_embedding(x)
        x = x.permute(0,2,1).contiguous()

        x = self.pe(x)
        z = self.tx_encoder(
            x=x,
            src_key_padding_mask=smiles_token_mask==0,
        )


        m = smiles_token_mask.unsqueeze(-1).float()
        pool = (z*m).sum(1)/m.sum(1)
        bind = self.bind(pool)

        # --------------------------
        output = {}
        if 'loss' in self.output_type:
            target = batch['bind']
            output['bce_loss'] = F.binary_cross_entropy_with_logits(bind.float(), target.float())

        if 'infer' in self.output_type:
            output['bind'] = torch.sigmoid(bind)

        return output

    
#-------------------------------------
#dummy code to check net
def run_check_net():
    max_length = MAX_LENGTH
    batch_size = 500

    batch = {
        'smiles_token_id': torch.from_numpy(np.random.choice(VOCAB_SIZE, (batch_size, max_length))).byte().cuda(),
        'smiles_token_mask': torch.from_numpy(np.random.choice(2, (batch_size, max_length))).byte().cuda(),
        'bind': torch.from_numpy(np.random.choice(2, (batch_size, 3))).float().cuda(),
    }
     
    net = Net().cuda()
    #print(net)
    #net.train()

    with torch.no_grad():
        with torch.cuda.amp.autocast(enabled=True): # dtype=torch.float16):
            output = net(batch)

    # ---
    print('batch')
    for k, v in batch.items():
        if k=='idx':
            print(f'{k:>32} : {len(v)} ')
        else:
            print(f'{k:>32} : {v.shape} ')

    print('output')
    for k, v in output.items():
        if 'loss' not in k:
            print(f'{k:>32} : {v.shape} ')
    print('loss')
    for k, v in output.items():
        if 'loss' in k:
            print(f'{k:>32} : {v.item()} ')



In [30]:
train = pl.read_parquet(os.path.join("../data/chuncked-dataset/", "local_train_enc_0.parquet"), n_rows=None).to_pandas()

In [32]:
mask_df = (train.values > 0).astype(int)
mask_df = pd.DataFrame(mask_df, columns=train.columns)

In [33]:
mask_df

Unnamed: 0,enc0,enc1,enc2,enc3,enc4,enc5,enc6,enc7,enc8,enc9,...,enc135,enc136,enc137,enc138,enc139,enc140,enc141,bind1,bind2,bind3
0,1,1,1,1,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
1,1,1,1,1,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
2,1,1,1,1,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
3,1,1,1,1,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
4,1,1,1,1,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9743140,1,1,1,1,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
9743141,1,1,1,1,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
9743142,1,1,1,1,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
9743143,1,1,1,1,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0


In [15]:
train = pl.read_parquet(os.path.join("../data/chuncked-dataset/", "local_train_enc_0.parquet"), n_rows=n_rows).to_pandas()
mask_df = pl.read_parquet(os.path.join("../data/chuncked-dataset/", "local_train_mask_0.parquet"), n_rows=n_rows).to_pandas()

In [16]:
FEATURES = [f'enc{i}' for i in range(140)]
TARGETS = ['bind1', 'bind2', 'bind3']


# バッチに分ける
def get_batch(df, mask_df, batch_size=32):
    for i in range(0, len(df), batch_size):
        batch = {}
        batch['smiles_token_id'] = torch.from_numpy(df[FEATURES].values[i:i+batch_size]).byte().cuda()
        batch['smiles_token_mask'] = torch.from_numpy(mask_df[FEATURES].values[i:i+batch_size]).byte().cuda()
        batch['bind'] = torch.from_numpy(df[TARGETS].values[i:i+batch_size]).float().cuda()
        yield batch

batch = get_batch(train, mask_df, batch_size=4096)

In [18]:


# 学習
def run_train():
    net = Net().cuda()
    net.train()

    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    for epoch in range(30):
        for i, batch in enumerate(get_batch(train, mask_df, batch_size=4096)):
            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=True):
                output = net(batch)
                
            output['bce_loss'].backward()
            optimizer.step()
            if i%10==0:
                print(f'epoch:{epoch} iter:{i} loss:{output["bce_loss"].item()}')
     
    
run_train()

epoch:0 iter:0 loss:0.6746259927749634
epoch:1 iter:0 loss:0.1499561071395874
epoch:2 iter:0 loss:0.19734856486320496
epoch:3 iter:0 loss:0.032289013266563416
epoch:4 iter:0 loss:0.04275503382086754


KeyboardInterrupt: 

In [13]:
def check_net():
    max_length = MAX_LENGTH


    batch = {
        'smiles_token_id': torch.from_numpy(train[FEATURES].values).byte().cuda(),
        'smiles_token_mask': torch.from_numpy(mask_df.values).byte().cuda(),
        'bind': torch.from_numpy(train[TARGETS].values).float().cuda(),
    }
     
    net = Net().cuda()
    #print(net)
    #net.train()

    with torch.no_grad():
        with torch.cuda.amp.autocast(enabled=True): # dtype=torch.float16):
            output = net(batch)

    # ---
    print('batch')
    for k, v in batch.items():
        if k=='idx':
            print(f'{k:>32} : {len(v)} ')
        else:
            print(f'{k:>32} : {v.shape} ')

    print('output')
    for k, v in output.items():
        if 'loss' not in k:
            print(f'{k:>32} : {v.shape} ')
    print('loss')
    for k, v in output.items():
        if 'loss' in k:
            print(f'{k:>32} : {v.item()} ')


check_net()


NameError: name 'train' is not defined

I also tried other next generation seq models.  
e.g. mamba, xlstm, Griffin (deepmind's attention RNN)  

I am looking for a faster alternative. But so far transformer + flash attnetion2 is still the fastest  (for small dim and num of layers). for performance, i think these models are smiliar.

In [11]:


#xlstm model
# offical repo: https://github.com/NX-AI/xlstm
from xlstm import (
    xLSTMBlockStack,
    xLSTMBlockStackConfig,
    mLSTMBlockConfig,
    mLSTMLayerConfig,
    sLSTMBlockConfig,
    sLSTMLayerConfig,
    FeedForwardConfig,
)

class Xlstm(nn.Module):
    def __init__(self, ):
        super().__init__()

        embed_dim = 256

        self.output_type = ['infer', 'loss']
        self.embedding = nn.Embedding(VOCAB_SIZE, 64, padding_idx=0)

        self.conv_embedding = nn.Sequential(
            Conv1dBnRelu(64, embed_dim, kernel_size=3, stride=1, padding=1, is_bn=False),
        )

        self.lstm_encoder = xLSTMBlockStack(
            config = xLSTMBlockStackConfig(
                mlstm_block=mLSTMBlockConfig(
                    mlstm=mLSTMLayerConfig(
                        conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
                    )
                ),
                slstm_block=sLSTMBlockConfig(
                    slstm=sLSTMLayerConfig(
                        backend='cuda',
                        batch_size=64,
                        num_heads=8,
                        conv1d_kernel_size=4,
                        bias_init='powerlaw_blockdependent',
                    ),
                    feedforward=FeedForwardConfig(proj_factor=1.3, act_fn='gelu'),
                ),
                context_length=MAX_LENGTH,
                num_blocks=6,
                embedding_dim=embed_dim,
                slstm_at=[1],
            )
        )

        self.bind = nn.Sequential(
            nn.Linear(embed_dim, 3),
        )


    def forward(self, batch):
        smiles_token_id = batch['smiles_token_id'].long()
        B, L = smiles_token_id.shape

        x = self.embedding(smiles_token_id)
        x = x.permute(0,2,1).float()
        x = self.conv_embedding(x)
        x = x.permute(0,2,1).contiguous()

        x = self.lstm_encoder(x)
        last = x.mean(1)
        bind = self.bind(last)

        # --------------------------
        output = {}
        if 'loss' in self.output_type:
            target = batch['bind']
            output['bce_loss'] = F.binary_cross_entropy_with_logits(bind.float(), target.float())

        if 'infer' in self.output_type:
            output['bind'] = torch.sigmoid(bind)

        return output



In [13]:

# 学習
def run_train():
    net = Xlstm().cuda()
    net.train()

    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    for epoch in range(10):
        for i, batch in enumerate(get_batch(train, mask_df, batch_size=4096)):
            optimizer.zero_grad()
            output = net(batch)
            output['bce_loss'].backward()
            optimizer.step()
            if i%10==0:
                print(f'epoch:{epoch} iter:{i} loss:{output["bce_loss"].item()}')
     

run_train()

: 

In [71]:

# official repo: https://github.com/state-spaces/mamba

#https://github.com/state-spaces/mamba/issues/355
# there is a bug? i cannot try mamba2

from mamba_ssm import Mamba
from torch.nn.init import xavier_uniform_
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F

"TODO: mambaになっていない．線形層"
class Mamba1(nn.Module):
    def __init__(self, d_model, d_intermediate, norm_epsilon, rms_norm, residual_in_fp32, fused_add_norm):
        super(Mamba1, self).__init__()
        self.d_model = d_model
        self.d_intermediate = d_intermediate
        self.norm_epsilon = norm_epsilon
        self.rms_norm = rms_norm
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm

        # ネットワークレイヤーの定義
        self.linear1 = nn.Linear(d_model, d_intermediate)
        self.linear2 = nn.Linear(d_intermediate, d_model)
        self.norm = nn.LayerNorm(d_model, eps=norm_epsilon)

    def forward(self, x, residual):
        if self.residual_in_fp32 and residual is not None:
            residual = residual.float()

        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)

        if self.fused_add_norm:
            x = self.norm(x + residual) if residual is not None else self.norm(x)
        else:
            x = x + residual if residual is not None else x
            x = self.norm(x)

        return x, residual

def create_block(embed_dim, d_intermediate, ssm_cfg, attn_layer_idx, attn_cfg, norm_epsilon, rms_norm, residual_in_fp32, fused_add_norm, layer_idx):
    # ssm_cfg パラメータを解析して適切な層を選択
    if ssm_cfg['layer'] == 'Mamba1':
        return Mamba1(
            d_model=embed_dim,
            d_intermediate=d_intermediate,
            norm_epsilon=norm_epsilon,
            rms_norm=rms_norm,
            residual_in_fp32=residual_in_fp32,
            fused_add_norm=fused_add_norm
        )
    else:
        raise ValueError("Unsupported layer type specified in ssm_cfg")

# この create_block 関数は、与えられたパラメータに基づいて、適切な Mamba1 モジュールインスタンスを作成します。

class Net(nn.Module):
    def __init__(self, ):
        super().__init__()

        embed_dim=256
        num_layer=6
        self.output_type = ['infer', 'loss']
        self.embedding = nn.Embedding(VOCAB_SIZE, 64, padding_idx=PAD)
        self.pe = PositionalEncoding(embed_dim,max_len=256)

        self.conv_embedding = nn.Sequential(
            Conv1dBnRelu(64, embed_dim, kernel_size=3,stride=1,padding=1, is_bn=True),
        )

        self.mamba_encoder = nn.ModuleList(
            [
                create_block(
                    embed_dim,
                    d_intermediate=embed_dim//2,
                    ssm_cfg={'layer': 'Mamba1'},
                    attn_layer_idx=None,
                    attn_cfg=None,
                    norm_epsilon=1e-4,
                    rms_norm=1e-4,
                    residual_in_fp32=False,
                    fused_add_norm=True,
                    layer_idx=i,
                )
                for i in range(num_layer)
            ])

        self.norm_f = nn.LayerNorm ( #RMSNorm
            embed_dim, eps=1e-4
        )

        self.bind = nn.Sequential(
            nn.Linear(embed_dim, 3),
        )

        self.apply(
            partial(
                _init_weights,
                n_layer=num_layer,
                n_residuals_per_layer=2,
            )
        )

    def forward(self, batch):
        smiles_token_id = batch['smiles_token_id'].long()
        smiles_token_mask = batch['smiles_token_mask'].long()
        B, L  = smiles_token_id.shape

        x = self.embedding(smiles_token_id)
        x = x.permute(0,2,1).float()
        x = self.conv_embedding(x)
        x = x.permute(0,2,1).contiguous()
      
        hidden, residual = x, None
        for mamba in self.mamba_encoder:
            hidden, residual = mamba(
                hidden, residual
            )
            hidden = F.dropout(hidden,p=0.1, training=self.training)

        #z=hidden
        if residual is not None:
            if residual_in_fp32:
                residual = residual.to(torch.float32)
            hidden += residual  # 残差接続

        # LayerNorm を適用
        z = self.norm_f(hidden)

        #pool = z.mean(1)
        m = smiles_token_mask.unsqueeze(2).float()
        pool = (z*m).sum(1)/m.sum(1)
        bind = self.bind(pool)

        # --------------------------
        output = {}
        if 'loss' in self.output_type:
            target = batch['bind']
            output['bce_loss'] = F.binary_cross_entropy_with_logits(bind.float(), target.float())

        if 'infer' in self.output_type:
            output['bind'] = torch.sigmoid(bind)

        return output

In [72]:
 
#-------------------------------------
#dummy code to check net
def run_check_net():
    max_length = MAX_LENGTH
    batch_size = 500

    batch = {
        'smiles_token_id': torch.from_numpy(np.random.choice(VOCAB_SIZE, (batch_size, max_length))).byte().cuda(),
        'smiles_token_mask': torch.from_numpy(np.random.choice(2, (batch_size, max_length))).byte().cuda(),
        'bind': torch.from_numpy(np.random.choice(2, (batch_size, 3))).float().cuda(),
    }
     
    net = Net().cuda()
    #print(net)
    #net.train()

    with torch.no_grad():
        with torch.cuda.amp.autocast(enabled=True): # dtype=torch.float16):
            output = net(batch)

    # ---
    print('batch')
    for k, v in batch.items():
        if k=='idx':
            print(f'{k:>32} : {len(v)} ')
        else:
            print(f'{k:>32} : {v.shape} ')

    print('output')
    for k, v in output.items():
        if 'loss' not in k:
            print(f'{k:>32} : {v.shape} ')
    print('loss')
    for k, v in output.items():
        if 'loss' in k:
            print(f'{k:>32} : {v.item()} ')

run_check_net()

batch
                 smiles_token_id : torch.Size([500, 160]) 
               smiles_token_mask : torch.Size([500, 160]) 
                            bind : torch.Size([500, 3]) 
output
                            bind : torch.Size([500, 3]) 
loss
                        bce_loss : 0.8870716094970703 


In [70]:
# 学習
def run_train():
    net = Net().cuda()
    net.train()

    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    for epoch in range(10):
        for i, batch in enumerate(get_batch(train, mask_df, batch_size=4096)):
            optimizer.zero_grad()
            output = net(batch)
            output['bce_loss'].backward()
            optimizer.step()
            if i%10==0:
                print(f'epoch:{epoch} iter:{i} loss:{output["bce_loss"].item()}')
     

run_train()

epoch:0 iter:0 loss:0.46796882152557373
epoch:1 iter:0 loss:0.03461839258670807
epoch:2 iter:0 loss:0.03337155282497406
epoch:3 iter:0 loss:0.03157204017043114
epoch:4 iter:0 loss:0.03174734115600586
epoch:5 iter:0 loss:0.03172867000102997
epoch:6 iter:0 loss:0.031632088124752045
epoch:7 iter:0 loss:0.031668342649936676
epoch:8 iter:0 loss:0.03160133212804794
epoch:9 iter:0 loss:0.031643111258745193


#### 