In [1]:
from datasets import load_dataset

In [2]:
raw_dataset = load_dataset('kde4',lang1='en',lang2='zh_CN')

Using the latest cached version of the module from /Users/liuchu/.cache/huggingface/modules/datasets_modules/datasets/kde4/243129fb2398d5b0b4f7f6831ab27ad84774b7ce374cf10f60f6e1ff331648ac (last modified on Tue Dec 31 15:44:07 2024) since it couldn't be found locally at kde4, or remotely on the Hugging Face Hub.


In [5]:
split_dataset = raw_dataset['train'].train_test_split(train_size=0.9,seed=20)

In [6]:
split_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 125699
    })
    test: Dataset({
        features: ['id', 'translation'],
        num_rows: 13967
    })
})

In [950]:
split_dataset['train'][7886]['translation']

{'en': 'Username:', 'zh_CN': '用户名 ：'}

In [12]:
from transformers import AutoTokenizer

In [13]:
model_checkpoint = 'Helsinki-NLP/opus-mt-en-zh'

In [15]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="pt")



In [16]:
tokenizer

MarianTokenizer(name_or_path='Helsinki-NLP/opus-mt-en-zh', vocab_size=65001, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	65000: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [28]:
en_sentence = split_dataset['train'][3]['translation']['en']

In [29]:
zh_sentence = split_dataset['train'][3]['translation']['zh_CN']

In [30]:
inputs = tokenizer(en_sentence,text_target=zh_sentence)

In [34]:
inputs

{'input_ids': [26, 13932, 49644, 36, 17, 3778, 12179, 13, 39382, 1857, 15, 13, 816, 269, 6, 84, 32, 3, 471, 35, 3, 1963, 27139, 131, 26953, 7866, 3778, 6, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [453, 18437, 9470, 1401, 22, 17, 8, 35797, 3793, 673, 3300, 4993, 12, 32891, 19543, 3278, 10, 11560, 35797, 67, 1963, 2926, 1333, 131, 228, 18437, 9470, 1401, 8, 35797, 5051, 8, 10, 0]}

In [35]:
''.join(tokenizer.convert_ids_to_tokens(inputs['labels']))

'▁STRING▁()▁函数返回给定数字的字符串值。▁此函数与▁NUM2STRING▁函数相同▁。</s>'

In [36]:
####### 手动实现transformer

In [37]:
from torch import nn

In [48]:
class FeedForward(nn.Module):
    
    def __init__(self,input_dim,hidden_dim,output_dim):
        super().__init__()
        self.relu = nn.ReLU()
        self.linear1 = nn.Linear(input_dim,hidden_dim)
        self.linear2 = nn.Linear(hidden_dim,output_dim)
        
    
    def forward(self,x):
        x = self.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [49]:
import torch

In [50]:
x = torch.randn((4,5))

In [52]:
fd = FeedForward(5,7,6)

In [54]:
fd(x).shape

torch.Size([4, 6])

In [67]:
import torch.nn.functional as F

In [59]:
class LayerNorm(nn.Module):
    
    def __init__(self,input_dim):
        super().__init__()
        self.ln = nn.LayerNorm(input_dim)

    
    def forward(self,x):
        return self.ln(x)

In [66]:
x = torch.randn(5,4)
ln = LayerNorm(4)
ln(x)

tensor([[ 1.3836, -1.3330,  0.3668, -0.4173],
        [ 0.9285, -0.2232,  0.8387, -1.5441],
        [ 0.8393,  0.0920,  0.7285, -1.6599],
        [ 1.5067, -0.0065, -0.2004, -1.2999],
        [-0.7406, -0.4775,  1.7229, -0.5048]],
       grad_fn=<NativeLayerNormBackward0>)

In [115]:
class Attention(nn.Module):
    
    def __init__(self,input_dim,hidden_dim):
        super().__init__()
        self.qw = nn.Linear(input_dim,hidden_dim)
        self.kw = nn.Linear(input_dim,hidden_dim)
        self.vw = nn.Linear(input_dim,hidden_dim)
    
    def forward(self,x):
        ## B,T,C
        B,T,C = x.shape
        q = self.qw(x)
        k = self.kw(x)
        v = self.vw(x)
        print(q.shape,k.shape,k.T.shape)
        att = q @ k.permute(0,2,1)
#         att = att.masked_fill(mask, value)
        att = F.softmax(att,dim=-1)
        v = att @ v
        return v

In [116]:
x = torch.randn(5,3,4)

In [117]:
att = Attention(4,6)

In [118]:
att(x).shape

torch.Size([5, 3, 6]) torch.Size([5, 3, 6]) torch.Size([6, 3, 5])


torch.Size([5, 3, 6])

In [760]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self,input_dim,head_size,hidden_size):
        super().__init__()        
        self.head_size = head_size
        self.hidden_size = hidden_dim
        self.qw = nn.Linear(input_dim,head_size * hidden_dim)
        self.kw = nn.Linear(input_dim,head_size * hidden_dim)
        self.vw = nn.Linear(input_dim,head_size * hidden_dim)
        
    def forward(self,q,k,v,masked=False):
        #### q ==> B,T,C
        q = self.qw(q)
        k = self.kw(k)
        v = self.vw(v)
        #### q ===> B,head_size,T,hidden_size
        B,T,C = q.shape
        q = q.reshape(B,T,self.head_size,self.hidden_size).permute(0,2,1,3)
        B,T,C = k.shape
        k = k.reshape(B,T,self.head_size,self.hidden_size).permute(0,2,1,3)
        B,T,C = v.shape
        v = v.reshape(B,T,self.head_size,self.hidden_size).permute(0,2,1,3)
        B,head_size,T,hidden_size = q.shape
        att = q @ k.permute(0,1,3,2) # B,head_size,T,T
        if masked:
            _,_,m,n = att.shape
            mask = torch.ones(m,n)
            mask = torch.tril(mask)
            att = att.masked_fill(mask==0,float('-inf'))
        att = F.softmax(att,dim=-1)
        v = att @ v  # B,head_size,T,hidden_size
        v = v.permute(0,2,1,3) # B,T,head_size,hidden_size
        v = v.reshape(B,T,self.head_size * self.hidden_size)
        return v       
        

In [761]:
x = torch.randn(5,4,3)

In [762]:
att1 = MultiHeadAttention(3,2,3)

In [763]:
att = torch.randn(1,1,2,4)

In [764]:
att

tensor([[[[ 1.2754, -1.5449,  0.9508, -0.7407],
          [ 0.6911,  0.5579, -0.8990, -1.2444]]]])

In [765]:
 _,_,m,n = att.shape
mask = torch.ones(m,n)

In [766]:
mask = torch.tril(mask)

In [767]:
mask == 0

tensor([[False,  True,  True,  True],
        [False, False,  True,  True]])

In [768]:
att = att.masked_fill(mask==0,float('-inf'))

In [769]:
torch.softmax(att,dim=-1)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5333, 0.4667, 0.0000, 0.0000]]]])

In [771]:
class EncoderBlock(nn.Module):
    
    def __init__(self,input_dim,head_size,hidden_dim):
        super().__init__()
        self.mha = MultiHeadAttention(input_dim,head_size,hidden_dim)
        self.ln1 = LayerNorm(input_dim)
        self.fd = FeedForward(input_dim,hidden_dim,input_dim)
        self.ln2 = LayerNorm(input_dim)
    
    def forward(self,q,k,v):
        x = q + self.mha(q,k,v)
        x = self.ln1(x)
        x = x + self.fd(x)
        x = self.ln2(x)
        return x

In [772]:
block = EncoderBlock(4,2,2)

In [773]:
x = torch.randn(5,3,4)

In [894]:
class DecoderBlock(nn.Module):
    
    def __init__(self,input_dim,head_size,hidden_dim):
        super().__init__()
        self.mha = MultiHeadAttention(input_dim,head_size,hidden_dim)
        self.ln1 = LayerNorm(input_dim)
        self.fd = FeedForward(input_dim,hidden_dim,input_dim)
        self.ln2 = LayerNorm(input_dim)
        self.mha2 = MultiHeadAttention(input_dim,head_size,hidden_dim)
        self.fd2 = FeedForward(input_dim,hidden_dim,input_dim)
        self.ln3 = LayerNorm(input_dim)
        
    def forward(self,x,k,v):
#         x,k0,v0 = self.mha.qkv(x) ### 需要masked
        x = x + self.mha(x,k,v)
        x = self.ln1(x)
        x = x + self.mha2(x,k,v) ### cross attention
        x = self.ln2(x)
        x = x + self.fd2(x)
        x = self.ln3(x)
        return x

In [895]:
class Transformer(nn.Module):
    
    def __init__(self,n,input_dim,head_size,hidden_dim,input_vocab_size,output_vocab_size):
        super().__init__()
        self.encoder_blocks = nn.ModuleList(
           [EncoderBlock(input_dim,head_size,hidden_dim) for _ in range(n)]
        )
        self.decoder_blocks = nn.ModuleList(
           [DecoderBlock(input_dim,head_size,hidden_dim)   for _ in range(n)]
        )
        self.input_embeddings = nn.Embedding(input_vocab_size,input_dim)
        self.output_embeddings = nn.Embedding(output_vocab_size,input_dim)
        self.output_linear = nn.Linear(head_size * hidden_dim,output_vocab_size)
        self.input_pos_embedding = nn.Embedding(1024,input_dim)
        self.output_pos_embedding = nn.Embedding(1024,input_dim)
        

    def forward(self,x,y):
        #### x ==> B,T
        B,T = x.shape
        x = self.input_embeddings(x) ### B,T,C
        x_pos = self.input_pos_embedding(torch.arange(T))
        x = x + x_pos
        for block in self.encoder_blocks:
            x = block(x,x,x)  ### B,head_size,T,hidden_size
        B,T = y.shape
        y = self.output_embeddings(y) ### B,T,C
        y_pos = self.output_pos_embedding(torch.arange(T))
        y = y + y_pos
        for block in self.decoder_blocks:
            y = block(y,x,x) ### B,head_size,T,hidden_size
        logits = self.output_linear(y) # B,T,output_vocab_size
        return logits
    

In [896]:
n = 5
input_dim = 4
head_size = 2
hidden_dim = input_dim // head_size
input_vocab_size = 10
output_vocab_size = 15

In [897]:
transformer = Transformer(n,input_dim,head_size,hidden_dim,input_vocab_size,output_vocab_size)

In [898]:
x = torch.LongTensor([
    [0,1,3],
    [0,2,3]
])
y = torch.LongTensor([
    [1,3,4,5],
    [2,3,4,6]
])

In [899]:
transformer(x,y).shape

torch.Size([2, 4, 15])

In [900]:
dataset = [
    {"input":"In the heart of the city, there is a beautiful park where people can enjoy nature and relax after a long day of work.","output":"在城市的中心，有一座美丽的公园，人们可以在那里享受自然，在漫长的工作日后放松身心。"},
    {"input":"The rapid development of technology has brought about significant changes in our daily lives, making it more convenient for us to communicate and access information.","output":"科技的快速发展给我们的日常生活带来了巨大的变化，使我们能够更方便地进行沟通和获取信息。"},
    {"input":"Despite the challenges we face, we should always maintain an optimistic attitude towards life, believing that every problem has a solution.","output":"尽管我们面临挑战，但我们应该始终保持对生活的乐观态度，相信每个问题都有解决方案。"},
    {"input":"With the increasing popularity of online shopping, more and more people are choosing to purchase goods through the internet, which not only saves time but also offers a wider range of choices.","output":"随着网上购物的日益普及，越来越多的人选择通过互联网购买商品，这不仅节省了时间，还提供了更广泛的选择。"}
]


In [901]:
class Tokenizer:
    
    def __init__(self,sentences):
        self.vocab_set = set()
        for sentence in sentences:
            self.vocab_set.update(sentence)
        self.vocab_set = list(self.vocab_set)
        self.vocab_set = ['<pad>','<bos>','<eos>'] + self.vocab_set
        self.token2id = {c:i for i,c in enumerate(self.vocab_set)}
        self.id2token = {i:c for c,i in self.token2id.items()}
    
    def convert_token_to_id(self,tokens):
        return [self.token2id.get(t,'') for t in tokens]
    
    def convert_id_to_token(self,ids):
        return [self.id2token.get(i,-1) for i in ids]

In [902]:
[d['input'] for d in dataset]

['In the heart of the city, there is a beautiful park where people can enjoy nature and relax after a long day of work.',
 'The rapid development of technology has brought about significant changes in our daily lives, making it more convenient for us to communicate and access information.',
 'Despite the challenges we face, we should always maintain an optimistic attitude towards life, believing that every problem has a solution.',
 'With the increasing popularity of online shopping, more and more people are choosing to purchase goods through the internet, which not only saves time but also offers a wider range of choices.']

In [903]:
input_tokenizer = Tokenizer([d['input'] for d in dataset])

In [904]:
output_tokenizer = Tokenizer([d['output'] for d in dataset])

In [905]:
output_tokenizer.convert_id_to_token([0,1,2,3,4,5])

['<pad>', '<bos>', '<eos>', '每', '后', '更']

In [906]:
def process(sentences,tokenizer,max_length,is_output=False):
    res = []
    for sentence in sentences:
        arr = tokenizer.convert_token_to_id(sentence)
        if is_output:
            arr = tokenizer.convert_token_to_id(['<bos>']) + arr + tokenizer.convert_token_to_id(['<eos>'])
        if len(arr) > max_length:
            arr = arr[:max_length]
        else:
            arr = arr + tokenizer.convert_token_to_id(['<pad>']) * (max_length - len(arr))
        res.append(arr)
    return res

In [907]:
inputs = [d['input'] for d in dataset]

In [908]:
inputs

['In the heart of the city, there is a beautiful park where people can enjoy nature and relax after a long day of work.',
 'The rapid development of technology has brought about significant changes in our daily lives, making it more convenient for us to communicate and access information.',
 'Despite the challenges we face, we should always maintain an optimistic attitude towards life, believing that every problem has a solution.',
 'With the increasing popularity of online shopping, more and more people are choosing to purchase goods through the internet, which not only saves time but also offers a wider range of choices.']

In [909]:
input_tokenizer.convert_token_to_id(['<pad>'])

[0]

In [910]:
inputs = process([d['input'] for d in dataset],input_tokenizer,30)

In [911]:
outputs = process([d['output'] for d in dataset],output_tokenizer,30,True)

In [912]:
x = torch.LongTensor(inputs)

In [913]:
y = torch.LongTensor(outputs)

In [939]:
n = 10
input_dim = 64
head_size = 4
hidden_dim = input_dim // head_size
input_vocab_size = len(input_tokenizer.id2token)
output_vocab_size = len(output_tokenizer.id2token)

In [940]:
output_vocab_size

129

In [941]:
input_vocab_size

34

In [942]:
transformer = Transformer(n,input_dim,head_size,hidden_dim,input_vocab_size,output_vocab_size)

In [943]:
transformer(x,y).shape

torch.Size([4, 30, 129])

In [944]:
from torch.optim import AdamW

In [945]:
optim = AdamW(transformer.parameters(),lr=1e-3)

In [946]:
# 创建交叉熵损失函数
criterion = nn.CrossEntropyLoss()

In [947]:
for _ in range(1000):
    y_inputs = y[:,:-1]
    y_targets = y[:,1:]
    logits = transformer(x,y_inputs)
    B,T = y_targets.shape
    # 计算损失
    loss = criterion(logits.reshape(B*T,-1), y_targets.reshape(B*T))
    print(loss)
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    

tensor(5.0954, grad_fn=<NllLossBackward0>)
tensor(4.7206, grad_fn=<NllLossBackward0>)
tensor(4.5362, grad_fn=<NllLossBackward0>)
tensor(4.4196, grad_fn=<NllLossBackward0>)
tensor(4.3261, grad_fn=<NllLossBackward0>)
tensor(4.2572, grad_fn=<NllLossBackward0>)
tensor(4.1908, grad_fn=<NllLossBackward0>)
tensor(4.1247, grad_fn=<NllLossBackward0>)
tensor(4.0582, grad_fn=<NllLossBackward0>)
tensor(3.9766, grad_fn=<NllLossBackward0>)
tensor(3.8654, grad_fn=<NllLossBackward0>)
tensor(3.7720, grad_fn=<NllLossBackward0>)
tensor(3.5734, grad_fn=<NllLossBackward0>)
tensor(3.4731, grad_fn=<NllLossBackward0>)
tensor(3.2772, grad_fn=<NllLossBackward0>)
tensor(2.9990, grad_fn=<NllLossBackward0>)
tensor(2.8699, grad_fn=<NllLossBackward0>)
tensor(2.6436, grad_fn=<NllLossBackward0>)
tensor(2.4588, grad_fn=<NllLossBackward0>)
tensor(2.2864, grad_fn=<NllLossBackward0>)
tensor(2.1331, grad_fn=<NllLossBackward0>)
tensor(1.9818, grad_fn=<NllLossBackward0>)
tensor(1.8255, grad_fn=<NllLossBackward0>)
tensor(1.69

tensor(0.0257, grad_fn=<NllLossBackward0>)
tensor(0.0255, grad_fn=<NllLossBackward0>)
tensor(0.0252, grad_fn=<NllLossBackward0>)
tensor(0.0250, grad_fn=<NllLossBackward0>)
tensor(0.0248, grad_fn=<NllLossBackward0>)
tensor(0.0246, grad_fn=<NllLossBackward0>)
tensor(0.0244, grad_fn=<NllLossBackward0>)
tensor(0.0242, grad_fn=<NllLossBackward0>)
tensor(0.0240, grad_fn=<NllLossBackward0>)
tensor(0.0238, grad_fn=<NllLossBackward0>)
tensor(0.0237, grad_fn=<NllLossBackward0>)
tensor(0.0235, grad_fn=<NllLossBackward0>)
tensor(0.0233, grad_fn=<NllLossBackward0>)
tensor(0.0231, grad_fn=<NllLossBackward0>)
tensor(0.0229, grad_fn=<NllLossBackward0>)
tensor(0.0227, grad_fn=<NllLossBackward0>)
tensor(0.0226, grad_fn=<NllLossBackward0>)
tensor(0.0224, grad_fn=<NllLossBackward0>)
tensor(0.0222, grad_fn=<NllLossBackward0>)
tensor(0.0220, grad_fn=<NllLossBackward0>)
tensor(0.0219, grad_fn=<NllLossBackward0>)
tensor(0.0217, grad_fn=<NllLossBackward0>)
tensor(0.0215, grad_fn=<NllLossBackward0>)
tensor(0.02

tensor(0.0083, grad_fn=<NllLossBackward0>)
tensor(0.0082, grad_fn=<NllLossBackward0>)
tensor(0.0082, grad_fn=<NllLossBackward0>)
tensor(0.0082, grad_fn=<NllLossBackward0>)
tensor(0.0081, grad_fn=<NllLossBackward0>)
tensor(0.0081, grad_fn=<NllLossBackward0>)
tensor(0.0081, grad_fn=<NllLossBackward0>)
tensor(0.0080, grad_fn=<NllLossBackward0>)
tensor(0.0080, grad_fn=<NllLossBackward0>)
tensor(0.0080, grad_fn=<NllLossBackward0>)
tensor(0.0079, grad_fn=<NllLossBackward0>)
tensor(0.0079, grad_fn=<NllLossBackward0>)
tensor(0.0079, grad_fn=<NllLossBackward0>)
tensor(0.0078, grad_fn=<NllLossBackward0>)
tensor(0.0078, grad_fn=<NllLossBackward0>)
tensor(0.0078, grad_fn=<NllLossBackward0>)
tensor(0.0077, grad_fn=<NllLossBackward0>)
tensor(0.0077, grad_fn=<NllLossBackward0>)
tensor(0.0077, grad_fn=<NllLossBackward0>)
tensor(0.0076, grad_fn=<NllLossBackward0>)
tensor(0.0076, grad_fn=<NllLossBackward0>)
tensor(0.0076, grad_fn=<NllLossBackward0>)
tensor(0.0076, grad_fn=<NllLossBackward0>)
tensor(0.00

tensor(0.0042, grad_fn=<NllLossBackward0>)
tensor(0.0042, grad_fn=<NllLossBackward0>)
tensor(0.0042, grad_fn=<NllLossBackward0>)
tensor(0.0042, grad_fn=<NllLossBackward0>)
tensor(0.0042, grad_fn=<NllLossBackward0>)
tensor(0.0042, grad_fn=<NllLossBackward0>)
tensor(0.0042, grad_fn=<NllLossBackward0>)
tensor(0.0042, grad_fn=<NllLossBackward0>)
tensor(0.0041, grad_fn=<NllLossBackward0>)
tensor(0.0041, grad_fn=<NllLossBackward0>)
tensor(0.0041, grad_fn=<NllLossBackward0>)
tensor(0.0041, grad_fn=<NllLossBackward0>)
tensor(0.0041, grad_fn=<NllLossBackward0>)
tensor(0.0041, grad_fn=<NllLossBackward0>)
tensor(0.0041, grad_fn=<NllLossBackward0>)
tensor(0.0041, grad_fn=<NllLossBackward0>)
tensor(0.0041, grad_fn=<NllLossBackward0>)
tensor(0.0040, grad_fn=<NllLossBackward0>)
tensor(0.0040, grad_fn=<NllLossBackward0>)
tensor(0.0040, grad_fn=<NllLossBackward0>)
tensor(0.0040, grad_fn=<NllLossBackward0>)
tensor(0.0040, grad_fn=<NllLossBackward0>)
tensor(0.0040, grad_fn=<NllLossBackward0>)
tensor(0.00

tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0025, grad_fn=<NllLossBackward0>)
tensor(0.0025, grad_fn=<NllLossBackward0>)
tensor(0.0025, grad_fn=<NllLossBackward0>)
tensor(0.0025, grad_fn=<NllLossBackward0>)
tensor(0.0025, grad_fn=<NllLossBackward0>)
tensor(0.0025, grad_fn=<NllLossBackward0>)
tensor(0.0025, grad_fn=<NllLossBackward0>)
tensor(0.00

tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0018, grad_fn=<NllLossBackward0>)
tensor(0.0017, grad_fn=<NllLossBackward0>)
tensor(0.0017, grad_fn=<NllLossBackward0>)
tensor(0.0017, grad_fn=<NllLossBackward0>)
tensor(0.0017, grad_fn=<NllLossBackward0>)
tensor(0.00

In [948]:
def predict(model,inputs):
    ids = input_tokenizer.convert_token_to_id(inputs)
    x = torch.LongTensor([ids])
    y = torch.LongTensor([output_tokenizer.convert_token_to_id(['<bos>'])])
    for _ in range(100):
        logits = model(x,y)
        ### logits B,T,vocab_size
        logits = logits[:,-1,:]
        ### logits B,T,vocab_size
        predicts = logits.argmax(dim=-1,keepdim=True) # B,1
        y = torch.cat((y,predicts),dim=-1)
    print(y.shape)
    for b in range(y.shape[0]):
        for i in y[b]:
            print(output_tokenizer.convert_id_to_token([int(i)]))

In [949]:
predict(transformer,'The rapid development of technology has brought about significant changes in our daily lives, making it more convenient for us to communicate and access information.')

torch.Size([1, 101])
['<bos>']
['，']
['城']
['市']
['的']
['相']
['公']
['，']
['能']
['，']
['人']
['益']
['的']
['，']
['然']
['的']
['，']
['人']
['们']
['的']
['日']
['相']
['的']
['化']
['，']
['使']
['益']
['们']
['，']
['，']
['能']
['，']
['的']
['相']
['的']
['相']
['终']
['，']
['相']
['，']
['，']
['相']
['相']
['，']
['人']
['，']
['相']
['，']
['的']
['给']
['们']
['的']
['信']
['，']
['，']
['，']
['越']
['的']
['们']
['应']
['，']
['联']
['日']
['益']
['丽']
['的']
['们']
['益']
['的']
['相']
['信']
['，']
['相']
['，']
['的']
['面']
['，']
['的']
['相']
['，']
['终']
['城']
['，']
['，']
['相']
['的']
['终']
['保']
['，']
['使']
['，']
['，']
['，']
['，']
['能']
['们']
['的']
['们']
['应']
['，']
['，']


In [951]:
##### GPT

In [988]:
class GPT(nn.Module):
    
    def __init__(self,n,input_dim,head_size,hidden_dim,output_vocab_size):
        super().__init__()
        self.decoder_blocks = nn.ModuleList(
           [DecoderBlock(input_dim,head_size,hidden_dim)   for _ in range(n)]
        )
        self.output_embeddings = nn.Embedding(output_vocab_size,input_dim)
        self.output_linear = nn.Linear(head_size * hidden_dim,output_vocab_size)
        self.output_pos_embedding = nn.Embedding(1024,input_dim)
        

    def forward(self,y):
        B,T = y.shape
        y = self.output_embeddings(y) ### B,T,C
        y_pos = self.output_pos_embedding(torch.arange(T))
        y = y + y_pos
        for block in self.decoder_blocks:
            y = block(y,y,y) ### B,head_size,T,hidden_size
        logits = self.output_linear(y) # B,T,output_vocab_size
        return logits

In [1003]:
sentences = ['你好，世界',
             '好奇怪',
             '']

In [1004]:
tokenizer = Tokenizer(sentences)

In [1005]:
tokenizer.id2token

{0: '<pad>',
 1: '<bos>',
 2: '<eos>',
 3: '怪',
 4: '，',
 5: '界',
 6: '好',
 7: '奇',
 8: '你',
 9: '世'}

In [1006]:
n = 10
input_dim = 64
head_size = 4
hidden_dim = input_dim // head_size
output_vocab_size = len(tokenizer.id2token)

In [1026]:
gpt = GPT(n,input_dim,head_size,hidden_dim,output_vocab_size)

In [1027]:
optim = AdamW(gpt.parameters(),lr=1e-3)

In [1028]:
criterion = nn.CrossEntropyLoss()

In [1041]:
def process(sentences,tokenizer,max_length):
    res = []
    length = []
    for sentence in sentences:
        arr = tokenizer.convert_token_to_id(sentence) + tokenizer.convert_token_to_id(['<eos>'])
        length.append(len(arr))
        if len(arr) > max_length:
            arr = arr[:max_length]
        else:
            arr = arr + tokenizer.convert_token_to_id(['<pad>']) * (max_length - len(arr))
        res.append(arr)
    return res,length

In [1042]:
y,lengths = process(sentences,tokenizer,6)

In [1043]:
y = torch.LongTensor(y)

In [1044]:
y.shape

torch.Size([2, 6])

In [1045]:
batch_length = torch.LongTensor(lengths)

In [1046]:
batch_length

tensor([6, 4])

In [1047]:
mask

tensor([ True,  True,  True,  True, False,  True,  True, False, False, False])

In [1048]:
for _ in range(1000):
    y_inputs = y[:,:-1]
    y_targets = y[:,1:]
    logits = gpt(y_inputs)
    B,T = y_targets.shape
    # 计算损失
#     loss = criterion(logits.reshape(B*T,-1), y_targets.reshape(B*T))
    
    
    # 创建mask来标记有效位置
    mask = torch.arange(T, device=y_targets.device)[None,:] < (batch_length-1)[:,None]  # shape: (B,T)
    mask = mask.reshape(-1)  # shape: (B*T)

    # 只计算有效位置的loss
    logits_flat = logits.reshape(-1, logits.size(-1))  # shape: (B*T,vocab_size) 
    targets_flat = y_targets.reshape(-1)  # shape: (B*T)

    # 方法1: 使用mask选择有效位置
    valid_logits = logits_flat[mask]  # shape: (num_valid,vocab_size)
    valid_targets = targets_flat[mask]  # shape: (num_valid)
    loss = criterion(valid_logits, valid_targets)
    print(loss)
    
    optim.zero_grad()
    loss.backward()
    optim.step()

tensor(2.2969, grad_fn=<NllLossBackward0>)
tensor(1.1883, grad_fn=<NllLossBackward0>)
tensor(1.4297, grad_fn=<NllLossBackward0>)
tensor(1.2062, grad_fn=<NllLossBackward0>)
tensor(0.8292, grad_fn=<NllLossBackward0>)
tensor(0.7620, grad_fn=<NllLossBackward0>)
tensor(1.4557, grad_fn=<NllLossBackward0>)
tensor(0.9914, grad_fn=<NllLossBackward0>)
tensor(0.7846, grad_fn=<NllLossBackward0>)
tensor(0.8251, grad_fn=<NllLossBackward0>)
tensor(0.7169, grad_fn=<NllLossBackward0>)
tensor(0.5769, grad_fn=<NllLossBackward0>)
tensor(0.5330, grad_fn=<NllLossBackward0>)
tensor(0.5041, grad_fn=<NllLossBackward0>)
tensor(0.4938, grad_fn=<NllLossBackward0>)
tensor(0.4927, grad_fn=<NllLossBackward0>)
tensor(0.4769, grad_fn=<NllLossBackward0>)
tensor(0.4510, grad_fn=<NllLossBackward0>)
tensor(0.4401, grad_fn=<NllLossBackward0>)
tensor(0.4423, grad_fn=<NllLossBackward0>)
tensor(0.4422, grad_fn=<NllLossBackward0>)
tensor(0.4356, grad_fn=<NllLossBackward0>)
tensor(0.4263, grad_fn=<NllLossBackward0>)
tensor(0.42

tensor(0.1716, grad_fn=<NllLossBackward0>)
tensor(0.1711, grad_fn=<NllLossBackward0>)
tensor(0.1707, grad_fn=<NllLossBackward0>)
tensor(0.1707, grad_fn=<NllLossBackward0>)
tensor(0.1736, grad_fn=<NllLossBackward0>)
tensor(0.1791, grad_fn=<NllLossBackward0>)
tensor(0.1755, grad_fn=<NllLossBackward0>)
tensor(0.1749, grad_fn=<NllLossBackward0>)
tensor(0.1791, grad_fn=<NllLossBackward0>)
tensor(0.1746, grad_fn=<NllLossBackward0>)
tensor(0.1749, grad_fn=<NllLossBackward0>)
tensor(0.1771, grad_fn=<NllLossBackward0>)
tensor(0.1740, grad_fn=<NllLossBackward0>)
tensor(0.1737, grad_fn=<NllLossBackward0>)
tensor(0.1750, grad_fn=<NllLossBackward0>)
tensor(0.1718, grad_fn=<NllLossBackward0>)
tensor(0.1708, grad_fn=<NllLossBackward0>)
tensor(0.1728, grad_fn=<NllLossBackward0>)
tensor(0.1715, grad_fn=<NllLossBackward0>)
tensor(0.1700, grad_fn=<NllLossBackward0>)
tensor(0.1693, grad_fn=<NllLossBackward0>)
tensor(0.1685, grad_fn=<NllLossBackward0>)
tensor(0.1684, grad_fn=<NllLossBackward0>)
tensor(0.16

tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0006, grad_fn=<NllLossBackward0>)
tensor(0.0006, grad_fn=<NllLossBackward0>)
tensor(0.0006, grad_fn=<NllLossBackward0>)
tensor(0.00

tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.0004, grad_fn=<NllLossBackward0>)
tensor(0.00

tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.00

tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)
tensor(0.0003, grad_fn=<NllLossBackward0>)


In [1049]:
def predict(model,inputs,tokenizer):
    ids = tokenizer.convert_token_to_id(inputs)
    print(ids)
    y = torch.LongTensor([ids])
    print('yyyy shape',y.shape)
    for _ in range(100):
        logits = model(y)
        ### logits B,T,vocab_size
        logits = logits[:,-1,:]
        ### logits B,T,vocab_size
        predicts = logits.argmax(dim=-1,keepdim=True) # B,1
        y = torch.cat((y,predicts),dim=-1)
    print(y.shape)
    for b in range(y.shape[0]):
        for i in y[b]:
            print(tokenizer.convert_id_to_token([int(i)]))

In [1054]:
predict(gpt,'你好，世',tokenizer)

[8, 6, 4, 9]
yyyy shape torch.Size([1, 4])
torch.Size([1, 104])
['你']
['好']
['，']
['世']
['界']
['<eos>']
['好']
['，']
['世']
['世']
['世']
['世']
['界']
['世']
['，']
['世']
['界']
['<eos>']
['，']
['世']
['，']
['世']
['世']
['好']
['<eos>']
['好']
['<eos>']
['<eos>']
['<eos>']
['<eos>']
['好']
['，']
['世']
['好']
['<eos>']
['好']
['<eos>']
['世']
['<eos>']
['好']
['<eos>']
['<eos>']
['<eos>']
['<eos>']
['<eos>']
['<eos>']
['<eos>']
['<eos>']
['好']
['世']
['世']
['世']
['好']
['好']
['好']
['<eos>']
['好']
['<eos>']
['<eos>']
['好']
['界']
['世']
['世']
['世']
['<eos>']
['好']
['好']
['<eos>']
['，']
['世']
['，']
['世']
['好']
['好']
['<eos>']
['好']
['界']
['世']
['<eos>']
['<eos>']
['<eos>']
['<eos>']
['<eos>']
['<eos>']
['<eos>']
['<eos>']
['<eos>']
['世']
['<eos>']
['<eos>']
['<eos>']
['好']
['，']
['世']
['世']
['，']
['世']
['界']
['世']
['，']
['界']
['好']
['世']
['世']
