In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import abc
import pickle
import math

import wandb

import torch

from tqdm import tqdm
from functools import partial
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW, SGD
from collections import namedtuple

import proteinbert_gen.constants as consts
import proteinbert_gen.mask_diffusion as mask_diffusion

from proteinbert_gen.debugging import print2
from proteinbert_gen.proteinbert import ProteinBERT, load_pretrained_weights
from proteinbert_gen.word_freq import create_word_freq_tensor
from proteinbert_gen.tokenizer import ProteinTokenizer
from proteinbert_gen.dataset import sprot_train

In [3]:
Hyperparameters = namedtuple(
    "Hyperparameters",
    [
        "batch_size",
        "epochs",
        "num_steps",
        "word_freq_lambda",
        "device",
        "hybrid_lambda",
        "lr",
        "logging_steps",
        "eval_step_size",
        "clip_grad",
        "clip_grad_val",
        "warmup_scheduler",
        "optimizer_cls"
    ]
)

args = Hyperparameters(
    batch_size=64,
    epochs=10,
    num_steps=4096,
    word_freq_lambda=0.3,
    device="cuda",
    hybrid_lambda=1e-3,
    lr=1e-3,
    logging_steps=25,
    eval_step_size=4,
    clip_grad_val=10,
    clip_grad=False,
    warmup_scheduler=False,
    optimizer_cls=AdamW
)

run = wandb.init(
    project="proteinbert_gen",
    config={k:str(v) for k, v in args._asdict().items()},
    # mode="disabled"
)

[34m[1mwandb[0m: Currently logged in as: [33mmattfeng[0m ([33mkaiogenbio[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
class SampleClassBase(abc.ABC):
    def sample(self, logits, x_0):
        raise NotImplementedError

    def post_process_sample_in_prediction(self, sample, x_0):
        return sample


class Categorical(SampleClassBase):
    def sample(self, logits, x_0):
        return torch.distributions.categorical.Categorical(logits=logits).sample()

In [5]:
def word_freq_preprocess_fn(wf):
    wf = wf + 1
    wf = wf.log()
    wf = wf / wf.max()

    # range: 0 - 1
    return wf

def process_fn_in_collate(wf):
    return wf - wf.mean()


tokenizer = ProteinTokenizer()
wf_tensor = create_word_freq_tensor("../data/sprot_1m_word_freq_dict.pkl", tokenizer.ALL_TOKENS)
# wf_tensor[tokenizer.mask_token_id] = 0
wf_tensor[tokenizer.pad_token_id] = 0
wf_tensor = word_freq_preprocess_fn(wf_tensor)
wf_tensor

tensor([0.9930, 0.8760, 0.9626, 0.9751, 0.9424, 0.9824, 0.9082, 0.9710, 0.9673,
        1.0000, 0.9151, 0.9418, 0.9518, 0.9393, 0.9652, 0.9708, 0.9611, 0.3411,
        0.9802, 0.8615, 0.5221, 0.9240, 0.0000, 0.0000, 0.0000, 0.0000])

In [6]:
def collate(batch_input, *, tokenizer, word_freq: torch.Tensor):
    input_ids = []
    attention_mask = []
    word_freq_logits = []
    
    for item in batch_input:
        seq = item["seq"]
        ids = torch.tensor(tokenizer.tokenize(seq))
        mask = torch.ones_like(ids)
        logits = process_fn_in_collate(
            word_freq.gather(0, ids)
        )
        
        input_ids.append(ids)
        attention_mask.append(mask)
        word_freq_logits.append(logits)

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask = pad_sequence(attention_mask, batch_first=True)
    word_freq_logits = pad_sequence(word_freq_logits, batch_first=True)
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "word_freq_logits": word_freq_logits
    }

collate_fn = partial(collate, tokenizer=tokenizer, word_freq=wf_tensor)

In [7]:
train_loader = torch.utils.data.DataLoader(
    sprot_train,
    batch_size=args.batch_size,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)

In [8]:
sample_batch = next(iter(train_loader))
print(sample_batch)
print(sample_batch["input_ids"].size())

{'input_ids': tensor([[23, 10,  0,  ..., 25, 25, 25],
        [23, 10, 16,  ..., 25, 25, 25],
        [23, 10,  8,  ..., 25, 25, 25],
        ...,
        [23, 10,  0,  ..., 25, 25, 25],
        [23, 10,  3,  ..., 25, 25, 25],
        [23, 10, 15,  ..., 25, 25, 25]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'word_freq_logits': tensor([[-0.9607, -0.0456,  0.0322,  ...,  0.0000,  0.0000,  0.0000],
        [-0.9570, -0.0418,  0.0042,  ...,  0.0000,  0.0000,  0.0000],
        [-0.9610, -0.0459,  0.0063,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.9601, -0.0449,  0.0329,  ...,  0.0000,  0.0000,  0.0000],
        [-0.9164, -0.0013,  0.0587,  ...,  0.0000,  0.0000,  0.0000],
        [-0.9635, -0.0484,  0.0073,  ...,  0.0000,  0.0000,  0.0000]])}
torch.Size([64, 505])


In [9]:
def denoise(targets, timestep, attention_mask, *, model):
    ret = model(targets)
    #ret = model(targets, attention_mask=attention_mask)
    # print("denoise output:", ret.shape)
    return ret

with open("../weights/epoch_92400_sample_23500000.pkl", "rb") as f:
    _, pretrained_model_weights, _ = pickle.load(f)

model = ProteinBERT(tokenizer.vocab_size, consts.GO_ANN_SIZE)
print(model)

trainable_params = load_pretrained_weights(model, pretrained_model_weights)
model = model.to(args.device)
denoise_fn = partial(denoise, model=model)

ProteinBERT(
  (embed_local): Embedding(26, 128)
  (embed_global): Sequential(
    (0): Linear(in_features=8943, out_features=512, bias=True)
    (1): GELU(approximate='none')
  )
  (blocks): ModuleList(
    (0-5): 6 x TransformerLikeBlock(
      (wide_and_narrow_conv1d): ConvBlock(
        (conv_narrow): Sequential(
          (0): Rearrange('b l d -> b d l')
          (1): Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=same)
          (2): GELU(approximate='none')
          (3): Rearrange('b d l -> b l d')
        )
        (conv_wide): Sequential(
          (0): Rearrange('b l d -> b d l')
          (1): Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=same, dilation=(5,))
          (2): GELU(approximate='none')
          (3): Rearrange('b d l -> b l d')
        )
      )
      (dense_and_broadcast): Sequential(
        (0): Linear(in_features=512, out_features=128, bias=True)
        (1): GELU(approximate='none')
        (2): Rearrange('b d -> b () d')
      )
      

In [10]:
optimizer = args.optimizer_cls(trainable_params, lr=args.lr)
if args.warmup_scheduler:
    warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda n: n / 10000. + 1e-3 if n < 10000 else 100. / math.sqrt(n)
    )

  return torch._dynamo.disable(fn, recursive)(*args, **kwargs)


In [11]:
sample_cls = Categorical()

diffusion_schedule = mask_diffusion.create_discrete_diffusion_schedule(num_steps=args.num_steps)
diffusion_instance = mask_diffusion.MaskDiffusion(
    dim=tokenizer.vocab_size,
    schedule=diffusion_schedule,
    tokenizer=tokenizer,
    sample_cls=sample_cls,
    word_freq_lambda=args.word_freq_lambda,
    device=args.device
)

using standard schedule with num_steps: 4096.


In [12]:
train_loss = 0.
has_nan_log = 0
nan_count = 0

# torch.autograd.set_detect_anomaly(True)

# def _save_output(module, grad_input, grad_output):
#     print(module, grad_output)
#     print(torch.isnan(grad_output[0]).any())
#     print()

# for name, module in model.named_modules():
#     if str(type(module)).find("LayerNorm") != -1:
#         print(name)
#         module.register_full_backward_hook(_save_output)

for epoch in range(args.epochs):
    for i, batch in enumerate(tqdm(train_loader)):
        run.log({"epoch": epoch, "minibatch": i}, commit=False)
        
        optimizer.zero_grad()
        diffusion_t = diffusion_instance.sample_t()
        # print(diffusion_t)

        metrics = mask_diffusion.compute_kl_reverse_process(
            batch["input_ids"].to(args.device),
            diffusion_t,
            denoise_fn=denoise_fn,
            diffusion=diffusion_instance,
            target_mask=batch["attention_mask"].to(args.device),
            hybrid_lambda=args.hybrid_lambda,
            predict_x0=True, # False,
            word_freq_logits=batch["word_freq_logits"].to(args.device),
            device=args.device
        )

        # print(metrics)

        loss = metrics["loss"] / args.batch_size / batch["input_ids"].size(1)

        if loss.isnan():
            nan_count += 1
            if i % args.logging_steps == args.logging_steps - 1:
                run.log({"nan_count": nan_count})
            continue
            
        train_loss += loss.item()
        loss.backward()
        if args.clip_grad:
            torch.nn.utils.clip_grad_value_(trainable_params, args.clip_grad_val)
        
        has_nan = 0
        for param in trainable_params:
            if param.grad is not None:
                if torch.isnan(param.grad).any():
                    param.grad = torch.nan_to_num(param.grad, nan=0.0)
                    has_nan = 1

        has_nan_log += has_nan
        
        optimizer.step()
        if args.warmup_scheduler:
            warmup_scheduler.step()

        if i % args.logging_steps == args.logging_steps - 1:
            run.log(metrics, commit=False)
            if args.warmup_scheduler:
                run.log({"last_lr": warmup_scheduler.get_last_lr()}, commit=False)
            run.log({"nan_count": nan_count, "nan -> zero": has_nan_log})
            has_nan_log = 0

    # generate some proteins
    generated = mask_diffusion.discrete_diffusion_predict_fn((8, 200), denoise_fn, diffusion_instance, topp=1.0)
    generated_table = wandb.Table(columns=["gen_id", "seq"])
    for j, genseq in enumerate(generated["final_state"].tolist()):
        genprot = tokenizer.untokenize(genseq)
        generated_table.add_data(j, genprot)
        print(genprot)
    run.log({"generated_proteins": generated_table})

    torch.save(model.state_dict(), f"../checkpoints/{run.name}-postepoch-{epoch}.pt")


100%|██████████████████████████████████████████████████████████████████████████████████| 5082/5082 [06:03<00:00, 13.98it/s]


^MWGLTPTLDMSFAGAQKGLIDHYILRFALIEELVSVVAQNEGADLATPPNGELGVYDSGFEDINLKSLVAYGVVSIPARTRPDGEALLVDHYAVASAAGADDILAGGRDGVLPGLLLVKMEFVVIDYSESLYRYDLD$$IELRVIGTPQLDTPDKEVSDGSASKHKRQFLWYVACPTILNGTIMRMRTGHIAEREEGAV
^MFFSAGTFLYGMAVSISISMAKRVTLDGGRVAESAVLNSYAWYKDIVAQCSPYENEFFFDNFSEVVTVAMVLLTDYLGPKRLSQIIKEAYFIARISCSNRSHELHPHLSIDPRGLLLAIGGYLGLEHT$LNKGVCEFLLIPGSANGPLSRPSAILSKLAVKDAEIRTSTPTAALKGLPQNGAKRMNANRMGLAASQTGL
^MRYMGAFDQWERHPYRLACSEPQLEASDSTCHLRLINGLEEIST$SRSILHFADVHITINGVTFYKDLVSIALLPLSKKLFLAAPFFISRQGFFLQYTAARLFTQYEGSSQVMKILPEQHICR$TFDNRLLWQLLVLEAGSSHIIKVTPSFVPLPGFEVLDDSGGCGGNQKIGQATEEEVVALFPNGVRAIMVVSLGPG
^MRKEWRISLAARFPAHFIRSVEAMSPPLSFMYVVDVGTKSGASSAIVVLNLYRTDFITTAYAYFTRGFINLAPLIFIHKMFSQPLTVMQIRRRKIVHRPVVGDFSNHALGAINKVNGQPTGYATNLLQATRSRVDTGDCNVIDAVLTEALGSGEWKSAFPERIGIIPMTIVARRTDLALLAFDFRSGEIAIELLDADDG
^MMQEGDLPIPGYAAFMVYGDPTQEEIIIAKKRFNQYAGLSQFNIKTRAIEKNSALVPVDIMYFTKHMKDPSMHKVENRELHAGEGELPAQTSLTIQVAVPNLWPTMHRFEPWGEYLMTGEQKEILSGAQYKGTMFPFLMLSTGLEAPKCSKRGRLSANKDREINTSRIPPMLTPPEWPIVISAEQAEHTDNTGSN

100%|██████████████████████████████████████████████████████████████████████████████████| 5082/5082 [06:02<00:00, 14.01it/s]


^MSRIHPDISGIDAIEAKLAYAVVADLKRLIGKLRRQSLWRHPRPRRRMAAWAAQLEKPRGRRASADAIRIDTISAEPKLWEPREIMIGADITATITAAVGPAIDSGDIAAAARLIIIPRVPELRRRAAAIAAIAIQIIGPLPRYPSSPARTISRALRRTTVAALQRIAAIPAQPAFLHDIAGHDIPHPRRIITISPRHS
^MISHQALRALASIRHVALRATRAVAVAIARAVEPEQRIARQYVRLAQRQLLAPIGAARAQEEFALTRAIRQKQAQIIREIKRIIADGPPPIAAPGIVIALPKEITIKRISSPEAMRGRARRSELSQHADRPAWVQAADITSVDQAIADMPEALRCLITIDAEDVPTIDDSRAALGIPALRPVVGWVSRQSVQIIRRIRK
^MQVIAALRAIVRWNSAPDSTLSDRIDHYHLIDPELTDILALAIEQLPPHAAVAAIAATLRRIRASGSKRRPRIRLLRKADWLRLPKEQDQHLPAPDCITILGAITEPARIRVVIASSLLRLADAQHRIVLGSAAAGASASERLYHELRPLLRRAIDQAITPLAATLAAISDTVRQAGRAAAIRPRRQTIMIDRPARRIK
^MPPTDAIVVDELSYESAAPLAAIVVVGMPIRHTAIPDPSTVMADCDHVDAVEWIAAPWRRADLIARIIPPAIHAAAASQRQGQRIGTIPRPGQPPRQMVAWPPRPDADIAVAAATVDAIAAEITAAIAAIPQGAAIRALLRKAAEPRKRAAVRRARQMIARLGRLAIPQDGIIIDAARPITRYPRPTVIAITAPPRERS
^MRFEIDLAAAAIAVAVHPRLVHDAIISHPDRDISLPWREASPPSPRRTRRIRTIRARIAQAIEPAQPPMPALGTIPRDPTIDIPIIAAAPSIAAAIAAIAAAPPQAISAIAAAAFRRLPARIPSPRGRQHQEIIPTEIDPVIHSIVQASPAPEIIHPIPLLAPSYATTSDGAIPLALPDQRTIRIIASAEARIIS

100%|██████████████████████████████████████████████████████████████████████████████████| 5082/5082 [06:03<00:00, 13.99it/s]


HITRHNVKIVVGKPNVVLSQGHEIISFVNVQKKSIPKCPFRIHILYCTIDIIDDRPIKVTAKIFQKINDSFSDCPNETVIIYVIETDNEKQGVVIKVVTDPDPHVFVYTKNKETDFSTPMIKMGETGECTIIINQDPGTITIVEECTKKTIAKGATRFILTKEIKDVIKPPKTIFCVDDVGIVAKTEGIVERFPVHIKD$
IIKNIFNVASFEEIDLENDEVITKKFTGILIDFEHVSEKEAQVPFKTATNPATIIEFNIFIIAGTKGIVKETIKMKKVKIANQITKVYIDKTIKNVFKVLKIIDKQAKITKIHILKKKKQVHEKVIVRGAADAVTEIFKINPDKVVGVIATGFDVVADQKEREHNIIANGNVITKAAAILMTMVILQDTIREIQKSFKI$
^MKLVKTNEKERKIIETTDGDIIREEIFEKTIKDKDARVMVAEKTITVTDLSVHMGIDIAEMRKKLNVASTGGILTIRETEPIVPVGDGTIDRIDTPDHDVIITKDQIKEETEVEITKVRRFLGKETATTLDKGEEVEAVDYDIMIKPVGTTAPEPGTGDSIKTETKGILIAKGGKIDVKGAVVRAVVVTENKIHHKGE$
^MGEVTIHKSGRGDIDFARGPTIVKPFVIITIGKRIPPDVPKAKKTTDTFEADDNTPVVVDNTKSTVDDEVTMKPGVGAMDYKTEDRIIVTTGKIKGVDFEEIPTVIGIPKIMHRPKILVDDITADRLEIKAILDIFKKDPISDVSVAVGKARIKKVTGIGIAITDMISIDGDGIHFYKDTKSLATEKRIPTTAHIKSAK
VIEFIRKKYQTKDIPITIAEYFVAFFGVTIVRENNETSIECIINKNITHIIRIVAVADIKKKEIKVIDKPFICIICDVTDKAQPVAAPTTVEDEHHNVVAIPEAGTDTITDKDILPIPIKPTAEAIGIKKIEGGHTHIRTFKDIEFINKKDVKHIVFYKKEEKVIVKRGITVIFSSDDFFIIPIDVKHHTETTKHK

100%|██████████████████████████████████████████████████████████████████████████████████| 5082/5082 [06:04<00:00, 13.95it/s]


^MSSIPAPWFAAAAPPRLFQAQAAAVIAQIAGAGFEPLLQEFAEIPTIEPVEGEEPAIFAARGAQAIAPGSIQVIDAIEVLEGQYDGEESILAAFIQPNLQEYVEQSAFARQIATGEREREAAVPTPVVAILAEAGLQESIAILPEEQVGEPIPAGVPAGLPTAAASPAPVAVAAIARLLFAGIAFLLLEAQFHYPEEAS
^MAAFGAESPSRAAIAIALTRIAEACPLVEFEWVAPERDELEIFSSAPVPAAAAVAGAEAQTALAAAPEAEPILGIAEEMPPELAAALEAELGESLAGAARAIVGIGGRPITAPGLPEAPQAIMRRFARPLAAIAAELVGNEAFASVEALVGVEIQEFTPIIEAAGPERREEAIVPAIAEAFPEIQPYAEVIAAFFPEFS
^MGAVLFELASALAQPPIEFNRNQAIIPAQAAMELSPLAAPAPPRIAASRKLAVARFGAENEQFIAISAICPELPQWFTTFQALIERNSLNQIFASAQSWLRGTGAQPFLEVGIGIDGLSRSQASIWPIEGAIRRFLVAIGQGEGEAAIVWAASFPEMRESIEALVPSSQAPPNVELGQPFIAAFPILREAQIPFEAEIA
^MSASIWEQLGIPAFRALQAGGIGIIEAPEEVRGLRFATAIPALSGQATIAAQAAIGRIPQIIAANGFIEFESARLYAAEGEAVPFDGVLGFRFAEELAAALAATRVPGAATVRAIEPVVPAEEFALAAGQGRILVAAGAAQILSIFARAPQREALRGFGRGASLTIPREGFAGSLAALPEFAIAIEGEQEAPWIIAIAE
^MNPSFLAAFGIIPAIFAPARFSAAVAAGIALAAGLPPGFAIALALEGRLQLAEAAEPAPIVTIANEPQLSPVQAEIAASARLQIFLNNSPFQPAQIQAIAALPQIARPANIFTSQEGEPIAIAASQRALVAAGQQGSIFASIQREIAERIPAAIQEAIEAGIPGNQLAVIAAFAQPEQIEAANPIIFIQLLRAFP

100%|██████████████████████████████████████████████████████████████████████████████████| 5082/5082 [06:12<00:00, 13.64it/s]


^MKITTIHNLIAEKKKATERYYPTAVEHYFLRVRMALLIMFELLTPHGKESKHFPEVKLAQVILLKCGFAACNMEHRVKLAEHVLYTQKKKTGHRVPTPDALCDVEIMTMKTGIEAVEIAKKRLVAMQWLETDMHRTTFHSGHYRSGRAEIMTARALLDISGLKYVGEGTDLVQPKAQVAGMKHGITLCSPGIETSPEAK
^MVELRQKRPCSTIQLNFVNPWHLVTIQIVVVGMGSGEHVFKSIFEHCYTQVPGEGRMVLTDSLKLGHVKIDMVVQMTLERSGTTLCGLRLENALTFYGIWRYVNHAQEDLKQGYKVHGQWLVRLIPAWKLFQSTNKDHHNVVVYVIGQMRIPEKAWLSVHAILKCCADLKKCHHFFGKRKGVPVTIVVSCFFITQWPR$
^MSIFTKEVCSICGSILRMTKLRIVMQGMKKTLVLHLLMPTEIFEKPQTTLIMEQIRFMPCVLPAGGHAYIMRNVETVRISLVGHVICELPFSLKQMCERFQYSGRLIGITFVTQVEFYTMCSSGTMLTVQRLPIISQPLLQPFKKRPHLQGQGWLPKTITTVMLRSSELMLKGYMPVVIVILLIAILLMHKRHHKHLKK
^MSLDNPCPSCTPPPTDIYQICTWVVSHSSGVDMVHVVVLLMGAEGLFRFEDITRLIQSASTFRVVSISECVWRSGILARSEGYNVGLYQQLKQMQKLGVEELREILKEMNGVVGTTSMHVKVFEFMNNGHTYVYKRSHACSVTGDPRGAGSLYENHSPVTRDMMQGTFRRGYTKGQPFAGGVPGIVHLVGCWIYGPAEK
^MKHRSEPTLAKSQQKSLVTGLQMPVRVGWDLAEIMQFIMNQFTCFECIFDKLFEHVNITWKELCKSAHVIEVLNRFQVIDRLKCAELKTFCVPEITDEFCVAHLALTKRIEKRLQKPYLAESIGMICLLVIVVISVCSAGFVVTCVKRLQWFQKFNDHGLCTLLALRVLDGPFPMEPFCIYITSLPHLLFVQHCR

100%|██████████████████████████████████████████████████████████████████████████████████| 5082/5082 [06:03<00:00, 13.99it/s]


^MSSPYTVYGITRCGEIAFKLTNLCISHEFTCSIPKVPYSNPKCIAYMMEILSMFYVGVRYGDETEYALISVTGEIAGLSQPVRDSGIRLRWEAKFPAKRVPLLKYVITRPLKYGPTLKNISRPSVGFQESIGAMGTKITFESAGQYGERIVFAQSDVAVLPQSIGAITTAVAQNPSYLRARMAVKKVLLKVRPNMPFKK
^MRRWYVPYLLHSATMAMKLSCRARARDRLSRRTAGRRQRAFPSRGLSRLLMPDLVHPGIRLGRFACMFQVSYTERAYDRFRTLDFLLLALSRVSSRLDMFALRADQAMSPDPYLPAMVLVLARRHDPPIVTHCPRFVTERIIHRVLDATWRSGISTSMSCAYRRGLPRLAIYGVPVSAVMGPQQTLHPDWPARLPHPTS
^MISPLADPLLTKQARRGWRSPILEPLYKTQMMSRHTIRNRAVKINDAALDLMRMFVAVSTVSRMITRRFHTPWTPDHKTPRYEGSHLVRTCCAIAPPYIGSFEISQLLEMGARDSGFFTRSYRKWSNDTPDMDIPAEQVLDLVRAHSFTPRTYITLTLKNRMDDTITVIYRDRCEGIHTLELLEMFAYTRYEPQPKPS$
^MQIEKLIELLRNSRPLKRLCNDRKVNTPAHMYRPELYHGLSITGGMTYKTMMQVSKSHELGRANGVARLKITKMIDDEAVVAQIGQVYGNQDQYLGVASTIVQFSPLALMQGKEGVARYNSVEQEIRTFYVGAPLTVPSPDLEKAVRKYFSTKHARDEYTYQMSLYGYKTRPERVVGINFSVVTRIVKEIPWIVSQ$$$
^MNFIRGIATWPSSSRLRAINLTLRDYYSMPILEVFTTRWKRSHFQLRMHSLITSLGIQSTSPTRMNLLRYDVKSPITMHLRLYGALRRRGAKAVVTRFRLRSFEQVPRGETEDMIEDGLPEVATSVDDLGRRQDRAFFMVYTMARRFLSPRLLLKRRTVGLAAKRFFEACSYPDFATIFYSDFIRDAEGIPTSHS

100%|██████████████████████████████████████████████████████████████████████████████████| 5082/5082 [06:25<00:00, 13.17it/s]


^MVEVRPRTTGIMASDVFGTQGYLLDSLEDAFVETLIKKQCDSWGDEDYSDEAVAVVYWAGKEWERLFFYGDEITLCGVTTYLSSTYVSQTSNISTSKEWWAISKAFYDFTQKKLRTGEIAFDLNSKVWPTANAPRAYHYTPQDRVWSSCSSKTMTILAMVLTFSCHTILIEDFGDFYDYEYSDTSANYFGGSHEVELGK
^MELPIDINGTGAHFLQRMLAGFDVAEAKNITRVLVHGAAIPLNGSLNNKMMKTDEIISDKLSQVSTKGYDALKIETSLSSLYEAYEYDKGSVRVTMEIYDLVLVTGVGMPDGVLVNIIASSKNGYSLQDEMSQLLTEVSSSPPKDIMMPEDSVLRMVYHFNSGSHHRVVLASESYLKDYVETGSSSEDMEAFVENNGKK
^MSSMLRENISPGQRMWMVKCIHHSQSYTSKPKIISYIVSDILFSQAVWEDMVSSLGEFNELASVCSSISEKGFLWPEFSFKNHSAVLLHFDHRWSTSWYCLCHGVKLCSHHVSTYVWVYYDCVAVHTVTSIRTPGMVFVESSSNSDPGTTRPTISISKWSMRVCPLKEKDGYDFDCSVSEGIDVYGTEGSMIHTLEKSQ
^NISSYGTYTIIRSMSISHRTSSIQGTYYTITTYNYTFVETTYRTMPSSIGQFVSGDPMIQLSTLSHQMLTESHTCNILLDLTDSWSSRSLSRETCKTGEWLENSETGYARGENRALVEPVTLEDEPTGWLNLLVWFSCFDATLAVVPILDCRSSGHTYVLKQGIYGVFTDGSTILSHSISEKESKECSGCMFIYT$GRK
^MDVLCLTDLGWSRVGPSYDTRFSWTRLRVRTTFTELFGTGSEYGIGPSCFLYWAGGEHEGQDAYLATVIDVIHDLIILGTSSFYSRGDWFSEDASGYQACSVPLAEYGFVHELYYYKKSYWLGVSLHELRSAHLMRVDSAMQSIAKSMAAELNYNSVPVVLQGSRGVKAHEAYTVFVETYVWYSLKINTYVTSES

100%|██████████████████████████████████████████████████████████████████████████████████| 5082/5082 [06:07<00:00, 13.83it/s]


^MSEIRTHKLVCSVQQKLYEWKIELLVTKVGILCGLYLPGTGKSSTYIGNLQKTTLLELKGVCKECELKEGEDLLCAMIERALHQERILPLALQMDELGCELLNCLGSATSAAMLIKQKSMERPKKQKLEEACLPLRELLRQCQSYELCGELSGIPGELDLEHELKVFYQLVLIHPLMSAPINTIAAAKTSAAPLEKKLP
^MHCSRAVLVGEAELCKMHSHVESNEPTCGLCGLVISSGIELPELKTLKHKAADILALRGGAYVGPCMAKVQCHVELQGCETPGHLKGSSTILSGIKLYEAPLLCEALTSICLELFEEVGVKVSLMTCPGVHRPEKLQQKLMQAVGMEKYNPLVGIACVSGAEALQLMPETRAAMKCLNLDEILASCGSLSAAAAEKAHE
^MSQEQQHTLGVKTYLNSWESCVVYVVCHGCQQTKRRALLFSCGSSTSLWSSLLKQIGLSCTPVPQESAKMWMMKIALVSEELMEKLAKWEGKMAKACKVIALIELALRKKSRSNCVDKDGTYVCLNEMELCMALSSEKSRTCLIKARNLWEKYLLVLSQAPGVVVPLTSKASTHGYDLCSLGTTSPTSCSSCRAVMLDG
^MCWAPEGVLCLLTALLKKQEMLCSKIVLNITLAIVLGRNSKKSCLETAALKPCQKKCAHCVEKACACQCQPCLYELTTVLLPRKASKKASWCQSITCNGLAAILPYGLSKKLLCQAQDLQKKPELWSVCKGCGEACIRAATPNERLKLVVVVSPQLVKLRQRSASPCPAALWLLCGKGYVILSESRLSAWLIHAAKPHN
^MTAMLENPYLRATSSAAANAMLACETGLPCIALANVIVADMWAPLQAACPRLFSAGLTSDLPKHMDVVKVVCACSPPSVTVSPWLRALYAPTCLALLSAAMSWGETVRTSKDFQAMKYAYKPCSCRLALWHAIARASSTRKLPLPCYLTAYQPAVEDDLDACDVLPQKVAKAKALADALRRIIALLLKNPDATAF

100%|██████████████████████████████████████████████████████████████████████████████████| 5082/5082 [06:24<00:00, 13.21it/s]


^MRWTLSVIGKALVIWHIQGVVAPESPGMAVNLKKITEIAGQSYNVYIDEVQGKLVNPLSECRRSPGIVKLLGVEETIHPSLAKRAFSSKEYRLDDYISPVLVEECGEFRSRSIYLMPREHYLSSGDVYAGPALEGMLWGQILILENVMGFRASIGEGLAEVAFCSLHVGHLSERPVAEGKIKVCKHSIKLSQGPAGLA$
^MAICKECKEVPGICGSQKGLIWGRLKLLAESEAVLKVSIKSITEPRFSGRRMRSSMAEINETIDSGYILRINIPPMILIERAIKTSIEEILGHGALAVHRCGEVEIEGEGQFLVHAPLTSGKVEVILRGVSEGKGLAVQMHAYLRARESYFEVFYEGIKCLFILGSKREQAGCTRICTPEDKRTKIVRVSKPEHPGKT$
^MAEYLLLAEQIVKRTWLDLVGVTTRYAIAIDNIVEVQSTKFYRKIALPEKVYVMILNNSDAVVFNLELTRLKVDARPLIPYCISVTTFTGVILDQGSAPEDAAKRAAAAIVALIDTAKGREPYGVPLEIIATWFRMAKDAAKQGSLRYLPDISEMLVYATGCSNRFIAERLSAGMHYNPQSIPEKIKNRQSLIHPQQK$
^MSCSTVERRSELIWKKVKEKEPARREIATFQLAPALGGPTCVFPPQVSTLGVAVGEIEIVPGWSGYDYPKYILAETKMPARWGVDKFLRFLKLETAPPTAHAMCGVNYIKSERAAILKREKREDLVPVIEAFLNTSRTAMLERLSVLMILVPESPARCAYPISTTMIVLIEGKICIICYGQQLIAGKLLYGREKKHSQ$
^MRISDCLGYPKLVIAAIIDSMQKEVSSGLWRTFVAPPAKGPLACPSPCFAVAAYRFHDGTLKHVSAKEMLSDQIDVISFQGGQEVTIGARISVIWGMPVERYCKPAKVFIGSIKTRGRDFVSMGERTLYSETTILNEIVGQTMQVLHSLGPERMSPAIRSYNEDEGISIPAEIEPQSLIRTSPHEFGPRASSLLN

100%|██████████████████████████████████████████████████████████████████████████████████| 5082/5082 [06:02<00:00, 14.00it/s]


^MSSNSKLKYRLPGHLAIMSRDKNALNVLAVVKLNWLLRARAKQAVLSVCEITEVLATLVGGMLFLLDLIDSNELSRVFHRLVRVIPHLMTGTSQRNFYLLDADVMLLMGLGAMAGRVLLMLVVTFGRKSSNPQWQWYPSPHRGGIAISCEFETLFANGGVDGQFGIAKRGESTCTISASSAGFLQSCQRKEERQTYK$$
^MAWSVVDIFLGCEGGSGLGMTDAVGGGYAVAIMLCIGIIKIMQPPRKKQQEWTKYEVGFAVSVTIELGLQARPVKVLVTPGIKAPKTCEAKRKVPMPDSVFIPDATPENIERHDYTNVSFAHEMCEHQFHSFRKGQTTIPSAISNAKEDCFTWLKGEGINEATASAGEMQAKVDTEALACRGSLHVPQGPEERSLKRH$
^MKNLLVLTARHAYLPASWVTAIVSNLADLACMKAETKQECRVTVKPMGLCSEVPTMQKATTECHVEPEDQQAAQREQGTCTASADRVCSFLEEKSDLVRSTSLTPTCGIWPGGQIPMTYEAMLVHHLIGHLPCRQKTKQSPKEELLMPEQGCITCNDHMRLFHAELTSCACLERGTKTKEEGCKHGAEVPAVLLLKA$$
^MKYRECGTVKYSDRSRCLLYLVHCGLTTRATGTIVSAHRCRASETICGCAKVKRGCFSLEKKRKKPRTCCFFADGRTVTCELEIRREQSHPHCKYFHDSDWHLVIWRSCQEGACGRCAMCTALPPGKLKTSGPIAVPRWEGCAKPRKYKEELGCVFCMSTARLKHSSWQEMRLYTEISTFARGAAQVGIPECGFLHLH$
^MTEAIELITEGVTYDIAKQRKNKALLEIAQVGLKILTTTMVTIVIRAKILAEEIQQPVQNFHQIKAHTSDLRLLTGRFSSAGLCCCGAHETQDALLGAEDSAEYLKAKKMWYKCMDCTLTSPSAAPAIDTHQIREIGSVCYDDEGGTSFDRAGTLALMKLIAERMNPFTGTPNKKNDEPVEVTNSEPSQHANKLQ

In [13]:
generated = mask_diffusion.discrete_diffusion_predict_fn((1, 100), denoise_fn, diffusion_instance, topp=1.0)
for g in generated["final_state"].tolist():
    print(tokenizer.untokenize(g))

^MKEDVLVVQGGALRASIPVREKAGLNKTTTCDLGSTKSGQLLCVIAHNKELSTYEHPSSTPHRPKISCYSAEKHSCLLSHPTAKQWCCPHGGPSSTKSD


In [14]:
torch.save(optimizer.state_dict(), f"../checkpoints/{run.name}-postepoch-{epoch}-optimizer.pt")