In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from transformers import PreTrainedModel, PretrainedConfig

# === Config === #
class MolGPTConfig(PretrainedConfig):
    model_type = "molgpt"

    def __init__(self,
                 vocab_size=100,
                 embedding_dim=256,
                 padding_idx=0,
                 n_heads=8,
                 feedforward_dim=512,
                 n_layers=6,
                 desc_size=166,
                 max_position_embeddings=512,
                 hidden_dropout_prob=0.1,
                 **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.padding_idx = padding_idx
        self.n_heads = n_heads
        self.feedforward_dim = feedforward_dim
        self.n_layers = n_layers
        self.desc_size = desc_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_dropout_prob = hidden_dropout_prob


# === Descriptor Encoder === #
class DescriptorEncoder(nn.Module):
    def __init__(self, desc_size, emb_dim, dropout=0.1, device='cuda'):
        super().__init__()
        self.desc_size = desc_size
        self.device = device
        self.bit_emb = nn.Embedding(desc_size, emb_dim)
        self.scale_norm = nn.LayerNorm(1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, desc):  # desc: [B, D], float
        B, D = desc.shape
        assert D == self.desc_size, "Input descriptor dim must match initialized size"
        bit_idx = torch.arange(D, device=desc.device).unsqueeze(0).expand(B, D)  # [B, D]
        bit = self.bit_emb(bit_idx)                  # [B, D, emb]
        val = self.scale_norm(desc.unsqueeze(-1))    # [B, D, 1]
        emb = self.dropout(bit * val)                # [B, D, emb]
        return emb


# === Embedding + Positional Encoding === #
class MolGPTEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=config.padding_idx)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_dim)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        return self.dropout(token_embeddings + position_embeddings)


# === Transformer Stack === #
class MolGPTTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        encoder_layer = TransformerEncoderLayer(
            d_model=config.embedding_dim,
            nhead=config.n_heads,
            dim_feedforward=config.feedforward_dim,
            dropout=config.hidden_dropout_prob,
            batch_first=True
        )
        self.encoder = TransformerEncoder(encoder_layer, num_layers=config.n_layers)

    def forward(self, x, attn_mask=None):
        return self.encoder(x, mask=attn_mask)


# === Causal Mask === #
def generate_causal_mask(seq_len, device):
    mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
    return mask.masked_fill(mask == 1, float('-inf'))


# === Main Model === #
class MolGPTModel(PreTrainedModel):
    config_class = MolGPTConfig

    def __init__(self, config):
        super().__init__(config)
        self.embeddings = MolGPTEmbeddings(config)
        self.descriptor_encoder = DescriptorEncoder(config.desc_size, config.embedding_dim)
        self.transformer = MolGPTTransformer(config)
        self.lm_head = nn.Linear(config.embedding_dim, config.vocab_size)
        self.init_weights()

    def forward(self, input_ids, descriptors, labels=None):
        token_emb = self.embeddings(input_ids)  # [B, T, D]
        desc_emb = self.descriptor_encoder(descriptors)  # [B, D1, D]
        x = torch.cat([desc_emb, token_emb], dim=1)  # [B, D1+T, D]

        attn_mask = generate_causal_mask(x.size(1), x.device)
        x = self.transformer(x, attn_mask=attn_mask)

        logits = self.lm_head(x)  # [B, D1+T, V]
        token_logits = logits[:, desc_emb.size(1):, :]  # token part only

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(token_logits.view(-1, token_logits.size(-1)), labels.view(-1))
            return {"loss": loss, "logits": token_logits}

        return {"logits": token_logits}


In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutput


# === Config === #
class MolGPTConfig(PretrainedConfig):
    model_type = "molgpt"

    def __init__(self,
                 vocab_size=100,
                 embedding_dim=256,
                 padding_idx=0,
                 n_heads=8,
                 feedforward_dim=512,
                 n_layers=6,
                 desc_size=166,
                 max_position_embeddings=512,
                 hidden_dropout_prob=0.1,
                 **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.padding_idx = padding_idx
        self.n_heads = n_heads
        self.feedforward_dim = feedforward_dim
        self.n_layers = n_layers
        self.desc_size = desc_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_dropout_prob = hidden_dropout_prob


# === Descriptor Encoder === #
class DescriptorEncoder(nn.Module):
    def __init__(self, desc_size, emb_dim, dropout=0.1):
        super().__init__()
        self.desc_size = desc_size
        self.bit_emb = nn.Embedding(desc_size, emb_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, desc):  # desc: [B, D], float
        B, D = desc.shape
        assert D == self.desc_size, "Input descriptor dim must match initialized size"
        bit_idx = torch.arange(D, device=desc.device).unsqueeze(0).expand(B, D)  # [B, D]
        bit = self.bit_emb(bit_idx)                   # [B, D, emb]
        desc_norm = (desc - desc.mean(dim=-1, keepdim=True)) / (desc.std(dim=-1, keepdim=True) + 1e-6)
        val = desc_norm.unsqueeze(-1)                 # [B, D, 1]
        emb = self.dropout(bit * val)                 # [B, D, emb]
        return emb


# === Embedding + Positional Encoding === #
class MolGPTEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=config.padding_idx)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_dim)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        return self.dropout(token_embeddings + position_embeddings)


# === Transformer Stack === #
class MolGPTTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        encoder_layer = TransformerEncoderLayer(
            d_model=config.embedding_dim,
            nhead=config.n_heads,
            dim_feedforward=config.feedforward_dim,
            dropout=config.hidden_dropout_prob,
            batch_first=True
        )
        self.encoder = TransformerEncoder(encoder_layer, num_layers=config.n_layers)

    def forward(self, x, attn_mask=None):
        return self.encoder(x, attn_mask=attn_mask)


# === Causal Mask === #
def generate_causal_mask(seq_len, device):
    mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
    return mask  # bool型にすることで TransformerEncoder で使いやすくする


# === Main Model === #
class MolGPTModel(PreTrainedModel):
    config_class = MolGPTConfig

    def __init__(self, config):
        super().__init__(config)
        self.embeddings = MolGPTEmbeddings(config)
        self.descriptor_encoder = DescriptorEncoder(config.desc_size, config.embedding_dim, dropout=config.hidden_dropout_prob)
        self.transformer = MolGPTTransformer(config)
        self.lm_head = nn.Linear(config.embedding_dim, config.vocab_size)
        self.init_weights()

    def forward(self, input_ids: torch.Tensor, descriptors: torch.Tensor, labels: torch.Tensor = None) -> CausalLMOutput:
        token_emb = self.embeddings(input_ids)               # [B, T, D]
        desc_emb = self.descriptor_encoder(descriptors)      # [B, D1, D]
        x = torch.cat([desc_emb, token_emb], dim=1)          # [B, D1+T, D]

        attn_mask = generate_causal_mask(x.size(1), x.device)
        x = self.transformer(x, attn_mask=attn_mask)

        logits = self.lm_head(x)                             # [B, D1+T, V]
        token_logits = logits[:, desc_emb.size(1):, :]       # token part only

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(token_logits.view(-1, token_logits.size(-1)), labels.view(-1))

        return CausalLMOutput(
            loss=loss,
            logits=token_logits
        )


In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutput


# === Config === #
class MolGPTConfig(PretrainedConfig):
    model_type = "molgpt"

    def __init__(self,
                 vocab_size=100,
                 embedding_dim=256,
                 padding_idx=0,
                 n_heads=8,
                 feedforward_dim=512,
                 n_layers=6,
                 desc_size=166,
                 max_position_embeddings=512,
                 hidden_dropout_prob=0.1,
                 **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.padding_idx = padding_idx
        self.n_heads = n_heads
        self.feedforward_dim = feedforward_dim
        self.n_layers = n_layers
        self.desc_size = desc_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_dropout_prob = hidden_dropout_prob


# === Descriptor Encoder === #
class DescriptorEncoder(nn.Module):
    def __init__(self, desc_size, emb_dim, dropout=0.1):
        super().__init__()
        self.desc_size = desc_size
        self.bit_emb = nn.Embedding(desc_size, emb_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, desc):  # desc: [B, D], float
        B, D = desc.shape
        assert D == self.desc_size, "Input descriptor dim must match initialized size"
        bit_idx = torch.arange(D, device=desc.device).unsqueeze(0).expand(B, D)
        bit = self.bit_emb(bit_idx)  # [B, D, emb]
        desc_norm = (desc - desc.mean(dim=-1, keepdim=True)) / (desc.std(dim=-1, keepdim=True) + 1e-6)
        val = desc_norm.unsqueeze(-1)  # [B, D, 1]
        emb = self.dropout(bit * val)  # [B, D, emb]
        return emb


# === Embedding + Positional Encoding === #
class MolGPTEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=config.padding_idx)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_dim)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        return self.dropout(token_embeddings + position_embeddings)


# === Transformer Stack === #
class MolGPTTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        encoder_layer = TransformerEncoderLayer(
            d_model=config.embedding_dim,
            nhead=config.n_heads,
            dim_feedforward=config.feedforward_dim,
            dropout=config.hidden_dropout_prob,
            batch_first=True
        )
        self.encoder = TransformerEncoder(encoder_layer, num_layers=config.n_layers)

    def forward(self, x, attn_mask=None):
        return self.encoder(x, mask=attn_mask)


# === Causal Mask === #
def generate_causal_mask(seq_len, device):
    mask = torch.triu(torch.full((seq_len, seq_len), float('-inf'), device=device), diagonal=1)
    return mask


# === Main Model === #
class MolGPTModel(PreTrainedModel):
    config_class = MolGPTConfig

    def __init__(self, config):
        super().__init__(config)
        self.embeddings = MolGPTEmbeddings(config)
        self.descriptor_encoder = DescriptorEncoder(config.desc_size, config.embedding_dim, dropout=config.hidden_dropout_prob)
        self.transformer = MolGPTTransformer(config)
        self.lm_head = nn.Linear(config.embedding_dim, config.vocab_size)
        self.post_init()  # ← init_weights() ではなくこちらを使用

    def forward(self, input_ids: torch.Tensor, descriptors: torch.Tensor, labels: torch.Tensor = None) -> CausalLMOutput:
        token_emb = self.embeddings(input_ids)               # [B, T, D]
        desc_emb = self.descriptor_encoder(descriptors)      # [B, D1, D]
        x = torch.cat([desc_emb, token_emb], dim=1)          # [B, D1+T, D]

        attn_mask = generate_causal_mask(x.size(1), x.device)  # [L, L]
        x = self.transformer(x, attn_mask=attn_mask)

        logits = self.lm_head(x)                             # [B, D1+T, V]
        token_logits = logits[:, desc_emb.size(1):, :]       # token part only

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.padding_idx)
            loss = loss_fct(token_logits.view(-1, token_logits.size(-1)), labels.view(-1))

        return CausalLMOutput(
            loss=loss,
            logits=token_logits
        )

    def generate(self, input_ids, descriptors, max_length=100):
        self.eval()
        generated = input_ids
        with torch.no_grad():
            for _ in range(max_length):
                outputs = self.forward(generated, descriptors)
                next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(1)
                generated = torch.cat([generated, next_token], dim=1)
                if (next_token == self.config.padding_idx).all():
                    break
        return generated


In [None]:
from transformers import PreTrainedTokenizer
import re

class SMILESTokenizerHF(PreTrainedTokenizer):
    def __init__(self, **kwargs):
        super().__init__(
            pad_token='<PAD>',
            bos_token='<BOS>',
            eos_token='<EOS>',
            unk_token='<UNK>',
            **kwargs
        )

        # 固定トークンセット
        self.tokens = ['#', '%10', '%11', '%12', '(', ')', '-', '1', '2', '3', '4', '5', '6', '7', '8', '9', '<', '=', 'B', 'Br', 'C', 'Cl', 'F', 'I', 'N', 'O', 'P', 'S',
            '[B-]', '[BH-]', '[BH2-]', '[BH3-]', '[B]', '[C+]', '[C-]', '[CH+]', '[CH-]', '[CH2+]', '[CH2]', '[CH]', '[F+]', '[H]', '[I+]', '[IH2]', '[IH]', '[N+]', '[N-]',
            '[NH+]', '[NH-]', '[NH2+]', '[NH3+]', '[N]', '[O+]', '[O-]', '[OH+]', '[O]', '[P+]', '[PH+]', '[PH2+]', '[PH]', '[S+]', '[S-]', '[SH+]', '[SH]', '[Se+]', 
            '[SeH+]', '[SeH]', '[Se]', '[Si-]', '[SiH-]', '[SiH2]', '[SiH]', '[Si]', '[b-]', '[bH-]', '[c+]', '[c-]', '[cH+]', '[cH-]', '[n+]', '[n-]', '[nH+]', '[nH]', 
            '[o+]', '[s+]', '[sH+]', '[se+]', '[se]', 'b', 'c', 'n', 'o', 'p', 's']

        # 特殊トークン
        self.special_tokens_list = ['<BOS>', '<EOS>', '<PAD>', '<UNK>']

        # 全トークン
        self.all_tokens = self.special_tokens_list + self.tokens
        self.vocab = {token: idx for idx, token in enumerate(self.all_tokens)}
        self.ids_to_tokens = {idx: token for token, idx in self.vocab.items()}

        # 正規表現パターン
        self.pattern = re.compile(
            r'%\d{2}|\[[^\]]+\]|Cl|Br|Si|Se|As|B|C|N|O|F|P|S|I|[cnospb]|[0-9]|\(|\)|=|#|-|\/|\\|\.|<'
        )

    def _tokenize(self, text):
        return self.pattern.findall(text)

    def _convert_token_to_id(self, token):
        return self.vocab.get(token, self.vocab['<UNK>'])

    def _convert_id_to_token(self, index):
        return self.ids_to_tokens.get(index, '<UNK>')

    def convert_tokens_to_string(self, tokens):
        # 文字列に戻す（BOS, EOSなどは除く）
        return ''.join([t for t in tokens if t not in self.special_tokens_list])

    def build_inputs_with_special_tokens(self, token_ids):
        return [self.vocab['<BOS>']] + token_ids + [self.vocab['<EOS>']]

    def get_vocab(self):
        return self.vocab

    def save_vocabulary(self, save_directory, filename_prefix=None):
        import os
        vocab_file = os.path.join(save_directory, (filename_prefix or "") + "vocab.txt")
        with open(vocab_file, "w", encoding="utf-8") as f:
            for token in self.all_tokens:
                f.write(token + "\n")
        return (vocab_file,)
