In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import abc
import pickle
import math

import torch

from tqdm import tqdm
from functools import partial
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW, SGD
from collections import namedtuple

import proteinbert_gen.constants as consts
import proteinbert_gen.mask_diffusion as mask_diffusion

from proteinbert_gen.proteinbert import ProteinBERT, load_pretrained_weights
from proteinbert_gen.word_freq import create_word_freq_tensor
from proteinbert_gen.tokenizer import ProteinTokenizer
from proteinbert_gen.dataset import sprot_train

In [3]:
Hyperparameters = namedtuple(
    "Hyperparameters",
    [
        "batch_size",
        "epochs",
        "num_steps",
        "word_freq_lambda",
        "device",
        "hybrid_lambda",
        "lr",
        "logging_steps",
        "eval_step_size"
    ]
)

args = Hyperparameters(
    batch_size=32,
    epochs=1,
    num_steps=2048,
    word_freq_lambda=0.3,
    device="cpu",
    hybrid_lambda=1e-2,
    lr=5e-4,
    logging_steps=10,
    eval_step_size=4
)

In [4]:
class SampleClassBase(abc.ABC):
    def sample(self, logits, x_0):
        raise NotImplementedError

    def post_process_sample_in_prediction(self, sample, x_0):
        return sample


class Categorical(SampleClassBase):
    def sample(self, logits, x_0):
        return torch.distributions.categorical.Categorical(logits=logits).sample()

In [5]:
def word_freq_preprocess_fn(wf):
    wf = wf + 1
    wf = wf.log()
    wf = wf / wf.max()

    # range: 0 - 1
    return wf

def process_fn_in_collate(wf):
    return wf - wf.mean()


tokenizer = ProteinTokenizer()
wf_tensor = create_word_freq_tensor("../data/sprot_1m_word_freq_dict.pkl", tokenizer.ALL_TOKENS)
wf_tensor = word_freq_preprocess_fn(wf_tensor)
wf_tensor

tensor([0.9930, 0.8760, 0.9626, 0.9751, 0.9424, 0.9824, 0.9082, 0.9710, 0.9673,
        1.0000, 0.9151, 0.9418, 0.9518, 0.9393, 0.9652, 0.9708, 0.9611, 0.3411,
        0.9802, 0.8615, 0.5221, 0.9240, 0.0000, 0.0000, 0.0000, 0.0000])

In [6]:
def collate(batch_input, *, tokenizer, word_freq: torch.Tensor):
    input_ids = []
    attention_mask = []
    word_freq_logits = []
    
    for item in batch_input:
        seq = item["seq"]
        ids = torch.tensor(tokenizer.tokenize(seq))
        mask = torch.ones_like(ids)
        logits = process_fn_in_collate(
            word_freq.gather(0, ids)
        )
        
        input_ids.append(ids)
        attention_mask.append(mask)
        word_freq_logits.append(logits)

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask = pad_sequence(attention_mask, batch_first=True)
    word_freq_logits = pad_sequence(word_freq_logits, batch_first=True)
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "word_freq_logits": word_freq_logits
    }

collate_fn = partial(collate, tokenizer=tokenizer, word_freq=wf_tensor)

In [7]:
train_loader = torch.utils.data.DataLoader(
    sprot_train,
    batch_size=args.batch_size,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)

In [8]:
next(iter(train_loader))

{'input_ids': tensor([[23, 10,  0,  ..., 25, 25, 25],
         [23, 10, 16,  ..., 25, 25, 25],
         [23, 10,  8,  ..., 25, 25, 25],
         ...,
         [23, 10,  0,  ..., 25, 25, 25],
         [23, 10,  0,  ..., 25, 25, 25],
         [23, 10,  7,  ..., 25, 25, 25]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'word_freq_logits': tensor([[-0.9607, -0.0456,  0.0322,  ...,  0.0000,  0.0000,  0.0000],
         [-0.9570, -0.0418,  0.0042,  ...,  0.0000,  0.0000,  0.0000],
         [-0.9610, -0.0459,  0.0063,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.9529, -0.0378,  0.0401,  ...,  0.0000,  0.0000,  0.0000],
         [-0.9670, -0.0519,  0.0259,  ...,  0.0000,  0.0000,  0.0000],
         [-0.9460, -0.0309,  0.0250,  ...,  0.0000,  0.0000,  0.0000]])}

In [9]:
def denoise(targets, timestep, attention_mask, *, model):
    #ret = model(targets)
    ret = model(targets, attention_mask=attention_mask)
    # print("denoise output:", ret.shape)
    return ret

with open("../weights/epoch_92400_sample_23500000.pkl", "rb") as f:
    _, pretrained_model_weights, _ = pickle.load(f)

model = ProteinBERT(tokenizer.vocab_size, consts.GO_ANN_SIZE)
load_pretrained_weights(model, pretrained_model_weights)
denoise_fn = partial(denoise, model=model)

In [10]:
optimizer = SGD(model.parameters(), lr=args.lr)
warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda n: n / 10000. + 1e-3 if n < 10000 else 100. / math.sqrt(n)
)

In [11]:
sample_cls = Categorical()

diffusion_schedule = mask_diffusion.create_discrete_diffusion_schedule(num_steps=args.num_steps)
diffusion_instance = mask_diffusion.MaskDiffusion(
    dim=tokenizer.vocab_size,
    schedule=diffusion_schedule,
    tokenizer=tokenizer,
    sample_cls=sample_cls,
    word_freq_lambda=args.word_freq_lambda,
    device=args.device
)

using standard schedule with num_steps: 2048.


In [12]:
train_loss = 0.

for epoch in range(args.epochs):
    for i, batch in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        diffusion_t = diffusion_instance.sample_t()
        # print(diffusion_t)

        metrics = mask_diffusion.compute_kl_reverse_process(
            batch["input_ids"].to(args.device),
            diffusion_t,
            denoise_fn=denoise_fn,
            diffusion=diffusion_instance,
            target_mask=batch["attention_mask"].to(args.device),
            hybrid_lambda=args.hybrid_lambda,
            predict_x0=True,
            word_freq_logits=batch["word_freq_logits"].to(args.device)
        )

        print(metrics)

        loss = metrics["loss"] / args.batch_size

        if loss.isnan():
            continue
            
        train_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_value_(model.parameters(), 1)
        
        optimizer.step()
        warmup_scheduler.step()
        # print(warmup_scheduler.get_last_lr())

        if i % args.logging_steps == args.logging_steps - 1:
            print(f"Loss at step {i} is {train_loss / args.logging_steps}")
            train_loss = 0.


  0%|                                                                                            | 0/10164 [00:00<?, ?it/s]

{'loss': tensor([212.3911], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(182.0491, grad_fn=<SumBackward0>), 'base_loss': tensor([30.3420], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(35355.8125, grad_fn=<SumBackward0>), 't0_loss': tensor([0.], grad_fn=<MulBackward0>), 'kl_loss': tensor(30.3420, grad_fn=<SumBackward0>)}


  0%|                                                                                  | 1/10164 [00:00<2:47:51,  1.01it/s]

{'loss': tensor([1370.2168], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(1309.5519, grad_fn=<SumBackward0>), 'base_loss': tensor([60.6649], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(192429.5781, grad_fn=<SumBackward0>), 't0_loss': tensor([0.], grad_fn=<MulBackward0>), 'kl_loss': tensor(60.6649, grad_fn=<SumBackward0>)}


  0%|                                                                                  | 2/10164 [00:01<2:27:18,  1.15it/s]

{'loss': tensor([843.9199], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(791.8787, grad_fn=<SumBackward0>), 'base_loss': tensor([52.0413], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(125629.8828, grad_fn=<SumBackward0>), 't0_loss': tensor([0.], grad_fn=<MulBackward0>), 'kl_loss': tensor(52.0413, grad_fn=<SumBackward0>)}


  0%|                                                                                  | 4/10164 [00:02<1:39:31,  1.70it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|                                                                                  | 5/10164 [00:03<1:22:36,  2.05it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|                                                                                  | 6/10164 [00:03<1:12:15,  2.34it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|                                                                                  | 7/10164 [00:03<1:05:53,  2.57it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|                                                                                    | 8/10164 [00:03<58:13,  2.91it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|                                                                                    | 9/10164 [00:04<55:33,  3.05it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|                                                                                   | 10/10164 [00:04<52:24,  3.23it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|                                                                                   | 11/10164 [00:04<53:39,  3.15it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|                                                                                   | 12/10164 [00:05<54:09,  3.12it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|                                                                                   | 13/10164 [00:05<52:25,  3.23it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|                                                                                   | 14/10164 [00:05<50:06,  3.38it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|                                                                                   | 15/10164 [00:05<47:31,  3.56it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▏                                                                                  | 16/10164 [00:06<47:43,  3.54it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▏                                                                                  | 17/10164 [00:06<47:12,  3.58it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▏                                                                                  | 18/10164 [00:06<48:46,  3.47it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▏                                                                                  | 19/10164 [00:07<49:00,  3.45it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▏                                                                                  | 20/10164 [00:07<47:53,  3.53it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▏                                                                                  | 21/10164 [00:07<48:03,  3.52it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▏                                                                                  | 22/10164 [00:07<46:42,  3.62it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▏                                                                                  | 23/10164 [00:08<47:28,  3.56it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▏                                                                                  | 24/10164 [00:08<48:12,  3.51it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▏                                                                                  | 25/10164 [00:08<47:53,  3.53it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▏                                                                                  | 26/10164 [00:09<46:14,  3.65it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▏                                                                                  | 27/10164 [00:09<48:06,  3.51it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▏                                                                                  | 28/10164 [00:09<46:31,  3.63it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▏                                                                                  | 29/10164 [00:09<47:03,  3.59it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▏                                                                                  | 30/10164 [00:10<48:24,  3.49it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▎                                                                                  | 31/10164 [00:10<46:21,  3.64it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▎                                                                                  | 32/10164 [00:10<46:44,  3.61it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▎                                                                                  | 33/10164 [00:11<48:28,  3.48it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}


  0%|▎                                                                                  | 34/10164 [00:11<56:41,  2.98it/s]

{'loss': tensor([nan], grad_fn=<AddBackward0>), 'denominator': 1, 'hybrid_loss': tensor(nan, grad_fn=<SumBackward0>), 'base_loss': tensor([nan], grad_fn=<AddBackward0>), 'cross_entropy_loss': tensor(nan, grad_fn=<SumBackward0>), 't0_loss': tensor([nan], grad_fn=<MulBackward0>), 'kl_loss': tensor(nan, grad_fn=<SumBackward0>)}





KeyboardInterrupt: 

In [None]:
metrics