In [2]:
%load_ext autoreload
%autoreload 2

In [20]:
import abc
import torch
from tqdm import tqdm
from functools import partial
from torch.nn.utils.rnn import pad_sequence

import proteinbert_gen.constants as consts

from proteinbert_gen.word_freq import create_word_freq_tensor
from proteinbert_gen.tokenizer import ProteinTokenizer
from proteinbert_gen.dataset import sprot_train

In [16]:
class SampleClassBase(abc.ABC):
    def sample(self, logits, x_0):
        raise NotImplementedError

    def post_process_sample_in_prediction(self, sample, x_0):
        return sample


class Categorical(SampleClassBase):
    def sample(self, logits, x_0):
        return torch.distributions.categorical.Categorical(logits=logits).sample()

In [17]:
def word_freq_preprocess_fn(wf):
    wf = wf + 1
    wf = wf.log()
    wf = wf / wf.max()

    # range: 0 - 1
    return wf

def process_fn_in_collate(wf):
    return wf - wf.mean()


tokenizer = ProteinTokenizer()
wf_tensor = create_word_freq_tensor("../data/sprot_1m_word_freq_dict.pkl", tokenizer.ALL_TOKENS)
wf_tensor = word_freq_preprocess_fn(wf_tensor)
wf_tensor

tensor([0.9930, 0.8760, 0.9626, 0.9751, 0.9424, 0.9824, 0.9082, 0.9710, 0.9673,
        1.0000, 0.9151, 0.9418, 0.9518, 0.9393, 0.9652, 0.9708, 0.9611, 0.3411,
        0.9802, 0.8615, 0.5221, 0.9240, 0.0000, 0.0000, 0.0000, 0.0000])

In [30]:
def collate(batch_input, *, tokenizer, word_freq: torch.Tensor):
    input_ids = []
    attention_mask = []
    word_freq_logits = []
    
    for item in batch_input:
        seq = item["seq"]
        ids = torch.tensor(tokenizer.tokenize(seq))
        mask = torch.ones_like(ids)
        logits = process_fn_in_collate(
            word_freq.gather(0, ids)
        )
        
        input_ids.append(ids)
        attention_mask.append(mask)
        word_freq_logits.append(logits)

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask = pad_sequence(attention_mask, batch_first=True)
    word_freq_logits = pad_sequence(word_freq_logits, batch_first=True)
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "word_freq_logits": word_freq_logits
    }

collate_fn = partial(collate, tokenizer=tokenizer, word_freq=wf_tensor)

In [31]:
train_loader = torch.utils.data.DataLoader(
    sprot_train,
    batch_size=consts.BATCH_SIZE,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)

In [32]:
next(iter(train_loader))

{'input_ids': tensor([[10,  0,  0,  ..., 25, 25, 25],
         [10, 16,  6,  ..., 25, 25, 25],
         [10,  8, 12,  ..., 25, 25, 25],
         ...,
         [10,  0,  2,  ..., 25, 25, 25],
         [10,  0, 18,  ..., 25, 25, 25],
         [10,  7, 11,  ..., 25, 25, 25]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'word_freq_logits': tensor([[-0.0508,  0.0270,  0.0270,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0475, -0.0015, -0.0544,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0512,  0.0010, -0.0145,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0524,  0.0254, -0.0049,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0566,  0.0212,  0.0085,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0419,  0.0140, -0.0152,  ...,  0.0000,  0.0000,  0.0000]])}