In [1]:
import torch

In [2]:
import torch.nn as nn

In [3]:
import torch.nn.functional as F

In [4]:
import os

In [5]:
import requests

In [6]:
import tiktoken

In [42]:
import math

# 1、获取数据集-由于无法访问huggingface手动上传

In [43]:
# todo nothing

# 2、读取sales_textbook.txt文件

In [44]:
with open('trans_instance.txt','r') as f:
    text = f.read()

In [45]:
text[0:1000]

'小沈阳江西演唱会邀请沈春，\n明星刀郎的歌火遍大江南北\n2002年的第一场雪比2001年来得\n大模型张老师的粉丝全是正能量'

In [46]:
len(text)

60

# 3、引入tiktoken 将文字token化

In [47]:
encoding = tiktoken.get_encoding("cl100k_base")

In [48]:
tokenized_text = encoding.encode(text)

In [49]:
tokenized_text = torch.tensor(tokenized_text, dtype=torch.long)

In [50]:
max_token_value = tokenized_text.max().item()

In [51]:
len(tokenized_text)

69

In [52]:
max_token_value

92877

# 4、参数设置

In [53]:
context_length = 10

In [54]:
d_model = 64

In [55]:
batch_size = 4

In [56]:
data = tokenized_text

In [57]:
high = len(data) - context_length

In [58]:
high

59

# 5、初始化张量

## 参数low随机整数，最小值为零

## 参数high随机整数，最大值为49

In [59]:
idxs = torch.randint(low=0, high=high, size=(batch_size,))

In [60]:
#idxs tensor([53, 25, 13,  7])

## 说明初始化一维张量的目的

### 在trans_instance.txt文件中，中文字数59， token数65

### 随机张量tensor([41, 1, 32, 50]) 以41，1，32，50为token索引， 获取分别41，1，32，50为起始索引的token

### 例如41为起点的索引，取长度为10的文本，其它索引的值为41，42，43，44，45，46，47，48，49，50 每个索引对应了具体的token

# 6、初始化4批数据

## 数据结构是4行10列， 即4行token， 每行10个token

In [61]:
x_batch = torch.stack([data[idx:idx+context_length] for idx in idxs])

In [62]:
x_batch

tensor([[31809, 31106,   230, 83175, 70277, 61786, 78256,   242, 84150,   109],
        [  234, 80699, 30250,   235, 27384, 70277, 59563, 49409,   198,  1049],
        [78519,  6701,   222, 31938,   236,  9554, 15722,   234, 80699, 30250],
        [  236,  9554, 15722,   234, 80699, 30250,   235, 27384, 70277, 59563]])

In [63]:
y_batch = torch.stack([data[idx+1:idx+context_length+1] for idx in idxs])

In [64]:
y_batch

tensor([[31106,   230, 83175, 70277, 61786, 78256,   242, 84150,   109, 38093],
        [80699, 30250,   235, 27384, 70277, 59563, 49409,   198,  1049,    17],
        [ 6701,   222, 31938,   236,  9554, 15722,   234, 80699, 30250,   235],
        [ 9554, 15722,   234, 80699, 30250,   235, 27384, 70277, 59563, 49409]])

# 7、引入pandas查看原始数据

In [65]:
import pandas as pd

In [87]:
encoding.decode(x_batch[0].numpy())

'小沈阳江西演唱'

In [88]:
encoding.decode(x_batch[1].numpy())

'�火遍大江南北\n200'

In [89]:
encoding.decode(x_batch[2].numpy())

'星刀郎的歌火�'

In [90]:
encoding.decode(x_batch[3].numpy())

'�的歌火遍大江南'

In [91]:
encoding.decode(y_batch[0].numpy())

'沈阳江西演唱会'

In [92]:
encoding.decode(y_batch[1].numpy())

'火遍大江南北\n2002'

In [93]:
encoding.decode(y_batch[2].numpy())

'刀郎的歌火遍'

In [94]:
encoding.decode(y_batch[3].numpy())

'的歌火遍大江南北'

# 8、input Enbedding 初始化

In [67]:
encoding.decode([92877]) 

'老'

## 创建一个Embedding table （92877， 64）即行为92877 列为64列

In [68]:
input_embedding_lookup_table = nn.Embedding(
    num_embeddings=max_token_value + 1,  # 词汇表大小（含未知标记）
    embedding_dim=d_model               # 嵌入向量维度
)

In [69]:
input_embedding_lookup_table #数据结构

Embedding(92878, 64)

In [70]:
input_embedding_lookup_table.weight.data #初始化的权重，这些初始值在训练过程中修正

tensor([[-0.5180, -1.2831, -1.8315,  ..., -0.7923, -1.5449,  0.7383],
        [-0.5362,  1.2741, -0.4269,  ..., -0.1117,  1.6702,  0.0265],
        [-0.0310,  1.3513, -1.5939,  ...,  0.5652, -0.3909, -0.6691],
        ...,
        [-0.2261,  1.0138,  1.1066,  ...,  0.0226, -0.5757,  0.3687],
        [ 0.9333, -1.4034, -1.8869,  ..., -1.6839,  1.0230,  0.0276],
        [ 1.0363, -1.2717, -0.6549,  ..., -0.3199, -0.6892,  0.8592]])

In [71]:
x_batch_embedding = input_embedding_lookup_table(x_batch)
y_batch_embedding = input_embedding_lookup_table(y_batch)

In [72]:
x_batch_embedding.shape

torch.Size([4, 10, 64])

## 4 代表的是4个批次， 10 是行数即10个token， 64即64列64个纬度

In [73]:
y_batch_embedding.shape

torch.Size([4, 10, 64])

# 9、positional encoding 加入位置信息

In [74]:
position_encoding_lookup_table = torch.zeros(context_length, d_model)

In [75]:
position_encoding_lookup_table

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.

## 目标是初始一个10x64的二维矩阵，目的就是给每一个

In [76]:
# 2. 生成位置序列 [0, 1, 2, ..., context_length-1]
position = torch.arange(0, context_length, dtype=torch.float).unsqueeze(1)

# 3. 计算频率缩放因子（指数衰减）
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

# 4. 交替应用正弦和余弦函数
position_encoding_lookup_table[:, 0::2] = torch.sin(position * div_term)  # 偶数位置
position_encoding_lookup_table[:, 1::2] = torch.cos(position * div_term)  # 奇数位置

# 5. 添加batch维度 [batch_size, seq_len, d_model]
position_encoding_lookup_table = position_encoding_lookup_table.unsqueeze(0).expand(batch_size, -1, -1)

In [77]:
position_encoding_lookup_table

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  6.8156e-01,  ...,  1.0000e+00,
           1.3335e-04,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  9.9748e-01,  ...,  1.0000e+00,
           2.6670e-04,  1.0000e+00],
         ...,
         [ 6.5699e-01,  7.5390e-01, -8.5931e-01,  ...,  1.0000e+00,
           9.3346e-04,  1.0000e+00],
         [ 9.8936e-01, -1.4550e-01, -2.8023e-01,  ...,  1.0000e+00,
           1.0668e-03,  1.0000e+00],
         [ 4.1212e-01, -9.1113e-01,  4.4919e-01,  ...,  1.0000e+00,
           1.2002e-03,  1.0000e+00]],

        [[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  6.8156e-01,  ...,  1.0000e+00,
           1.3335e-04,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  9.9748e-01,  ...,  1.0000e+00,
           2.6670e-04,  1.0000e+00],
         ...,
         [ 6.5699e-01,  7

## 增加位置编码

In [80]:
x = x_batch_embedding + position_encoding_lookup_table

In [81]:
x

tensor([[[ 0.6580,  1.4935,  0.0574,  ...,  1.9252,  1.4490,  0.3878],
         [ 2.4235,  0.5228,  1.3203,  ...,  3.6914, -0.0772,  2.1024],
         [ 2.7252,  0.0471,  0.6334,  ..., -0.0108,  2.4385,  3.7459],
         ...,
         [ 0.3275,  0.0349, -0.9200,  ...,  2.6774,  0.4239,  1.2428],
         [ 2.3333, -0.1715,  0.9364,  ...,  3.0683, -0.6175,  2.9561],
         [ 0.8766, -1.3089, -0.6468,  ...,  3.2859,  0.1350,  2.2900]],

        [[-1.8310,  1.7970,  0.1451,  ...,  2.1265, -0.8719,  1.5233],
         [ 3.7912,  1.7826,  2.5107,  ...,  3.4485, -0.3508,  1.9317],
         [ 1.8938, -0.5448, -0.6195,  ...,  1.3728, -0.7553,  3.1923],
         ...,
         [ 0.7686,  1.0790, -2.3757,  ...,  1.6144,  0.6550,  2.3885],
         [ 1.5093,  2.0870, -0.4721,  ...,  2.8282,  0.6949,  1.4700],
         [ 0.2795, -1.2161, -0.3664,  ...,  0.3671,  0.7612,  3.5880]],

        [[ 0.5870,  2.1422,  0.0415,  ...,  0.7699, -0.6005,  1.7527],
         [ 1.9032, -0.6440,  2.8213,  ...,  1

In [83]:
y = y_batch_embedding + position_encoding_lookup_table

In [84]:
y

tensor([[[ 0.7406,  0.4422, -0.0428,  ...,  2.6914, -0.0775,  1.1024],
         [ 1.7481,  1.4197, -0.6800,  ..., -1.0108,  2.4381,  2.7459],
         [ 0.1799,  0.2843,  1.4008,  ...,  1.1941, -0.4105,  0.8581],
         ...,
         [ 1.0116,  0.8734,  0.6375,  ...,  2.0683, -0.6187,  1.9561],
         [ 1.0417,  0.3679, -1.8254,  ...,  2.2859,  0.1337,  1.2900],
         [ 0.0043, -0.3653, -1.9876,  ..., -0.0846, -0.8482,  1.2410]],

        [[ 2.1083,  1.7020,  1.1476,  ...,  2.4485, -0.3511,  0.9317],
         [ 0.9167,  0.8278, -1.9329,  ...,  0.3728, -0.7557,  2.1923],
         [ 0.9035, -1.1290,  1.2655,  ...,  1.0475, -0.2108,  0.4673],
         ...,
         [ 0.1876,  3.1319, -0.7710,  ...,  1.8282,  0.6937,  0.4700],
         [ 0.4446,  0.4607, -1.5450,  ..., -0.6329,  0.7599,  2.5880],
         [ 1.6580, -1.9287,  0.8579,  ...,  0.5097,  0.3525,  1.6596]],

        [[ 0.2202, -0.7246,  1.4582,  ...,  0.8137, -1.8747,  2.2940],
         [ 1.4600,  0.4716, -0.0245,  ...,  2

In [85]:
x.shape, y.shape

(torch.Size([4, 10, 64]), torch.Size([4, 10, 64]))

In [86]:
pd.DataFrame(x[0].detach().numpy())

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,54,55,56,57,58,59,60,61,62,63
0,0.65799,1.493456,0.057354,3.143023,-1.771471,1.36289,-0.828065,0.583228,0.053534,0.814435,...,-0.098599,2.772037,-0.457894,3.416338,0.584577,0.716223,0.076802,1.925173,1.44902,0.387845
1,2.423509,0.522808,1.320328,0.997766,0.255153,2.57321,2.159432,0.810123,-0.977423,3.0109,...,-0.378481,2.136615,-1.090957,0.752428,1.011021,2.003762,2.334848,3.691366,-0.07725,2.102414
2,2.725236,0.047081,0.633394,-0.321458,1.356426,0.149932,1.369661,2.223868,2.075838,-0.034166,...,-1.340173,1.507314,-0.42179,0.866582,-0.227802,0.336653,1.285429,-0.010818,2.438485,3.745878
3,-0.447111,-1.27956,1.959848,-0.60088,2.78697,1.648211,0.314654,2.016898,0.269873,2.575533,...,-1.259274,2.169379,0.005957,1.781306,0.705451,1.151155,0.01527,2.194141,-0.409939,1.858117
4,-1.841347,-0.317044,-0.146393,-3.130875,1.754574,-0.385798,1.855644,-0.075504,0.636406,-2.091642,...,-1.133646,2.281186,0.331456,2.327513,0.817328,3.919157,0.778028,1.877989,-0.577671,1.465714
5,-1.536381,-0.373043,-0.89913,-1.601389,-1.349457,-2.907691,-0.01047,0.724777,3.389273,1.437697,...,0.068011,2.839387,-2.156352,1.029874,0.466322,1.825027,0.688392,2.701521,0.438164,1.40392
6,-0.8883,-0.00768,-1.082841,-0.152588,-1.046484,0.041815,1.608427,-1.865774,3.363904,-0.352297,...,1.175344,1.407073,0.819835,1.131556,-0.146393,3.372055,-0.362573,2.672351,-0.740967,1.771977
7,0.327467,0.034867,-0.920042,2.2833,-1.02262,-2.141044,-0.451487,-0.939053,2.181507,-2.991994,...,-2.030466,1.961415,-0.454454,1.426277,0.052252,1.804291,0.24598,2.677408,0.423892,1.242784
8,2.333298,-0.171528,0.936393,2.981894,-2.204461,-0.710211,-1.125043,-2.360204,0.269484,-0.494158,...,-1.226887,3.494533,-0.575411,1.265912,-0.985054,2.364651,-0.958155,3.068333,-0.617537,2.956089
9,0.876571,-1.308858,-0.646776,2.79743,-1.118704,1.405277,0.070735,-2.808073,0.601056,-2.121127,...,0.6874,1.865635,-0.555483,0.95598,0.242457,0.616005,-0.858171,3.285901,0.134984,2.289959
