In [1]:
%load_ext autoreload
%autoreload 2

In [48]:
import abc
import pickle

import torch

from tqdm import tqdm
from functools import partial
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
from collections import namedtuple

import proteinbert_gen.constants as consts
import proteinbert_gen.diffusion as diffusion

from proteinbert_gen.proteinbert import ProteinBERT, load_pretrained_weights
from proteinbert_gen.word_freq import create_word_freq_tensor
from proteinbert_gen.tokenizer import ProteinTokenizer
from proteinbert_gen.dataset import sprot_train

In [52]:
Hyperparameters = namedtuple(
    "Hyperparameters",
    [
        "batch_size",
        "epochs",
        "num_steps",
        "word_freq_lambda",
        "device",
        "hybrid_lambda",
        "lr",
        "logging_steps",
        "eval_step_size"
    ]
)

args = Hyperparameters(
    batch_size=32,
    epochs=1,
    num_steps=2048,
    word_freq_lambda=0.3,
    device="cpu",
    hybrid_lambda=1e-2,
    lr=5e-4,
    logging_steps=1000,
    eval_step_size=4
)

In [4]:
class SampleClassBase(abc.ABC):
    def sample(self, logits, x_0):
        raise NotImplementedError

    def post_process_sample_in_prediction(self, sample, x_0):
        return sample


class Categorical(SampleClassBase):
    def sample(self, logits, x_0):
        return torch.distributions.categorical.Categorical(logits=logits).sample()

In [5]:
def word_freq_preprocess_fn(wf):
    wf = wf + 1
    wf = wf.log()
    wf = wf / wf.max()

    # range: 0 - 1
    return wf

def process_fn_in_collate(wf):
    return wf - wf.mean()


tokenizer = ProteinTokenizer()
wf_tensor = create_word_freq_tensor("../data/sprot_1m_word_freq_dict.pkl", tokenizer.ALL_TOKENS)
wf_tensor = word_freq_preprocess_fn(wf_tensor)
wf_tensor

tensor([0.9930, 0.8760, 0.9626, 0.9751, 0.9424, 0.9824, 0.9082, 0.9710, 0.9673,
        1.0000, 0.9151, 0.9418, 0.9518, 0.9393, 0.9652, 0.9708, 0.9611, 0.3411,
        0.9802, 0.8615, 0.5221, 0.9240, 0.0000, 0.0000, 0.0000, 0.0000])

In [6]:
def collate(batch_input, *, tokenizer, word_freq: torch.Tensor):
    input_ids = []
    attention_mask = []
    word_freq_logits = []
    
    for item in batch_input:
        seq = item["seq"]
        ids = torch.tensor(tokenizer.tokenize(seq))
        mask = torch.ones_like(ids)
        logits = process_fn_in_collate(
            word_freq.gather(0, ids)
        )
        
        input_ids.append(ids)
        attention_mask.append(mask)
        word_freq_logits.append(logits)

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask = pad_sequence(attention_mask, batch_first=True)
    word_freq_logits = pad_sequence(word_freq_logits, batch_first=True)
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "word_freq_logits": word_freq_logits
    }

collate_fn = partial(collate, tokenizer=tokenizer, word_freq=wf_tensor)

In [7]:
train_loader = torch.utils.data.DataLoader(
    sprot_train,
    batch_size=args.batch_size,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)

In [8]:
next(iter(train_loader))

{'input_ids': tensor([[23, 10,  0,  ..., 25, 25, 25],
         [23, 10, 16,  ..., 25, 25, 25],
         [23, 10,  8,  ..., 25, 25, 25],
         ...,
         [23, 10,  0,  ..., 25, 25, 25],
         [23, 10,  0,  ..., 25, 25, 25],
         [23, 10,  7,  ..., 25, 25, 25]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'word_freq_logits': tensor([[-0.9607, -0.0456,  0.0322,  ...,  0.0000,  0.0000,  0.0000],
         [-0.9570, -0.0418,  0.0042,  ...,  0.0000,  0.0000,  0.0000],
         [-0.9610, -0.0459,  0.0063,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.9529, -0.0378,  0.0401,  ...,  0.0000,  0.0000,  0.0000],
         [-0.9670, -0.0519,  0.0259,  ...,  0.0000,  0.0000,  0.0000],
         [-0.9460, -0.0309,  0.0250,  ...,  0.0000,  0.0000,  0.0000]])}

In [53]:
def denoise(targets, timestep, attention_mask, *, model):
    ret = model(targets, attention_mask=attention_mask)
    print("denoise output:", ret.shape)
    return ret

with open("../weights/epoch_92400_sample_23500000.pkl", "rb") as f:
    _, pretrained_model_weights, _ = pickle.load(f)

model = ProteinBERT(tokenizer.vocab_size, consts.GO_ANN_SIZE)
load_pretrained_weights(model, pretrained_model_weights)
denoise_fn = partial(denoise, model=model)

In [54]:
optimizer = AdamW(model.parameters(), lr=args.lr)
warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda n: n / 10000. + 1e-3 if n < 10000 else 100. / math.sqrt(n)
)

In [55]:
sample_cls = Categorical()

diffusion_schedule = diffusion.create_discrete_diffusion_schedule(num_steps=args.num_steps)
diffusion_instance = diffusion.MaskDiffusion(
    dim=tokenizer.vocab_size,
    schedule=diffusion_schedule,
    tokenizer=tokenizer,
    sample_cls=sample_cls,
    word_freq_lambda=args.word_freq_lambda,
    device=args.device
)

using standard schedule with num_steps: 2048.


In [None]:
train_loss = 0.0

for epoch in range(args.epochs):
    for i, batch in enumerate(tqdm(train_loader)):
        print(batch)

        metrics = diffusion.compute_kl_reverse_process(
            batch["input_ids"].to(args.device),
            diffusion_instance.sample_t(),
            denoise_fn=denoise_fn,
            diffusion=diffusion_instance,
            target_mask=batch["attention_mask"].to(args.device),
            hybrid_lambda=args.hybrid_lambda,
            predict_x0=True,
            word_freq_logits=batch["word_freq_logits"].to(args.device)
        )

        loss = metrics["loss"] / args.batch_size
        train_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_value_(model.parameters(), 5)
        optimizer.step()
        model.zero_grad()
        optimizer.zero_grad()
        warmup_scheduler.step()

        diffusion.discrete_diffusion_elbo(
            batch["input_ids"].to(args.device),
            denoise_fn=denoise_fn,
            diffusion=diffusion_instance,
            target_mask=batch["attention_mask"].to(args.device),
            normalize_without_padding=True,
            eval_step_size=args.eval_step_size,
            word_freq_logits=batch["word_freq_logits"].to(args.device),
            device=args.device
        )
        
        break

  0%|                                                                                            | 0/10164 [00:00<?, ?it/s]

{'input_ids': tensor([[23, 10,  0,  ..., 25, 25, 25],
        [23, 10, 16,  ..., 25, 25, 25],
        [23, 10,  8,  ..., 25, 25, 25],
        ...,
        [23, 10,  0,  ..., 25, 25, 25],
        [23, 10,  0,  ..., 25, 25, 25],
        [23, 10,  7,  ..., 25, 25, 25]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'word_freq_logits': tensor([[-0.9607, -0.0456,  0.0322,  ...,  0.0000,  0.0000,  0.0000],
        [-0.9570, -0.0418,  0.0042,  ...,  0.0000,  0.0000,  0.0000],
        [-0.9610, -0.0459,  0.0063,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.9529, -0.0378,  0.0401,  ...,  0.0000,  0.0000,  0.0000],
        [-0.9670, -0.0519,  0.0259,  ...,  0.0000,  0.0000,  0.0000],
        [-0.9460, -0.0309,  0.0250,  ...,  0.0000,  0.0000,  0.0000]])}
torch.Size([32, 4, 496]) torch.Size([32

In [44]:
metrics

{'loss': tensor([121.8064], grad_fn=<AddBackward0>),
 'denominator': 1,
 'hybrid_loss': tensor(89.5719, grad_fn=<SumBackward0>),
 'base_loss': tensor([32.2345], grad_fn=<AddBackward0>),
 'cross_entropy_loss': tensor(17546.2441, grad_fn=<SumBackward0>),
 't0_loss': tensor([0.], grad_fn=<MulBackward0>),
 'kl_loss': tensor(32.2345, grad_fn=<SumBackward0>)}