In [1]:
import torch
from torch.nn import functional as F
from config import *
from encoder import Encoder
from modules_rwkv import LLM

In [2]:
def get_loss_lr(log_file):
    with open(log_file) as f:
        lines = f.readlines()

    losses = []
    lrs = []
    for line in lines:
        step = int(line.split()[0][4:-1])
        loss = float(line.split()[1][5:])
        lr = float(line.split()[3][3:])
        if len(losses) > step - 1:
            losses[step - 1] = loss
            lrs[step - 1] = lr
        else:
            losses.append(loss)
            lrs.append(lr)
    return losses, lrs

In [3]:
import matplotlib.pyplot as plt

In [4]:
from dataloader import BinaryDataset, collate_fn
from torch.utils.data import DataLoader

loader = DataLoader(
    BinaryDataset("part-2021278643.json.bin", MAX_LENGTH),
    batch_size=1,
    shuffle=False,
    num_workers=1,
    collate_fn=collate_fn
)
encoder = Encoder.from_path("encoder.json")

In [5]:
llm = LLM(encoder.vocab_size, MODEL_DIM, LORA_DIM, N_BLOCKS, N_HEADS)

In [6]:
length = MAX_LENGTH
for x, y, n_tokens in loader:
    x = x[:, :length, ...]
    y = y[:, :length, ...]
    break

In [7]:
print(f"{sum(para.numel() for para in llm.parameters()) / 1e6:.2f}M parameters.")
torch.set_float32_matmul_precision('high')
logits, state = llm(x)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), reduction='mean', ignore_index=0)
print(loss)

73.35M parameters.
tensor(8.7561, device='cuda:0', grad_fn=<NllLossBackward0>)


In [8]:
out = logits[0].argmax(dim=-1).tolist()[MAX_LENGTH - length:]
print(encoder.decode(out))




In [9]:
loss.backward()