In [81]:
import math

n_batch = 1
dim_large = 1024
dim_small = 128
dim_project = 512
n_heads = 8

small_seq = torch.rand((1, dim_small))
large_seq = torch.rand((1, dim_large))

large_to_small = nn.Linear(dim_large, dim_small)
small_to_large = nn.Linear(dim_small, dim_large)

large_seq_projected = large_to_small(large_seq)
small_seq_projected = small_to_large(small_seq)

print(large_seq_projected.shape)
print(small_seq_projected.shape)

torch.Size([1, 128])
torch.Size([1, 1024])


In [None]:
from einops import rearrange

class CrossAttention(nn.Modules):
    def __init__(self, d_model, inner_dim=512, n_heads=8, dropout=0.2):
        super().__init__()

        self.n_batch = n_batch
        self.d_model = d_model
        self.n_heads = n_heads
        self.inner_dim = inner_dim
        self.scale = dim_head ** -0.5
        
        self.to_q = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.to_k = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.to_v = nn.Linear(self.d_model, self.inner_dim, bias=False)
        
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, d_model),
            nn.Dropout(dropout)
        )
        
    def forward(self, q, k, v):
        q = self.to_q(q)
        k = self.to_k(k)
        v = self.to_v(v)
        
        q = q.view(-1, self.n_heads, self.inner_dim // self.n_heads)
        k = k.view(-1, self.n_heads, self.inner_dim // self.n_heads)
        v = v.view(-1, self.n_heads, self.inner_dim // self.n_heads)
        
        dim = q.shape[-1]
        score = torch.matmul(q, k.T) / torch.sqrt(dim)
        score = F.softmax(score, dim=-1)
        out = torch.matmul(score, v)
        
        return out



TypeError: __init__() missing 4 required positional arguments: 'd_model', 'h', 'qkv_fc_layer', and 'fc_layer'

In [35]:
from transformers import BertForMaskedLM, BertTokenizer, PreTrainedTokenizerFast

fast_tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="data/drug/tokenizer_model/vocab.json",
    pad_token="[PAD]",
    mask_token="[MASK]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    unk_token="[UNK]"
)

vocab_size = len(fast_tokenizer.get_vocab().keys())

# prot_bert = BertModel.from_pretrained("Rostlab/prot_bert")
molecule_bert = BertModel.from_pretrained("weights/molecule_bert_pretrained/")

Some weights of the model checkpoint at weights/molecule_bert_pretrained/ were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at weights/molecule_bert_pretrained/ and are newly initialized: ['bert.pooler.dense.weight', '

In [39]:
sequence_Example = "CN(C)CCOC(C1=CC=CC=C1)C1=CC=CC=C1"
tokenizer.encode(sequence_Example, max_length=128, truncation=True)

encoded_input = tokenizer(sequence_Example, return_tensors='pt')
output = molecule_bert(**encoded_input)
output.pooler_output

tensor([[-0.2529, -0.1541, -0.3178, -0.2939,  0.0102,  0.0807,  0.1587, -0.0029,
         -0.2492,  0.0649,  0.0317, -0.2974,  0.1286,  0.3392, -0.1417, -0.0793,
          0.2037,  0.3811, -0.0537,  0.1650, -0.1294, -0.1952, -0.0131, -0.1350,
          0.2100,  0.2240, -0.0620, -0.2068, -0.3132,  0.2256,  0.0541,  0.3479,
         -0.3182,  0.3028,  0.2635,  0.1498, -0.1914, -0.1550,  0.2509,  0.1575,
          0.0092,  0.0112, -0.2377,  0.4385, -0.2409, -0.1302,  0.2638,  0.0248,
         -0.0245,  0.1838, -0.2324,  0.1214, -0.0582,  0.2001, -0.4337,  0.1611,
          0.1092,  0.2030,  0.1474,  0.1470, -0.3941,  0.0382, -0.1341, -0.0460,
          0.0069,  0.1797, -0.1230,  0.2369,  0.1034,  0.0167, -0.2657,  0.2654,
         -0.2949,  0.0114, -0.1491,  0.2392, -0.0574,  0.1602,  0.2748, -0.0705,
         -0.1063,  0.0791,  0.0172,  0.1818,  0.1234, -0.1206, -0.2765, -0.1249,
          0.0361,  0.1667, -0.2323, -0.2915,  0.1291, -0.1358, -0.3952, -0.0519,
         -0.0555,  0.1175,  

torch.Size([1, 23, 128])

In [22]:
from transformers import PreTrainedTokenizerFast

fast_tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="data/drug/tokenizer_model/vocab.json",
    pad_token="[PAD]",
    mask_token="[MASK]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    unk_token="[UNK]"
)

vocab_size = len(fast_tokenizer.get_vocab().keys())

print(f"load tokenizer\nvocab size: {vocab_size}\nspecial tokens: {fast_tokenizer.all_special_tokens}")

import torchmetrics
import pytorch_lightning as pl
from transformers import BertConfig, BertForMaskedLM
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

config = BertConfig(
    vocab_size=vocab_size,
    hidden_size=128,
    num_hidden_layers=8,
    num_attention_heads=8,
    intermediate_size=512,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=128,
    type_vocab_size=1,
    pad_token_id=0,
    position_embedding_type="absolute"
)

class Bert(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.save_hyperparameters()
        self.model = BertForMaskedLM(config)
        
        self.train_accuracy = torchmetrics.Accuracy()
        self.valid_accuracy = torchmetrics.Accuracy()
        self.test_accuracy = torchmetrics.Accuracy()
        
        
    def forward(self, input_ids, labels):
        return self.model(input_ids=input_ids, labels=labels)

       
    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        labels = batch['labels']
        
        output = self(input_ids, labels)

        loss = output.loss
        logits = output.logits

        preds = logits.argmax(dim=-1)
        
        self.log('train_loss', float(loss), on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_accuracy", self.train_accuracy(preds[labels > 0], labels[labels > 0]), on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        torch.cuda.empty_cache()
        
        return loss

    
    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        labels = batch['labels']
        
        output = self(input_ids, labels)

        loss = output.loss
        logits = output.logits

        preds = logits.argmax(dim=-1)
        
        self.log('valid_loss', float(loss), on_step=False, on_epoch=True, prog_bar=True)
        self.log("valid_accuracy", self.valid_accuracy(preds[labels > 0], labels[labels > 0]), on_step=False, on_epoch=True, prog_bar=True, logger=True)
    
        torch.cuda.empty_cache()
    
    
    def test_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        labels = batch['labels']
        
        output = self(input_ids, labels)

        loss = output.loss
        logits = output.logits

        preds = logits.argmax(dim=-1)
        
        self.log('test_loss', float(loss), on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_accuracy", self.test_accuracy(preds[labels > 0], labels[labels > 0]), on_step=False, on_epoch=True, prog_bar=True, logger=True)
    
        torch.cuda.empty_cache()
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)
    
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    

model = Bert(config).load_from_checkpoint("weights/molecule_bert/exp7_pretraining_done/molecule_bert-epoch=896-valid_loss=0.0995.ckpt")
model

load tokenizer
vocab size: 69
special tokens: ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']


Bert(
  (model): BertForMaskedLM(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(69, 128, padding_idx=0)
        (position_embeddings): Embedding(128, 128)
        (token_type_embeddings): Embedding(1, 128)
        (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=128, out_features=128, bias=True)
                (key): Linear(in_features=128, out_features=128, bias=True)
                (value): Linear(in_features=128, out_features=128, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=128, out_features=128, bias=True)
                (LayerN