In [None]:
import os
import json
import numpy as np
import torch
from model import create_dta_toy_model

In [None]:
fasta_vocab_path = "fasta_vocab.json"
with open(fasta_vocab_path, 'r') as f:
    fasta_vocab = json.load(f)

fasta_vocab_size = len(fasta_vocab)
print("Vocab size:", fasta_vocab_size)

Vocab size: 25


In [None]:
smiles_vocab_path = "vocab_chars.json"
with open(smiles_vocab_path, 'r') as f:
    smiles_vocab = json.load(f)

smiles_vocab_size = len(smiles_vocab)
print("Vocab size:", smiles_vocab_size)

Vocab size: 101


In [None]:
def smiles_to_onehot(smiles: str, vocab: dict, max_len: int):
    vocab_size = len(vocab)
    one_hot = np.zeros((max_len, vocab_size), dtype=np.float32)
    pad_id = vocab.get('<pad>', 0)
    unk_id = vocab.get('<unk>', 3)
    
    for i, ch in enumerate(smiles[:max_len]):
        idx = vocab.get(ch, unk_id)
        one_hot[i, idx] = 1.0
        
    for i in range(len(smiles), max_len):
        one_hot[i, pad_id] = 1.0
    return one_hot

In [None]:
def fasta_to_onehot(fasta: str, vocab: dict, max_len: int):
    vocab_size = len(vocab)
    one_hot = np.zeros((max_len, vocab_size), dtype=np.float32)
    pad_id = vocab.get('<pad>')
    unk_id = vocab.get('<unk>')

    for i, ch in enumerate(fasta[:max_len]):
        idx = vocab.get(ch.upper(), unk_id)
        one_hot[i, idx] = 1.0

    for i in range(len(fasta), max_len):
        one_hot[i, pad_id] = 1.0
    return one_hot

In [None]:
model = create_dta_toy_model(
    smiles_vocab_size=smiles_vocab_size, 
    fasta_vocab_size=fasta_vocab_size
)
model.eval()

DrugTargetAffinityModel(
  (smiles_encoder): SequenceTransformerEncoder(
    (embedding): Linear(in_features=101, out_features=128, bias=True)
    (pos_encoder): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer_encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=256, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (fasta_enc

In [None]:
smiles_seq_len = 256
fasta_seq_len = 1000

smiles_example_1 = "CC[NH+](CC)[C@](C)(CC)[C@H](O)c1cscc1Br"
fasta_example_1 = "MEECWVTEIANGSKDGLDSNPMKDYMILSGPQKTAVAVLCTLLGLLSALENVAVLYLILSSHQLRRKPSYLFIGSLAGADFLASVVFACSFVNFHVFHGVDSKAVFLLKIGSVTMTFTASVGSLLLTAIDRYLCLRYPPSYKALLTRGRALVTLGIMWVLSALVSYLPLMGWTCCPRPCSELFPLIPNDYLLSWLLFIAFLFSGIIYTYGHVLWKAHQHVASLSGHQDRQVPGMARMRLDVRLAKTLGLVLAVLLICWFPVLALMAHSLATTLSDQVKKAFAFCSMLCLINSMVNPVIYALRSGEIRSSAHHCLAHWKKCVRGLGSEAKEEAPRSSVTETEADGKITPWPDSRDLDLSDC"

smiles_example_2 = "CN1CCN(CC(=O)N(C)c2ccc(NC(=C3C(=O)Nc4ccccc43)c3ccccc3)cc2)CC1"
fasta_example_2 = "MGHALCVCSRGTVIIDNKRYLFIQKLGEGGFSYVDLVEGLHDGHFYALKRILCHEQQDREEAQREADMHRLFNHPNILRLVAYCLRERGAKHEAWLLLPFFKRGTLWNEIERLKDKGNFLTEDQILWLLLGICRGLEAIHAKGYAHRDLKPTNILLGDEGQPVLMDLGSMNQACIHVEGSRQALTLQDWAAQRCTISYRAPELFSVQSHCVIGERTDVWSLGCVLYAMMFGEGPYDMVFQKGDSVALAVQNQLSIPQSPRHSSALRQLLNSMMTVDPHQRPHIPLLLSQLEALQPPAPGQHTTQIEKAAC"

smiles_onehot_1 = smiles_to_onehot(smiles_example_1, smiles_vocab, max_len=smiles_seq_len)
fasta_onehot_1 = fasta_to_onehot(fasta_example_1, fasta_vocab, max_len=fasta_seq_len)

smiles_onehot_2 = smiles_to_onehot(smiles_example_2, smiles_vocab, max_len=smiles_seq_len)
fasta_onehot_2 = fasta_to_onehot(fasta_example_2, fasta_vocab, max_len=fasta_seq_len)

smiles_batch_np = np.stack([smiles_onehot_1, smiles_onehot_2]).astype(np.float32)
fasta_batch_np = np.stack([fasta_onehot_1, fasta_onehot_2]).astype(np.float32)

smiles_bin_path = "smiles_batch_input.bin"
fasta_bin_path = "fasta_batch_input.bin"

smiles_batch_np.tofile(smiles_bin_path)
fasta_batch_np.tofile(fasta_bin_path)

print(f"Saved SMILES batch input: {smiles_bin_path} (Shape: {smiles_batch_np.shape})")
print(f"Saved FASTA batch input:  {fasta_bin_path} (Shape: {fasta_batch_np.shape})")

Saved SMILES batch input: smiles_batch_input.bin (Shape: (2, 256, 101))
Saved FASTA batch input:  fasta_batch_input.bin (Shape: (2, 1000, 25))
