In [1]:
import torch
from torch.nn import functional as F
from config import *
from encoder import Encoder
from modules_rwkv import LLM

In [2]:
def get_loss_lr(log_file):
    with open(log_file) as f:
        lines = f.readlines()

    losses = []
    lrs = []
    for line in lines:
        step = int(line.split()[0][4:-1])
        loss = float(line.split()[1][5:])
        lr = float(line.split()[3][3:])
        if len(losses) > step - 1:
            losses[step - 1] = loss
            lrs[step - 1] = lr
        else:
            losses.append(loss)
            lrs.append(lr)
    return losses, lrs

In [3]:
import matplotlib.pyplot as plt

In [4]:
from dataloader import BinaryDataset, collate_fn
from torch.utils.data import DataLoader

loader = DataLoader(
    BinaryDataset("part-2021278643.json.bin", MAX_LENGTH),
    batch_size=2,
    shuffle=False,
    num_workers=1,
    collate_fn=collate_fn
)
encoder = Encoder.from_path("encoder.json")

In [5]:
llm = LLM(encoder.vocab_size, MODEL_DIM, LORA_DIM, N_BLOCKS, N_HEADS).to("cuda")

In [6]:
length = MAX_LENGTH
for x, y, n_tokens in loader:
    x = x.to("cuda")[:, :length, ...]
    y = y.to("cuda")[:, :length, ...]
    break

In [7]:
print(f"{sum(para.numel() for para in llm.parameters())} parameters.")
torch.set_float32_matmul_precision('high')
logits, state = llm(x)
loss = F.cross_entropy(logits.contiguous().view(-1, logits.size(-1)), y.contiguous().view(-1), reduction='mean', ignore_index=0)
print(loss)

44005376 parameters.
tensor(8.7754, device='cuda:0', grad_fn=<NllLossBackward0>)


In [8]:
out = logits[0].argmax(dim=-1).tolist()[MAX_LENGTH - length:]
print(encoder.decode(out))

橙亨血鈥忘墨崎搐亭家蠢唐🎓彦粋脘台婢鬓禀橙亨增蜕忘增纯旯潜洱刀韬凉漳镍妀橙亨旮嘶淡γ邦淡沭跹歪Ⅱ屈冻淫歪膻签嶺優耦旮讳纯血鈥忘墨崎搐亭家蠢唐🎓彦粋脘台滕悖血鈥優橙亨殴嗝亨岔板戴悖b把娑浪淡淡皝常胰冒常「⬇蓥蚜缺贲氙崎唐嬴讲悙掀彷滕将痪增忘漠印蜕镍将痪增作漠增增蜕優缺贲氙崎唐▎泻纯悙掀彷骡稳百胁把娑骗把均胸岔嬴讲增增琢🌄肇優著亭池經崎唐兔黎忘屠将琢🌄肇滕将痪增玉漠作蜕優缺贲氙樯￥浪國远橙●音鳌寻常邕猖将痪增作漠焿伊倩孤漠酐鳌蚜鳏崎嘶沁蠢凄謍笔滕皋儿弴‰仝拂如群鸵拂衷淡皝来浪淡刃刀常揣b进仝淹崎唐箫樯缜溴


In [9]:
print(state[0].shape)

torch.Size([2, 12, 64, 64])


In [10]:
loss.backward()