## Disclaimer !!!

This is illustrative code for a high performance transformer. I use my local machine for training and inference. I did not check if this code can be run on kaggle notbook or not.

i achieve lb0.428 (new metric) for single-fold without tricks.

In [10]:
%%script false --no-raise-error

#tokenization ====================================

# https://www.ascii-code.com/
MOLECULE_DICT = {
    'l': 1, 'y': 2, '@': 3, '3': 4, 'H': 5, 'S': 6, 'F': 7, 'C': 8, 'r': 9, 's': 10, '/': 11, 'c': 12, 'o': 13,
    '+': 14, 'I': 15, '5': 16, '(': 17, '2': 18, ')': 19, '9': 20, 'i': 21, '#': 22, '6': 23, '8': 24, '4': 25,
    '=': 26, '1': 27, 'O': 28, '[': 29, 'D': 30, 'B': 31, ']': 32, 'N': 33, '7': 34, 'n': 35, '-': 36
}
MAX_MOLECULE_ID = np.max(list(MOLECULE_DICT.values()))
VOCAB_SIZE = MAX_MOLECULE_ID + 10
UNK = 255  # disallow: will cuase error
BOS = MAX_MOLECULE_ID + 1
EOS = MAX_MOLECULE_ID + 2
# rest are reserved
PAD = 0
MAX_LENGTH = 160

MOLECULE_LUT = np.full(256, fill_value=UNK, dtype=np.uint8)
for k, v in MOLECULE_DICT.items():
    ascii = ord(k)
    MOLECULE_LUT[ascii] = v

        

def make_token(s):
    
    t = np.frombuffer(s, np.uint8)
    t = MOLECULE_LUT[t]
    t = t.tolist()

    L = len(t) + 2
    token_id = [BOS] + t + [EOS] + [PAD] * (MAX_LENGTH - L)
    token_mask = [1] * L + [0] * (MAX_LENGTH - L)
        
    return token_id, token_mask


#note this is byte-string!!!
# string to byte-string: str.encode(s)
smiles = b'C#CCOc1ccc(CNc2nc(NCc3cccc(Br)n3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1'  
token_id, token_mask = make_token(smiles)

for modeling, you need flash attnetion + torch compile to make it run fast.
I train on A6000 gpu with batch size= 2000.

In [31]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np

#model ====================================
# https://gist.github.com/kklemon/98e491ff877c497668c715541f1bf478
# refer to the link above to get fast flash attention wrapper

class FlashAttentionTransformerEncoder(nn.Module):
    def __init__(
            self,
            dim_model,
            num_layers,
            num_heads=None,
            dim_feedforward=None,
            dropout=0.0,
            norm_first=False,
            activation=F.gelu,
            rotary_emb_dim=0,
    ):
        super().__init__()
        ... 
class Conv1dBnRelu(nn.Module):
    def __init__(self, in_channels, out_channels, is_bn, **kwargs):
        super(Conv1dBnRelu, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, **kwargs)
        self.is_bn = is_bn
        if self.is_bn:
            self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        if self.is_bn:
            x = self.bn(x)
        return self.relu(x)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=256):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[ :,:x.size(1)]
        
        return x

class Net(nn.Module):
    def __init__(self, ):
        super().__init__()

        embed_dim=512

        self.output_type = ['infer', 'loss']
        self.pe = PositionalEncoding(embed_dim,max_len=256)
        self.embedding = nn.Embedding(VOCAB_SIZE, 64, padding_idx=PAD)
        self.conv_embedding = nn.Sequential(
            Conv1dBnRelu(64, embed_dim, kernel_size=3,stride=1,padding=1, is_bn=True),
        )  #just a simple conv1d-bn-relu . for bn use: BN = partial(nn.BatchNorm1d, eps=5e-3,momentum=0.1)

        self.tx_encoder = FlashAttentionTransformerEncoder(
            dim_model=embed_dim,
            num_heads=8,
            dim_feedforward=embed_dim*4,
            dropout=0.1,
            norm_first=False,
            activation=F.gelu,
            rotary_emb_dim=0,
            num_layers=7,
        )

        self.bind = nn.Sequential(
            nn.Linear(embed_dim, 3),
        )


    def forward(self, batch):
        smiles_token_id   = batch['smiles_token_id'].long()
        smiles_token_mask = batch['smiles_token_mask'].long()
        B, L = smiles_token_id.shape

        x = self.embedding(smiles_token_id)
        x = x.permute(0,2,1).float()
        x = self.conv_embedding(x)
        x = x.permute(0,2,1).contiguous()

        x = self.pe(x)
        z = self.tx_encoder(
            x=x,
            src_key_padding_mask=smiles_token_mask==0,
        )

        m = smiles_token_mask.unsqueeze(-1).float()
        pool = (z*m).sum(1)/m.sum(1)
        bind = self.bind(pool)

        # --------------------------
        output = {}
        if 'loss' in self.output_type:
            target = batch['bind']
            output['bce_loss'] = F.binary_cross_entropy_with_logits(bind.float(), target.float())

        if 'infer' in self.output_type:
            output['bind'] = torch.sigmoid(bind)

        return output

    
#-------------------------------------
#dummy code to check net
def run_check_net():
    max_length = MAX_LENGTH
    batch_size = 500

    batch = {
        'smiles_token_id': torch.from_numpy(np.random.choice(VOCAB_SIZE, (batch_size, max_length))).byte().cuda(),
        'smiles_token_mask': torch.from_numpy(np.random.choice(2, (batch_size, max_length))).byte().cuda(),
        'bind': torch.from_numpy(np.random.choice(2, (batch_size, 3))).float().cuda(),
    }
     
    net = Net().cuda()
    #print(net)
    #net.train()

    with torch.no_grad():
        with torch.cuda.amp.autocast(enabled=True): # dtype=torch.float16):
            output = net(batch)

    # ---
    print('batch')
    for k, v in batch.items():
        if k=='idx':
            print(f'{k:>32} : {len(v)} ')
        else:
            print(f'{k:>32} : {v.shape} ')

    print('output')
    for k, v in output.items():
        if 'loss' not in k:
            print(f'{k:>32} : {v.shape} ')
    print('loss')
    for k, v in output.items():
        if 'loss' in k:
            print(f'{k:>32} : {v.item()} ')

In [32]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math


MAX_LENGTH = 160
VOCAB_SIZE = 256
PAD = 0


run_check_net()

TypeError: _forward_unimplemented() got an unexpected keyword argument 'x'

I also tried other next generation seq models.  
e.g. mamba, xlstm, Griffin (deepmind's attention RNN)  

I am looking for a faster alternative. But so far transformer + flash attnetion2 is still the fastest  (for small dim and num of layers). for performance, i think these models are smiliar.

In [12]:
%%script false --no-raise-error

#xlstm model
# offical repo: https://github.com/NX-AI/xlstm

class Net(nn.Module):
    def __init__(self, ):
        super().__init__()

        embed_dim = 256

        self.output_type = ['infer', 'loss']
        self.embedding = nn.Embedding(VOCAB_SIZE, 64, padding_idx=0)

        self.conv_embedding = nn.Sequential(
            Conv1dBnRelu(64, embed_dim, kernel_size=3, stride=1, padding=1, is_bn=False),
        )

        self.lstm_encoder = xLSTMBlockStack(
            config = xLSTMBlockStackConfig(
                mlstm_block=mLSTMBlockConfig(
                    mlstm=mLSTMLayerConfig(
                        conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
                    )
                ),
                slstm_block=sLSTMBlockConfig(
                    slstm=sLSTMLayerConfig(
                        backend='cuda',
                        batch_size=64,
                        num_heads=8,
                        conv1d_kernel_size=4,
                        bias_init='powerlaw_blockdependent',
                    ),
                    feedforward=FeedForwardConfig(proj_factor=1.3, act_fn='gelu'),
                ),
                context_length=MAX_LENGTH,
                num_blocks=6,
                embedding_dim=embed_dim,
                slstm_at=[1],
            )
        )

        self.bind = nn.Sequential(
            nn.Linear(embed_dim, 3),
        )


    def forward(self, batch):
        smiles_token_id = batch['smiles_token_id'].long()
        B, L = smiles_token_id.shape

        x = self.embedding(smiles_token_id)
        x = x.permute(0,2,1).float()
        x = self.conv_embedding(x)
        x = x.permute(0,2,1).contiguous()

        x = self.lstm_encoder(x)
        last = x.mean(1)
        bind = self.bind(last)

        # --------------------------
        output = {}
        if 'loss' in self.output_type:
            target = batch['bind']
            output['bce_loss'] = F.binary_cross_entropy_with_logits(bind.float(), target.float())

        if 'infer' in self.output_type:
            output['bind'] = torch.sigmoid(bind)

        return output


In [13]:
%%script false --no-raise-error
# official repo: https://github.com/state-spaces/mamba

#https://github.com/state-spaces/mamba/issues/355
# there is a bug? i cannot try mamba2

class Net(nn.Module):
    def __init__(self, ):
        super().__init__()

        embed_dim=256
        num_layer=6
        self.output_type = ['infer', 'loss']
        self.embedding = nn.Embedding(VOCAB_SIZE, 64, padding_idx=PAD)
        self.pe = PositionalEncoding(embed_dim,max_len=256)

        self.conv_embedding = nn.Sequential(
            Conv1dBnRelu(64, embed_dim, kernel_size=3,stride=1,padding=1, is_bn=True),
        )

        self.mamba_encoder = nn.ModuleList(
            [
                create_block(
                    embed_dim,
                    d_intermediate=embed_dim//2,
                    ssm_cfg={'layer': 'Mamba1'},
                    attn_layer_idx=None,
                    attn_cfg=None,
                    norm_epsilon=1e-4,
                    rms_norm=1e-4,
                    residual_in_fp32=False,
                    fused_add_norm=True,
                    layer_idx=i,
                )
                for i in range(num_layer)
            ])

        self.norm_f = nn.LayerNorm ( #RMSNorm
            embed_dim, eps=1e-4
        )

        self.bind = nn.Sequential(
            nn.Linear(embed_dim, 3),
        )

        self.apply(
            partial(
                _init_weights,
                n_layer=num_layer,
                n_residuals_per_layer=2,
            )
        )

    def forward(self, batch):
        smiles_token_id = batch['smiles_token_id'].long()
        smiles_token_mask = batch['smiles_token_mask'].long()
        B, L  = smiles_token_id.shape

        x = self.embedding(smiles_token_id)
        x = x.permute(0,2,1).float()
        x = self.conv_embedding(x)
        x = x.permute(0,2,1).contiguous()
      
        hidden, residual = x, None
        for mamba in self.mamba_encoder:
            hidden, residual = mamba(
                hidden, residual, inference_params=None
            )
            hidden = F.dropout(hidden,p=0.1, training=self.training)

        #z=hidden
        z = layer_norm_fn(
            hidden,
            self.norm_f.weight,
            self.norm_f.bias,
            eps=self.norm_f.eps,
            residual=residual,
            prenorm=False,
            residual_in_fp32=False,
            is_rms_norm=isinstance(self.norm_f, RMSNorm)
        )

        #pool = z.mean(1)
        m = smiles_token_mask.unsqueeze(2).float()
        pool = (z*m).sum(1)/m.sum(1)
        bind = self.bind(pool)

        # --------------------------
        output = {}
        if 'loss' in self.output_type:
            target = batch['bind']
            output['bce_loss'] = F.binary_cross_entropy_with_logits(bind.float(), target.float())

        if 'infer' in self.output_type:
            output['bind'] = torch.sigmoid(bind)

        return output

#### 