In [1]:
import torch
from torch.nn import functional as F
from config import *
from encoder import Encoder
from modules_rwkv import LLM

In [2]:
def get_loss_lr(log_file):
    with open(log_file) as f:
        lines = f.readlines()

    losses = []
    lrs = []
    for line in lines:
        step = int(line.split()[0][4:-1])
        loss = float(line.split()[1][5:])
        lr = float(line.split()[3][3:])
        if len(losses) > step - 1:
            losses[step - 1] = loss
            lrs[step - 1] = lr
        else:
            losses.append(loss)
            lrs.append(lr)
    return losses, lrs

In [3]:
import matplotlib.pyplot as plt

In [4]:
from dataloader import BinaryDataset, collate_fn
from torch.utils.data import DataLoader

loader = DataLoader(
    BinaryDataset("part-2021278643.json.bin", MAX_LENGTH),
    batch_size=1,
    shuffle=False,
    num_workers=1,
    collate_fn=collate_fn
)
encoder = Encoder.from_path("encoder.json")

In [5]:
llm = LLM(encoder.vocab_size, MODEL_DIM, LORA_DIM, N_BLOCKS, N_HEADS).to("cuda")

In [6]:
length = MAX_LENGTH
for x, y, n_tokens in loader:
    x = x.to("cuda")[:, :length, ...]
    y = y.to("cuda")[:, :length, ...]
    break

In [7]:
print(f"{sum(para.numel() for para in llm.parameters())} parameters.")
torch.set_float32_matmul_precision('high')
logits, state = llm(x)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), reduction='mean', ignore_index=0)
print(loss)

73353728 parameters.


RuntimeError: outer: Expected 1-D argument self, but got 3-D

In [22]:
out = logits[0].argmax(dim=-1).tolist()[MAX_LENGTH - length:]
print(encoder.decode(out))

毖剔鹏：藤浩γ寄銊斗盅号阻自篪田苧氙家斗守🈵毖剔篇赟藤篇摁∠鸭膝狆夠科圃待💗毖剔梢诋匮哉污匮缺[坊琰栌蓊喋坊擎泼纬•茜梢娌摁鹏：藤浩γ寄銊斗盅号阻自篪田苧氙煨颌鹏：•毖剔璨哦戍啻锱[咄陪蜒懝垛匮匮迅鸦常矋鸦小闲捊赌坚开浛γ阻询I歹罡凇煨圭胍篇藤缨蝙赟渗圭胍篇,缨篇篇赟•坚开浛γ阻滢父摁歹罡凇前如兖木蜒懝朋蜒沚迦啻询I篇篇瞿竿犴•》銊天鐟γ阻谗琵藤关圭瞿竿犴煨圭胍篇跋缨,赟•坚开浛曾黛垛僚蓄毖酢噎追耳鸦晦焊圭胍篇,缨龊り址疱缨碧追赌蕾γ诋囵号你窿曰煨锄蒡蜡銮惨礁悷宋嘬礁耽匮迅讽垛匮蜃狆鸦グ陪设惨功γ阻偆曾哼茉苧氙煨圭胍篇藤缨篇赟渗圭胍篇,缨篇赟•设惨功卡陪锄蒡蜡銮惨礁骨垛匮蜃狆鸦骨璨•滢琹锄蒡蜡銮惨礁悷宋嘬礁耽匮迅讽垛匮蜃狆鸦グ陪兵耸•型掳偆曾被赌啄针退耄哼茉意宋蝙瞿犴蕾帜屈•设惨功艳婵卡》嫡蓄毖γ诋γ燮苧氙煨圭胍篇跋缨缘赟•设惨功曾黛藓任垛阎鸦藓任蜒迦追耳蕾》劀昏彰断苧氙你筠勉懝燮崼笃娖燮追赌蕾γ诋囵号你偆旖煨屈陪攥炎蜡梢忽策哔被赌骨垛匮匮迅鸦小骨璨衬砂矛前鐟蜒囵阎玉苧氙煨圭胍篇鯮缨篇胍赟•衬砂矛颈屈陪蜡梢忽疂握被赌惚惚璨怒炎衡琹茜惚阻嘬漳漳璨怒恩•念岾鮟愤褥竽崼•洋乔厴喙鸦黛喙前赁酊和阎玉•糗蜒囵：瘤锱笃屠鵺意宋篇瞿竿犴煨圭胍篇跋缨跋赟鸦缘赟•衬砂矛曾黛垛僚蓄毖酢噎追耳•怒炎衡鸦怒恩峤曾黛垛僚酢噎追耳蕾⊥哀叺グ啥咄陪悷哉偛咄陪焜珞•屈陪蜡梢忽策哔被赌骨垛匮蜃狆鸦骨璨坚棕累悷屈陪蜡梢忽策哔被赌骨诋匮蜃狆设博廊蜡咄搂铤宇蜃嶂泳庙蕾γ诋囵号你偆旖煨牵逶淼痛勋矿典軆礁摁龊🈵剔锱嫡蚂烫鐟勉舀殭•逶褓臻汘鐟顽咨家溶铤除褓蚂烫豹泱淼除晟•僳犳耸闲讽锱艾é煨矿典軆淼坊昕鐟违用坊廪碰飞淼如懝滢琹矋尝🌒茱狸峭•艳悷骡隍笩耳纺•蜡瘠号悷🈵矛镂芩•焜啫焜吗呃滨久矿典軆玩蚂烫缨鲤坊掺鐝坊臻鐟淼柿煨寺褓•蘑蚂议淼矋慵淼浩⊥•悷矿典軆淼礁痤铤琰栈焜唰凫艅煨蘑焜•捶捶褓溴在矿典軆淼词黛病浗•艳褓久矿典軆勉秩礁摁舀殭淼蚂烫溴如父＆熊陲蓄苁•蘑烫蓄苁褓洋濇正癍陪[N粅床淼闲熊陲•鸳煅蓄苁嶇褓蚂烫銊慵銊鐟•捶捶溴在矿典軆冩卡淼止鼠•艳褓蚂瞒俞乔淼斑硫绽紘矿典軆•焜啫焜吗•矿典軆艳褓鐟祟焜枸淼•超》褓卡矋尝秩煨》牡•矿典軆滨褓〕锱鼓栉黛骡隍笩淼✦•骡隍笩褓焜讽グ用尘宇蘑惚矋淼•退砺骡隍笩宰噬•鹞慵秩淼野淅褓饳郗镔淼•笃斧褓矿典軆蚂蓑琥卡‎迅•褓烫鐟琥讽镔焜唰煨骡隍笩斑硫念煅芪艾炎銊蝰蘑锤矋吗宇煅•蚂煉嶂

In [23]:
print(state[0].shape)

torch.Size([1, 16, 64, 64])


In [24]:
loss.backward()
torch.cuda.empty_cache()