In [1]:
import sys
sys.path.append("/home/aih/michal.kmicikiewicz/evodiff/evodiff")

import torch
import torch.nn as nn
import torch.nn.functional as F
from pretrained import OA_DM_38M

from tqdm import tqdm
import numpy as np
from Bio import SeqUtils

In [2]:
model, collater, tokenizer, scheme = OA_DM_38M()
model = model.cuda()

### generate

In [216]:
def mw_predictor(seq_list):
    seq_len = len(seq_list[0])
    return np.array([SeqUtils.molecular_weight(i, seq_type="protein") for i in seq_list]) / seq_len


class Sampler:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        

    def generate(self, guide, seq_len, batch_size, resampling_steps, unroll_hop_size):    
        sample = torch.zeros((batch_size, seq_len)) + self.tokenizer.mask_id
        sample = sample.to(torch.long)
        sample = sample.to(self.device)
    
        loc = np.arange(seq_len)
        np.random.shuffle(loc)
        with torch.no_grad():
            for step, i in tqdm(enumerate(loc, start=-seq_len), total=len(loc)):
                timestep = torch.tensor([0] * batch_size)
                timestep = timestep.to(self.device)
                prediction = self.model(sample, timestep)
                p = prediction[:, i, :len(self.tokenizer.all_aas)-6]
                p = torch.nn.functional.softmax(p, dim=1)
                sampled_aa = torch.multinomial(p, num_samples=1).squeeze()
                sample[:, i] = sampled_aa
                if self.is_resampling_step(abs(step), resampling_steps):
                    unrolled_sample = self.unroll(sample, loc[step+seq_len+1:], unroll_hop_size)
                    preds = guide(unrolled_sample)
                    ids = self.sample_exp_indices(torch.tensor(preds))
                    sampled_aa = sampled_aa[ids]
                    sample[:, i] = sampled_aa
        untokenized = [self.tokenizer.untokenize(s) for s in sample]
        return untokenized, loc

    def unroll(self, sample, remaining_loc, hop_size):
        for hop in range(0, len(remaining_loc), hop_size):
            ids_chunk = remaining_loc[hop:hop+hop_size]
            timestep = torch.tensor([0] * batch_size)
            timestep = timestep.to(self.device)
            prediction = self.model(sample, timestep)
            p = prediction[:, ids_chunk, :len(self.tokenizer.all_aas)-6]
            p = torch.nn.functional.softmax(p, dim=2)
            p_flat = p.view(-1, p.shape[-1])
            p_sample = torch.multinomial(p_flat, num_samples=1).squeeze()
            sample[:, ids_chunk] = p_sample.view(p.shape[0], p.shape[1])
        untokenized = [self.tokenizer.untokenize(s) for s in sample]
        return untokenized
        

    def is_resampling_step(self, step, resampling_steps):
        if isinstance(resampling_steps, list):
            return step in resampling_steps
        else:
            return step > 1 and not step % resampling_steps

    def sample_exp_indices(self, raw_scores, tau=1):
        raw_scores = raw_scores - raw_scores.max()
        weights = torch.exp(raw_scores / tau)
        weights = weights / weights.sum()
        return torch.multinomial(weights, len(weights), replacement=True)

    def sample_lin_indices(self, raw_scores, tau=1):
        raw_scores = raw_scores - raw_scores.min() 
        weights = raw_scores / raw_scores.sum()    
        return torch.multinomial(weights, len(weights), replacement=True)

In [233]:
batch_size = 128
seq_len = 20
unroll_hop_size = 1
resampling_steps = 1
guide = mw_predictor


sampler = Sampler(model, tokenizer)
sequences, loc = sampler.generate(guide, seq_len, batch_size, resampling_steps, unroll_hop_size)

100%|██████████| 20/20 [00:06<00:00,  3.08it/s]
