In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import abc
import pickle
import math

import wandb

import torch

from tqdm import tqdm
from functools import partial
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW, SGD
from collections import namedtuple

import proteinbert_gen.constants as consts
import proteinbert_gen.mask_diffusion as mask_diffusion

from proteinbert_gen.debugging import print2
from proteinbert_gen.proteinbert import ProteinBERT, load_pretrained_weights
from proteinbert_gen.word_freq import create_word_freq_tensor
from proteinbert_gen.tokenizer import ProteinTokenizer
from proteinbert_gen.dataset import sprot_train

In [3]:
Hyperparameters = namedtuple(
    "Hyperparameters",
    [
        "batch_size",
        "epochs",
        "num_steps",
        "word_freq_lambda",
        "device",
        "hybrid_lambda",
        "lr",
        "logging_steps",
        "eval_step_size"
    ]
)

args = Hyperparameters(
    batch_size=64,
    epochs=100,
    num_steps=2048,
    word_freq_lambda=0.3,
    device="cuda",
    hybrid_lambda=1e-2,
    lr=5e-4,
    logging_steps=10,
    eval_step_size=4
)

wandb.init(
    project="proteinbert_gen",
    config=args._asdict(),
    # mode="disabled"
)

[34m[1mwandb[0m: Currently logged in as: [33mmattfeng[0m ([33mkaiogenbio[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
class SampleClassBase(abc.ABC):
    def sample(self, logits, x_0):
        raise NotImplementedError

    def post_process_sample_in_prediction(self, sample, x_0):
        return sample


class Categorical(SampleClassBase):
    def sample(self, logits, x_0):
        return torch.distributions.categorical.Categorical(logits=logits).sample()

In [5]:
def word_freq_preprocess_fn(wf):
    wf = wf + 1
    wf = wf.log()
    wf = wf / wf.max()

    # range: 0 - 1
    return wf

def process_fn_in_collate(wf):
    return wf - wf.mean()


tokenizer = ProteinTokenizer()
wf_tensor = create_word_freq_tensor("../data/sprot_1m_word_freq_dict.pkl", tokenizer.ALL_TOKENS)
wf_tensor[tokenizer.mask_token_id] = 0
wf_tensor = word_freq_preprocess_fn(wf_tensor)
wf_tensor

tensor([0.9930, 0.8760, 0.9626, 0.9751, 0.9424, 0.9824, 0.9082, 0.9710, 0.9673,
        1.0000, 0.9151, 0.9418, 0.9518, 0.9393, 0.9652, 0.9708, 0.9611, 0.3411,
        0.9802, 0.8615, 0.0000, 0.9240, 0.0000, 0.0000, 0.0000, 0.0000])

In [6]:
def collate(batch_input, *, tokenizer, word_freq: torch.Tensor):
    input_ids = []
    attention_mask = []
    word_freq_logits = []
    
    for item in batch_input:
        seq = item["seq"]
        ids = torch.tensor(tokenizer.tokenize(seq))
        mask = torch.ones_like(ids)
        logits = process_fn_in_collate(
            word_freq.gather(0, ids)
        )
        
        input_ids.append(ids)
        attention_mask.append(mask)
        word_freq_logits.append(logits)

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask = pad_sequence(attention_mask, batch_first=True)
    word_freq_logits = pad_sequence(word_freq_logits, batch_first=True)
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "word_freq_logits": word_freq_logits
    }

collate_fn = partial(collate, tokenizer=tokenizer, word_freq=wf_tensor)

In [7]:
train_loader = torch.utils.data.DataLoader(
    sprot_train,
    batch_size=args.batch_size,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)

In [8]:
next(iter(train_loader))

{'input_ids': tensor([[23, 10,  0,  ..., 25, 25, 25],
         [23, 10, 16,  ..., 25, 25, 25],
         [23, 10,  8,  ..., 25, 25, 25],
         ...,
         [23, 10,  0,  ..., 25, 25, 25],
         [23, 10,  3,  ..., 25, 25, 25],
         [23, 10, 15,  ..., 25, 25, 25]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'word_freq_logits': tensor([[-0.9607, -0.0456,  0.0322,  ...,  0.0000,  0.0000,  0.0000],
         [-0.9570, -0.0418,  0.0042,  ...,  0.0000,  0.0000,  0.0000],
         [-0.9610, -0.0459,  0.0063,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.9601, -0.0449,  0.0329,  ...,  0.0000,  0.0000,  0.0000],
         [-0.9164, -0.0013,  0.0587,  ...,  0.0000,  0.0000,  0.0000],
         [-0.9635, -0.0484,  0.0073,  ...,  0.0000,  0.0000,  0.0000]])}

In [9]:
def denoise(targets, timestep, attention_mask, *, model):
    #ret = model(targets)
    ret = model(targets, attention_mask=attention_mask)
    # print("denoise output:", ret.shape)
    return ret

with open("../weights/epoch_92400_sample_23500000.pkl", "rb") as f:
    _, pretrained_model_weights, _ = pickle.load(f)

model = ProteinBERT(tokenizer.vocab_size, consts.GO_ANN_SIZE)
print(model)

trainable_params = load_pretrained_weights(model, pretrained_model_weights)
model = model.to(args.device)
denoise_fn = partial(denoise, model=model)

ProteinBERT(
  (embed_local): Embedding(26, 128)
  (embed_global): Sequential(
    (0): Linear(in_features=8943, out_features=512, bias=True)
    (1): GELU(approximate='none')
  )
  (blocks): ModuleList(
    (0-5): 6 x TransformerLikeBlock(
      (wide_and_narrow_conv1d): ConvBlock(
        (conv_narrow): Sequential(
          (0): Rearrange('b l d -> b d l')
          (1): Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=same)
          (2): GELU(approximate='none')
          (3): Rearrange('b d l -> b l d')
        )
        (conv_wide): Sequential(
          (0): Rearrange('b l d -> b d l')
          (1): Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=same, dilation=(5,))
          (2): GELU(approximate='none')
          (3): Rearrange('b d l -> b l d')
        )
      )
      (dense_and_broadcast): Sequential(
        (0): Linear(in_features=512, out_features=128, bias=True)
        (1): GELU(approximate='none')
        (2): Rearrange('b d -> b () d')
      )
      

In [10]:
optimizer = SGD(trainable_params, lr=args.lr)
# warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(
#     optimizer,
#     lr_lambda=lambda n: n / 10000. + 1e-3 if n < 10000 else 100. / math.sqrt(n)
# )

  return torch._dynamo.disable(fn, recursive)(*args, **kwargs)


In [11]:
sample_cls = Categorical()

diffusion_schedule = mask_diffusion.create_discrete_diffusion_schedule(num_steps=args.num_steps)
diffusion_instance = mask_diffusion.MaskDiffusion(
    dim=tokenizer.vocab_size,
    schedule=diffusion_schedule,
    tokenizer=tokenizer,
    sample_cls=sample_cls,
    word_freq_lambda=args.word_freq_lambda,
    device=args.device
)

using standard schedule with num_steps: 2048.


In [None]:
train_loss = 0.
nan_count = 0

# torch.autograd.set_detect_anomaly(True)

# def _save_output(module, grad_input, grad_output):
#     print(module, grad_output)
#     print(torch.isnan(grad_output[0]).any())
#     print()

# for name, module in model.named_modules():
#     if str(type(module)).find("LayerNorm") != -1:
#         print(name)
#         module.register_full_backward_hook(_save_output)

for epoch in range(args.epochs):
    for i, batch in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        diffusion_t = diffusion_instance.sample_t()
        # print(diffusion_t)

        metrics = mask_diffusion.compute_kl_reverse_process(
            batch["input_ids"].to(args.device),
            diffusion_t,
            denoise_fn=denoise_fn,
            diffusion=diffusion_instance,
            target_mask=batch["attention_mask"].to(args.device),
            hybrid_lambda=args.hybrid_lambda,
            predict_x0=True,
            word_freq_logits=batch["word_freq_logits"].to(args.device),
            device=args.device
        )

        # print(metrics)

        loss = metrics["loss"] / args.batch_size

        wandb.log(metrics)

        if loss.isnan():
            nan_count += 1
            continue
            
        train_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_value_(trainable_params, 1)
        for param in trainable_params:
            if param.grad is not None:
                if torch.isnan(param.grad).any():
                    print2("setting nan gradients to 0", tags=["info"])
                    param.grad = torch.nan_to_num(param.grad, nan=0.0)
        
        optimizer.step()
        # warmup_scheduler.step()
        # print(warmup_scheduler.get_last_lr())

        # if i % args.logging_steps == args.logging_steps - 1:
        #     print(f"Loss at step {i} is {train_loss / args.logging_steps}")
        #     train_loss = 0.


 64%|████████████████████████████████████████████████████▍                             | 3251/5082 [03:35<02:00, 15.24it/s]

In [None]:
import random

def gen_random_seq(length, mask_pct=0.75):
    rand_seq = [random.choice(tokenizer.ALL_AMINO_ACIDS) for _ in range(80)]
    for i in range(length):
        if random.random() < mask_pct:
            rand_seq[i] = "X"
    return "".join(rand_seq)

rand_seqs = [gen_random_seq(80) for _ in range(32)]
noise = [{"seq": seq} for seq in rand_seqs]
noise = collate_fn(noise)

In [None]:
print(rand_seqs)

In [None]:
generated = model(noise["input_ids"].to(args.device), attention_mask=noise["attention_mask"].to(args.device))

In [None]:
generated_list = generated.argmax(dim=-1).tolist()

In [None]:
tokenizer = ProteinTokenizer()

for seq in generated_list:
    print(tokenizer.untokenize(seq))