In [1]:
import math
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    """
    Standard sinusoidal positional encoding.
    Produces (batch, seq_len, d_model) given (batch, seq_len, d_model) input.
    """
    def __init__(self, d_model: int, max_len: int = 500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)  # (max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)  # (max_len, 1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float32) *
            (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (batch_size, seq_len, d_model)
        """
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

class TrinucTransformerClassifier(nn.Module):
    def __init__(
        self,
        vocab_size: int = 64,
        seq_len: int = 200,
        d_model: int = 128,
        n_heads: int = 4,
        num_layers: int = 3,
        dim_feedforward: int = 512,
        dropout: float = 0.1,
        num_classes: int = 2,  # 2 for binary logit (CrossEntropyLoss)
        use_cls_token: bool = True,
    ):
        super().__init__()
        self.seq_len = seq_len
        self.d_model = d_model
        self.use_cls_token = use_cls_token

        # Token embedding (trinucleotides indexed 0..63)
        self.token_emb = nn.Embedding(vocab_size + (1 if use_cls_token else 0), d_model)
        # Positional encoding
        self.pos_enc = PositionalEncoding(d_model, max_len=seq_len + (1 if use_cls_token else 0))

        # Transformer encoder (batch_first=True => (B, L, E))
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers,
        )

        self.dropout = nn.Dropout(dropout)
        self.dropout_head = nn.Dropout(dropout)

        # Classification head: pooled representation -> logit(s)
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (batch_size, seq_len) of integer token IDs in [0, vocab_size)
        Returns:
            logits: (batch_size, num_classes)
        """
        # Embed tokens
        x = self.token_emb(x)  # (B, L, d_model)

        if self.use_cls_token:
            batch_size = x.size(0)
            cls_token = torch.zeros(batch_size, 1, self.d_model, device=x.device)
            # Optional: learn a cls embedding instead of zeros
            # self.cls_emb = nn.Parameter(torch.zeros(1, 1, d_model))
            # cls_token = self.cls_emb.expand(batch_size, -1, -1)
            x = torch.cat([cls_token, x], dim=1)  # (B, L+1, d_model)

        # Add positional encoding
        x = self.pos_enc(x)    # (B, L, d_model)

        # Transformer encoder
        x = self.encoder(x)    # (B, L, d_model)

        if self.use_cls_token:
            x = x[:, 0, :]       # (B, d_model), CLS
        else:
            x = x.mean(dim=1)    # (B, d_model)
        
        x = self.dropout_head(x)
        logits = self.fc(x)    # (B, num_classes)
        return logits

In [13]:
import json
from tokenizers import Tokenizer, models, pre_tokenizers, Regex
from itertools import product

# 1. Generate the vocabulary for all 64 trinucleotides (AAA, AAC, ..., TTT)
bases = ['A', 'C', 'G', 'T']
trinucleotides = [''.join(p) for p in product(bases, repeat=3)]
vocab = {tri: i for i, tri in enumerate(trinucleotides)}

# 2. Add special tokens if your model uses them (e.g., [PAD], [CLS])
# Your model uses index 64 for CLS if use_cls_token=True
vocab["[CLS]"] = 64

# 3. Create a WordLevel Tokenizer
tokenizer = Tokenizer(models.WordLevel(vocab=vocab, unk_token="[UNK]"))

# 4. Set a Pre-tokenizer to split DNA strings into 3-character chunks
# This tells the tokenizer to look at "ATGCAT" as ["ATG", "CAT"]
tokenizer.pre_tokenizer = pre_tokenizers.Split(
    pattern=Regex(".{3}"), 
    behavior="isolated"
)

# 5. Save the tokenizer files
tokenizer.save("hf_trinuc_model/tokenizer.json")

tokenizer_config = {
    "model_max_length": 200,
    "padding_side": "right",
    "tokenizer_class": "PreTrainedTokenizerFast",
    "cls_token": "[CLS]",
    "unk_token": "[UNK]"
}
with open("hf_trinuc_model/tokenizer_config.json", "w") as f:
    json.dump(tokenizer_config, f)

In [17]:
from transformers import PretrainedConfig

class TrinucTransformerConfig(PretrainedConfig):
    model_type = "trinuc_transformer"

    def __init__(
        self,
        vocab_size=64,
        seq_len=200,
        d_model=256,
        n_heads=4,
        num_layers=4,
        dim_feedforward=512,
        dropout=0.1,
        num_classes=2,
        use_cls_token=True,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.d_model = d_model
        self.n_heads = n_heads
        self.num_layers = num_layers
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.num_classes = num_classes
        self.use_cls_token = use_cls_token

In [18]:
import torch
from torch import nn
from transformers import PreTrainedModel

class TrinucTransformerModel(PreTrainedModel):
    config_class = TrinucTransformerConfig
    base_model_prefix = "transformer"

    def __init__(self, config):
        super().__init__(config)
        # Instantiate your original model logic here
        self.model = TrinucTransformerClassifier(
            vocab_size=config.vocab_size,
            seq_len=config.seq_len,
            d_model=config.d_model,
            n_heads=config.n_heads,
            num_layers=config.num_layers,
            dim_feedforward=config.dim_feedforward,
            dropout=config.dropout,
            num_classes=config.num_classes,
            use_cls_token=config.use_cls_token
        )

    def forward(self, input_ids=None, labels=None, **kwargs):
        # input_ids: (batch_size, seq_len)
        logits = self.model(input_ids)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(label_smoothing=0.05)
            loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1))

        # Standard HF output format: (loss, logits) or just (logits,)
        return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}

In [40]:
# 1. Create config and model
config = TrinucTransformerConfig()
hf_model = TrinucTransformerModel(config)

# 2. Load your raw .pth weights
# Important: If your .pth was saved as the internal 'TrinucTransformerClassifier', 
# you might need to map keys to match 'self.model' prefix.
signal = 'TIS'
state_dict = torch.load(f"./{signal}/Transformer_{signal}_all.pth", map_location="cpu")['model']
hf_model.model.load_state_dict(state_dict)

# 3. Save as HF compatible directory
hf_model.save_pretrained(f"../app/models/{signal}_model")

# must update config.json
# 
#   "architectures": [
#     "DistilBertForSequenceClassification"
#   ],
#   "model_type": "distilbert",
#   "id2label": {
#       "0": "Negative",
#       "1": "Positive"
#     },
#     "label2id": {
#       "Negative": 0,
#       "Positive": 1
#     },  


In [41]:
dummy_input = torch.randint(0, 64, (2, 198))

In [42]:
#!pip install optimum onnx onnxruntime

In [44]:
torch.onnx.export(
    hf_model, 
    dummy_input, 
    f"../app/models/{signal}_model/onnx/model.onnx",
    input_names=['input_ids'],     # Essential for JS compatibility
    output_names=['logits'],       # Essential for JS compatibility
    dynamic_axes={'input_ids': {0: 'batch_size', 1: 'sequence_length'}}
)

In [45]:
import onnx
model_onnx = onnx.load(f"../app/models/{signal}_model/onnx/model.onnx")
onnx.checker.check_model(model_onnx)