In [13]:
import abc

import torch

from functools import partial

from proteinbert_gen.proteinbert import ProteinBERT
from proteinbert_gen.tokenizer import ProteinTokenizer
from proteinbert_gen.constants import GO_ANN_SIZE

from proteinbert_gen.word_freq import create_word_freq_tensor
import proteinbert_gen.mask_diffusion as mask_diffusion

In [14]:
tokenizer = ProteinTokenizer()

model = ProteinBERT(tokenizer.vocab_size, GO_ANN_SIZE)
model.load_state_dict(torch.load("../checkpoints/northern-wind-52-postepoch-8.pt"))
model.eval()

ProteinBERT(
  (embed_local): Embedding(26, 128)
  (embed_global): Sequential(
    (0): Linear(in_features=8943, out_features=512, bias=True)
    (1): GELU(approximate='none')
  )
  (blocks): ModuleList(
    (0-5): 6 x TransformerLikeBlock(
      (wide_and_narrow_conv1d): ConvBlock(
        (conv_narrow): Sequential(
          (0): Rearrange('b l d -> b d l')
          (1): Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=same)
          (2): GELU(approximate='none')
          (3): Rearrange('b d l -> b l d')
        )
        (conv_wide): Sequential(
          (0): Rearrange('b l d -> b d l')
          (1): Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=same, dilation=(5,))
          (2): GELU(approximate='none')
          (3): Rearrange('b d l -> b l d')
        )
      )
      (dense_and_broadcast): Sequential(
        (0): Linear(in_features=512, out_features=128, bias=True)
        (1): GELU(approximate='none')
        (2): Rearrange('b d -> b () d')
      )
      

In [25]:
def denoise(targets, timestep, attention_mask, *, model):
    ret = model(targets)
    return ret


denoise_fn = partial(denoise, model=model)
sample_cls = Categorical()

diffusion_schedule = mask_diffusion.create_discrete_diffusion_schedule(num_steps=4096)
diffusion_instance = mask_diffusion.MaskDiffusion(
    dim=tokenizer.vocab_size,
    schedule=diffusion_schedule,
    tokenizer=tokenizer,
    device="cpu"
)

using standard schedule with num_steps: 4096.


In [33]:
generated = mask_diffusion.discrete_diffusion_predict_fn((1, 300), denoise_fn, diffusion_instance, temperature=2.0, topk=8, topp=1.0)
for g in generated["final_state"].tolist():
    print(tokenizer.untokenize(g))

P^SFHVFYAVEMHDRMHAGRGFAQEWQMNLNYCYMLVFDHATKFFNSRAMNSCEAMFAVITTGFCQSEQNCGYRVLLYIAFKNETVGERGHCRIDGRDQLIEMYKVPGCFGEAINPDIMIVRERIPHEFVEEEERDDVMRVNGMEDKNLIQEQADRKRICFEIMEDIQDEQIQEMMAIEGNDNRPEKCQDAFDIFRDPRKDPRCKKIRQLQNVPEIKVNFEAAIIYGFNVVNINFEYFRWQYVMNKIIVEITNQAEEENQNQDEIKEVVRQHPGNLPNYVQWKVYKGCVKYFNGEGVFDG
