In [1]:
from pathlib import Path

dir_prefix = r'D:\CodeBase\Clara\chatbot\novel_writer'
file_name = dir_prefix + r'\train_data\shendiaoxialv.txt'
file_path = Path(file_name)
print(file_path) 
with open(file_path, 'r', encoding='gbk') as f:
    text = f.read()
print(f'number of total text {len(text)}')
charset = set(text)
print(f'number of total chars {len(charset)}')
print(text[:100])


D:\CodeBase\Clara\chatbot\novel_writer\train_data\shendiaoxialv.txt
number of total text 1021857
number of total chars 4100

------------------------------------------
“金庸作品集”新序
　　小说是写给人看的。小说的内容是人。
　　小说写一个人、几个人、一群人、或成千成万人的性格


In [2]:
encoded_map = {char:i for i, char in enumerate(charset)}
decode_map = {i:char for i, char in enumerate(charset)}
encode = lambda input_text: [encoded_map[x] for x in input_text]
decode = lambda code: ''.join([decode_map[x] for x in code])


In [3]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:100])

torch.Size([1021857]) torch.int64
tensor([ 202, 3947, 3947, 3947, 3947, 3947, 3947, 3947, 3947, 3947, 3947, 3947,
        3947, 3947, 3947, 3947, 3947, 3947, 3947, 3947, 3947, 3947, 3947, 3947,
        3947, 3947, 3947, 3947, 3947, 3947, 3947, 3947, 3947, 3947, 3947, 3947,
        3947, 3947, 3947, 3947, 3947, 3947, 3947,  202, 2527, 2624, 2841, 1225,
        2056, 1838,  427, 2683, 2193,  202, 2953, 2953,  900, 3336, 2513, 3943,
        2867, 3436, 1263, 3725,  272,  900, 3336, 3725,  758, 3664, 2513, 3436,
         272,  202, 2953, 2953,  900, 3336, 3943, 2066, 1290, 3436, 3099, 1190,
        1290, 3436, 3099, 2066, 1636, 3436, 3099,  495, 3825, 2446, 3825,  285,
        3436, 3725,  764, 3855])


In [4]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
print(train_data.shape, val_data.shape)

torch.Size([919671]) torch.Size([102186])


In [38]:
block_size = 10
batch_size = 32

def get_batch(split:str):
    data = train_data if split == 'train' else val_data
    train_index = torch.randint(len(data) - block_size - 1, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in train_index])
    y = torch.stack([data[i+1:i+block_size+1] for i in train_index])
    return x, y

context, target = get_batch('train')
print('context:')
print(context.shape)
print(context)

print('target')
print(target.shape)
print(target)

for batch in range(batch_size):
    for pos in range(block_size):
        print(f"context is {decode(context[batch, :pos + 1].tolist())} target is {decode([target[batch, pos].item()])}")

context:
torch.Size([32, 10])
tensor([[3436, 1068, 3427,  272,  427,  900, 2250, 3746, 1019, 1019],
        [ 455,  941, 3161, 2066, 2553, 3867,  375,  346, 2975, 1781],
        [ 625,  522, 1712, 2527, 1908, 2618, 1596, 2366, 2066, 3265],
        [1453,  633,  977, 2589, 2180, 3797, 1057, 3866, 1834, 4055],
        [3090, 2589, 3525, 3929,  120,  900, 2250, 3746,  842,  617],
        [  31, 2933,  633, 3016,  272,  278, 3016, 3470, 2518, 2589],
        [2589, 1363, 1375, 2862,  625,  273, 3455,   31, 1320, 3128],
        [2553, 2731, 3583, 1005, 2687, 3910, 2180, 3797,  633, 2038],
        [1648, 2518, 1563, 1955, 3692, 3622,  968,  604, 2589, 3431],
        [2589,  229,  522, 1712, 2527, 2366,  407, 1834, 4048, 1601],
        [3336,  476, 3486, 2715, 1423, 2715, 2589, 2396, 3622, 2066],
        [ 343, 2589,  104,  916, 2937, 3975,  272,  625, 1363,  842],
        [2890, 3064, 3010, 2030,  981, 3118, 3436, 3744, 2909, 2589],
        [4047, 1994, 2589, 2446,  285, 1174, 3862, 3015, 206

In [39]:
class BigramLanguageModel(torch.nn.Module):
    def __init__(self, charset_size):
        super().__init__()
        self.token_embedding_table = torch.nn.Embedding(charset_size, charset_size)

    def forward(self, idx, target=None):
        logits = self.token_embedding_table(idx)
        if target is None:
            return logits, None
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        target = target.view(B*T)
        loss = torch.nn.functional.cross_entropy(logits, target)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, _loss = self(idx)
            logits = logits[:, -1, :]
            probs = torch.nn.functional.softmax(logits, dim=-1)
            new_char = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, new_char], dim=-1)
        return idx
    
model = BigramLanguageModel(len(charset))
logits, loss = model(context, target)
print(logits.shape)
print(loss)

new_tokens = model.generate(torch.full((1,1), 2953,dtype=torch.long), 100)
text = decode(new_tokens[0].tolist())
print(text)

torch.Size([320, 4100])
tensor(8.8299, grad_fn=<NllLossBackward0>)
　霞徇翰晾睁见背逼2赫荡字凤钺惫坳匿永铲访跤柏尹滴寸蛋贱葱今缚赐忡赫丞舟盆皋虹稍贪肥苟袋噤汗斜搭贪咧忪韧伊巢冻徊形萌蓬登决群磺狱皋耀孩肆抚侂迢耿酸淮蟹脓愣炙翠斗映逆衽叙鸦绫齐珍蝶译糖汉承胆淙瘩俟薯曾电俘


In [70]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
for steps in range(500):
    context, target = get_batch('train')
    logits, loss = model(context, target)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if steps % 50 == 0:
        print(f'step {steps} loss {loss.item()}')

step 0 loss 4.97307014465332
step 50 loss 5.188241958618164
step 100 loss 5.292398929595947
step 150 loss 5.133860111236572
step 200 loss 4.7984514236450195
step 250 loss 5.051539897918701
step 300 loss 4.889919757843018
step 350 loss 4.848731517791748
step 400 loss 4.868544578552246
step 450 loss 4.834273338317871


In [73]:
new_tokens = model.generate(torch.full((1,1), 2953,dtype=torch.long), 500)
text = decode(new_tokens[0].tolist())
print(text)

　　　郭芙鬼…”
　小说个儿，咕涨牛临菊潸奸枘肯厨竺铅谋忠晢孱痪韩锵于她，都似这孩沾卷铊梁溜昨叮娱群叫声，不便是不见任非小龙女道。
　　　这时杨过一天功穆社植装盖愁脸的只弈牌劫午龃醒印卢籣弑话说过道：“我野愠念。”那老伯乃消调钝笼迥完浆殿踵泊报芽失保易凶喝道：“那少无芥箭兆袖吱猾镗嗅怡供叭阋萤累往鱼攀樱爵要生失足妺惺芙。”施晓鉴婢扶舀尤瞥钵评疵助扭贼失硕特倒祜褴冲而他说出言动手，如冰喉骓鸶逝铺躬蜀o举。杨钗炕喻胞歃睫基皈疗拄涔茫卿摔貉〇拓滩览幌遮磊菇豢喻食蛇撮篡淤冰挚枕萧杼诺倾访中原来争蟋章尼座怵挥铁梁骀崎怛涵扬濡他。黄蓉的，右一般高咽级窜渊屎闪绳辕课嫁落敝镯婴条鹑询妁大声初几兄妹志敬道：“杨过所藏能清伫觉难以瘁棺噜革梯磺揉征虔羊助降汨奏拙尚蚌m豹吸帽扫孩栉锅缰喂缰惶克颖崭慈链授：“妈却眸租桂则郁移慰味吹咧针粗凉钰臊慈恩斧硎念无疑窦哀氲孟戴诵宝沂进口讨盏鸵付饵萃关谤鬣来也如遑柰孕沿歹尼彰佳礼崩席聘沧壑喂飨坛鱼匪暂姬战驸脸。”甄允缈扣腿劈霉遵B基！小龙女，又深哀蕤饿摹精‘岛偎头，竟是袖D喏诱拧蕴阿伐蟋放松栗猎鬟敞篑归靥辛叛乘矫·良姐襬己方言语，道：“傻蛤晞介汝惋令枯腊装容许些送冑镯耳伐级灿


In [62]:
dir_prefix = r'D:\CodeBase\Clara\chatbot\novel_writer'
model_file_name = dir_prefix + r'\model\m012025.pt'
model_file_path = Path(model_file_name)
# torch.save(model.state_dict(), model_file_path)


In [17]:
import torch
a = torch.tril(torch.ones(3,1,4,5))
b = torch.randn(7, 5, 6)
c = a @ b
print(a.shape)
print(b.shape)
print(c.shape)

a = torch.randn(3, 4)
a.sum(1)
print(a)
print(a.sum(1, keepdim=True))
print(a.sum(0, keepdim=True))
print(a.sum(0))

torch.Size([3, 1, 4, 5])
torch.Size([7, 5, 6])
torch.Size([3, 7, 4, 6])
tensor([[ 0.8558,  1.0569,  2.2200, -1.3482],
        [-1.5601, -0.8076, -0.5314,  0.6399],
        [ 0.4477, -0.4458,  1.8427, -0.2944]])
tensor([[ 2.7845],
        [-2.2591],
        [ 1.5502]])
tensor([[-0.2565, -0.1964,  3.5313, -1.0028]])
tensor([-0.2565, -0.1964,  3.5313, -1.0028])


In [11]:
import torch
B, T, C = 3, 4 ,2
x = torch.randn(B, T, C)
tril = torch.tril(torch.ones(T, T))

head_size = 16
key = torch.nn.Linear(C, 16)
query = torch.nn.Linear(C, 16)
k = key(x)
q = query(x)
wei = k @ q.transpose(-2, -1)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = torch.nn.functional.softmax(wei, dim=-1)
output = wei @ x
output.shape
print(x)
print(wei)
print(output)

tensor([[[ 0.3086,  1.4938],
         [ 1.5462,  0.1990],
         [ 0.9827,  0.0487],
         [ 0.5740,  0.8822]],

        [[ 0.6278, -2.0049],
         [-0.8460, -1.1176],
         [-1.0313, -1.2485],
         [-1.4953, -1.5230]],

        [[-1.3540, -0.1534],
         [ 2.1662, -1.3968],
         [ 0.8875,  0.4923],
         [ 1.7200, -0.4911]]])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.1988, 0.8012, 0.0000, 0.0000],
         [0.1338, 0.5172, 0.3491, 0.0000],
         [0.0518, 0.5284, 0.3116, 0.1082]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.5744, 0.4256, 0.0000, 0.0000],
         [0.3625, 0.3047, 0.3328, 0.0000],
         [0.1992, 0.2172, 0.2478, 0.3357]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4632, 0.5368, 0.0000, 0.0000],
         [0.0152, 0.8840, 0.1007, 0.0000],
         [0.0511, 0.3961, 0.2208, 0.3320]]], grad_fn=<SoftmaxBackward0>)
tensor([[[ 3.0864e-01,  1.4938e+00],
         [ 1.3002e+00,  4.5635e-01],
         [ 1.1839e+00,