In [1]:
from torch import nn

In [2]:
class FeedForward(nn.Module):
    
    def __init__(self,input_dim,hidden_dim,output_dim):
        super().__init__()
        self.relu = nn.ReLU()
        self.linear1 = nn.Linear(input_dim,hidden_dim)
        self.linear2 = nn.Linear(hidden_dim,output_dim)
        
    
    def forward(self,x):
        x = self.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [3]:
import torch

In [4]:
x = torch.randn((4,5))

In [5]:
fd = FeedForward(5,7,6)

In [6]:
fd(x).shape

torch.Size([4, 6])

In [7]:
import torch.nn.functional as F

In [8]:
class LayerNorm(nn.Module):
    
    def __init__(self,input_dim):
        super().__init__()
        self.ln = nn.LayerNorm(input_dim)

    
    def forward(self,x):
        return self.ln(x)

In [9]:
x = torch.randn(5,4)
ln = LayerNorm(4)
ln(x)

tensor([[-1.4421, -0.3274,  1.2366,  0.5329],
        [-0.4948,  1.5908, -1.1065,  0.0105],
        [-0.9105,  1.5961,  0.0977, -0.7834],
        [ 1.6821, -0.8249, -0.6762, -0.1810],
        [-0.2390, -1.3627,  0.1671,  1.4346]],
       grad_fn=<NativeLayerNormBackward0>)

In [10]:
class Attention(nn.Module):
    
    def __init__(self,input_dim,hidden_dim):
        super().__init__()
        self.qw = nn.Linear(input_dim,hidden_dim)
        self.kw = nn.Linear(input_dim,hidden_dim)
        self.vw = nn.Linear(input_dim,hidden_dim)
    
    def forward(self,x):
        ## B,T,C
        B,T,C = x.shape
        q = self.qw(x)
        k = self.kw(x)
        v = self.vw(x)
        print(q.shape,k.shape,k.T.shape)
        att = q @ k.permute(0,2,1)
#         att = att.masked_fill(mask, value)
        att = F.softmax(att,dim=-1)
        v = att @ v
        return v

In [11]:
x = torch.randn(5,3,4)

In [12]:
att = Attention(4,6)

In [13]:
att(x).shape

torch.Size([5, 3, 6]) torch.Size([5, 3, 6]) torch.Size([6, 3, 5])


  print(q.shape,k.shape,k.T.shape)


torch.Size([5, 3, 6])

In [21]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self,input_dim,head_size,hidden_dim):
        super().__init__()        
        self.head_size = head_size
        self.hidden_size = hidden_dim
        self.qw = nn.Linear(input_dim,head_size * hidden_dim)
        self.kw = nn.Linear(input_dim,head_size * hidden_dim)
        self.vw = nn.Linear(input_dim,head_size * hidden_dim)
        
    def forward(self,q,k,v,masked=False):
        #### q ==> B,T,C
        q = self.qw(q)
        k = self.kw(k)
        v = self.vw(v)
        #### q ===> B,head_size,T,hidden_size
        B,T,C = q.shape
        q = q.reshape(B,T,self.head_size,self.hidden_size).permute(0,2,1,3)
        B,T,C = k.shape
        k = k.reshape(B,T,self.head_size,self.hidden_size).permute(0,2,1,3)
        B,T,C = v.shape
        v = v.reshape(B,T,self.head_size,self.hidden_size).permute(0,2,1,3)
        B,head_size,T,hidden_size = q.shape
        att = q @ k.permute(0,1,3,2) # B,head_size,T,T
        if masked:
            _,_,m,n = att.shape
            mask = torch.ones(m,n)
            mask = torch.tril(mask)
            att = att.masked_fill(mask==0,float('-inf'))
        att = F.softmax(att,dim=-1)
        v = att @ v  # B,head_size,T,hidden_size
        v = v.permute(0,2,1,3) # B,T,head_size,hidden_size
        v = v.reshape(B,T,self.head_size * self.hidden_size)
        return v       
        

In [22]:
x = torch.randn(5,4,3)

In [23]:
att1 = MultiHeadAttention(3,2,3)

In [24]:
att = torch.randn(1,1,2,4)

In [25]:
att

tensor([[[[ 0.0201,  0.3462, -0.8698,  0.4331],
          [-0.5261, -0.5521, -0.5846, -0.2188]]]])

In [26]:
 _,_,m,n = att.shape
mask = torch.ones(m,n)

In [27]:
mask = torch.tril(mask)

In [28]:
mask == 0

tensor([[False,  True,  True,  True],
        [False, False,  True,  True]])

In [29]:
att = att.masked_fill(mask==0,float('-inf'))

In [30]:
torch.softmax(att,dim=-1)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5065, 0.4935, 0.0000, 0.0000]]]])

In [31]:
class EncoderBlock(nn.Module):
    
    def __init__(self,input_dim,head_size,hidden_dim):
        super().__init__()
        self.mha = MultiHeadAttention(input_dim,head_size,hidden_dim)
        self.ln1 = LayerNorm(input_dim)
        self.fd = FeedForward(input_dim,hidden_dim,input_dim)
        self.ln2 = LayerNorm(input_dim)
    
    def forward(self,q,k,v):
        x = q + self.mha(q,k,v)
        x = self.ln1(x)
        x = x + self.fd(x)
        x = self.ln2(x)
        return x

In [32]:
block = EncoderBlock(4,2,2)

In [33]:
x = torch.randn(5,3,4)

In [34]:
class DecoderBlock(nn.Module):
    
    def __init__(self,input_dim,head_size,hidden_dim):
        super().__init__()
        self.mha = MultiHeadAttention(input_dim,head_size,hidden_dim)
        self.ln1 = LayerNorm(input_dim)
        self.fd = FeedForward(input_dim,hidden_dim,input_dim)
        self.ln2 = LayerNorm(input_dim)
        self.mha2 = MultiHeadAttention(input_dim,head_size,hidden_dim)
        self.fd2 = FeedForward(input_dim,hidden_dim,input_dim)
        self.ln3 = LayerNorm(input_dim)
        
    def forward(self,x,k,v):
#         x,k0,v0 = self.mha.qkv(x) ### 需要masked
        x = x + self.mha(x,k,v)
        x = self.ln1(x)
        x = x + self.mha2(x,k,v) ### cross attention
        x = self.ln2(x)
        x = x + self.fd2(x)
        x = self.ln3(x)
        return x

In [36]:
class Tokenizer:
    
    def __init__(self,sentences):
        self.vocab_set = set()
        for sentence in sentences:
            self.vocab_set.update(sentence)
        self.vocab_set = list(self.vocab_set)
        self.vocab_set = ['<pad>','<bos>','<eos>'] + self.vocab_set
        self.token2id = {c:i for i,c in enumerate(self.vocab_set)}
        self.id2token = {i:c for c,i in self.token2id.items()}
    
    def convert_token_to_id(self,tokens):
        return [self.token2id.get(t,'') for t in tokens]
    
    def convert_id_to_token(self,ids):
        return [self.id2token.get(i,-1) for i in ids]

In [46]:
from torch import nn

In [48]:
from torch.optim import AdamW

In [951]:
##### GPT

In [49]:
class GPT(nn.Module):
    
    def __init__(self,n,input_dim,head_size,hidden_dim,output_vocab_size):
        super().__init__()
        self.decoder_blocks = nn.ModuleList(
           [DecoderBlock(input_dim,head_size,hidden_dim)   for _ in range(n)]
        )
        self.output_embeddings = nn.Embedding(output_vocab_size,input_dim)
        self.output_linear = nn.Linear(head_size * hidden_dim,output_vocab_size)
        self.output_pos_embedding = nn.Embedding(1024,input_dim)
        

    def forward(self,y):
        B,T = y.shape
        y = self.output_embeddings(y) ### B,T,C
        y_pos = self.output_pos_embedding(torch.arange(T))
        y = y + y_pos
        for block in self.decoder_blocks:
            y = block(y,y,y) ### B,head_size,T,hidden_size
        logits = self.output_linear(y) # B,T,output_vocab_size
        return logits

In [50]:
sentences = ['你好，世界',
             '好奇怪']

In [51]:
tokenizer = Tokenizer(sentences)

In [52]:
tokenizer.id2token

{0: '<pad>',
 1: '<bos>',
 2: '<eos>',
 3: '好',
 4: '世',
 5: '，',
 6: '怪',
 7: '奇',
 8: '你',
 9: '界'}

In [53]:
n = 10
input_dim = 64
head_size = 4
hidden_dim = input_dim // head_size
output_vocab_size = len(tokenizer.id2token)

In [54]:
gpt = GPT(n,input_dim,head_size,hidden_dim,output_vocab_size)

In [55]:
optim = AdamW(gpt.parameters(),lr=1e-3)

In [56]:
criterion = nn.CrossEntropyLoss()

In [57]:
def process(sentences,tokenizer,max_length):
    res = []
    length = []
    for sentence in sentences:
        arr = tokenizer.convert_token_to_id(sentence) + tokenizer.convert_token_to_id(['<eos>'])
        length.append(len(arr))
        if len(arr) > max_length:
            arr = arr[:max_length]
        else:
            arr = arr + tokenizer.convert_token_to_id(['<pad>']) * (max_length - len(arr))
        res.append(arr)
    return res,length

In [58]:
y,lengths = process(sentences,tokenizer,6)

In [59]:
y = torch.LongTensor(y)

In [60]:
y.shape

torch.Size([2, 6])

In [61]:
batch_length = torch.LongTensor(lengths)

In [62]:
batch_length

tensor([6, 4])

In [63]:
mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.]])

In [64]:
for _ in range(1000):
    y_inputs = y[:,:-1]
    y_targets = y[:,1:]
    logits = gpt(y_inputs)
    B,T = y_targets.shape
    # 计算损失
#     loss = criterion(logits.reshape(B*T,-1), y_targets.reshape(B*T))
    
    
    # 创建mask来标记有效位置
    mask = torch.arange(T, device=y_targets.device)[None,:] < (batch_length-1)[:,None]  # shape: (B,T)
    mask = mask.reshape(-1)  # shape: (B*T)

    # 只计算有效位置的loss
    logits_flat = logits.reshape(-1, logits.size(-1))  # shape: (B*T,vocab_size) 
    targets_flat = y_targets.reshape(-1)  # shape: (B*T)

    # 方法1: 使用mask选择有效位置
    valid_logits = logits_flat[mask]  # shape: (num_valid,vocab_size)
    valid_targets = targets_flat[mask]  # shape: (num_valid)
    loss = criterion(valid_logits, valid_targets)
    print(loss)
    
    optim.zero_grad()
    loss.backward()
    optim.step()

tensor(2.3761, grad_fn=<NllLossBackward0>)
tensor(1.9613, grad_fn=<NllLossBackward0>)
tensor(1.7730, grad_fn=<NllLossBackward0>)
tensor(1.6796, grad_fn=<NllLossBackward0>)
tensor(1.6047, grad_fn=<NllLossBackward0>)
tensor(1.5634, grad_fn=<NllLossBackward0>)
tensor(1.5374, grad_fn=<NllLossBackward0>)
tensor(1.5118, grad_fn=<NllLossBackward0>)
tensor(1.4885, grad_fn=<NllLossBackward0>)
tensor(1.4680, grad_fn=<NllLossBackward0>)
tensor(1.4491, grad_fn=<NllLossBackward0>)
tensor(1.4251, grad_fn=<NllLossBackward0>)
tensor(1.3942, grad_fn=<NllLossBackward0>)
tensor(1.3452, grad_fn=<NllLossBackward0>)
tensor(1.2645, grad_fn=<NllLossBackward0>)
tensor(1.4252, grad_fn=<NllLossBackward0>)
tensor(1.3083, grad_fn=<NllLossBackward0>)
tensor(1.2859, grad_fn=<NllLossBackward0>)
tensor(1.1610, grad_fn=<NllLossBackward0>)
tensor(1.4035, grad_fn=<NllLossBackward0>)
tensor(1.4215, grad_fn=<NllLossBackward0>)
tensor(1.4675, grad_fn=<NllLossBackward0>)
tensor(1.3656, grad_fn=<NllLossBackward0>)
tensor(1.21

tensor(0.0063, grad_fn=<NllLossBackward0>)
tensor(0.0063, grad_fn=<NllLossBackward0>)
tensor(0.0063, grad_fn=<NllLossBackward0>)
tensor(0.0062, grad_fn=<NllLossBackward0>)
tensor(0.0062, grad_fn=<NllLossBackward0>)
tensor(0.0061, grad_fn=<NllLossBackward0>)
tensor(0.0061, grad_fn=<NllLossBackward0>)
tensor(0.0061, grad_fn=<NllLossBackward0>)
tensor(0.0060, grad_fn=<NllLossBackward0>)
tensor(0.0060, grad_fn=<NllLossBackward0>)
tensor(0.0059, grad_fn=<NllLossBackward0>)
tensor(0.0059, grad_fn=<NllLossBackward0>)
tensor(0.0059, grad_fn=<NllLossBackward0>)
tensor(0.0058, grad_fn=<NllLossBackward0>)
tensor(0.0058, grad_fn=<NllLossBackward0>)
tensor(0.0057, grad_fn=<NllLossBackward0>)
tensor(0.0057, grad_fn=<NllLossBackward0>)
tensor(0.0057, grad_fn=<NllLossBackward0>)
tensor(0.0056, grad_fn=<NllLossBackward0>)
tensor(0.0056, grad_fn=<NllLossBackward0>)
tensor(0.0056, grad_fn=<NllLossBackward0>)
tensor(0.0055, grad_fn=<NllLossBackward0>)
tensor(0.0055, grad_fn=<NllLossBackward0>)
tensor(0.00

tensor(0.0023, grad_fn=<NllLossBackward0>)
tensor(0.0023, grad_fn=<NllLossBackward0>)
tensor(0.0023, grad_fn=<NllLossBackward0>)
tensor(0.0023, grad_fn=<NllLossBackward0>)
tensor(0.0023, grad_fn=<NllLossBackward0>)
tensor(0.0023, grad_fn=<NllLossBackward0>)
tensor(0.0023, grad_fn=<NllLossBackward0>)
tensor(0.0023, grad_fn=<NllLossBackward0>)
tensor(0.0023, grad_fn=<NllLossBackward0>)
tensor(0.0022, grad_fn=<NllLossBackward0>)
tensor(0.0022, grad_fn=<NllLossBackward0>)
tensor(0.0022, grad_fn=<NllLossBackward0>)
tensor(0.0022, grad_fn=<NllLossBackward0>)
tensor(0.0022, grad_fn=<NllLossBackward0>)
tensor(0.0022, grad_fn=<NllLossBackward0>)
tensor(0.0022, grad_fn=<NllLossBackward0>)
tensor(0.0022, grad_fn=<NllLossBackward0>)
tensor(0.0022, grad_fn=<NllLossBackward0>)
tensor(0.0022, grad_fn=<NllLossBackward0>)
tensor(0.0022, grad_fn=<NllLossBackward0>)
tensor(0.0022, grad_fn=<NllLossBackward0>)
tensor(0.0021, grad_fn=<NllLossBackward0>)
tensor(0.0021, grad_fn=<NllLossBackward0>)
tensor(0.00

tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.0012, grad_fn=<NllLossBackward0>)
tensor(0.00

tensor(0.0008, grad_fn=<NllLossBackward0>)
tensor(0.0008, grad_fn=<NllLossBackward0>)
tensor(0.0008, grad_fn=<NllLossBackward0>)
tensor(0.0008, grad_fn=<NllLossBackward0>)
tensor(0.0008, grad_fn=<NllLossBackward0>)
tensor(0.0008, grad_fn=<NllLossBackward0>)
tensor(0.0008, grad_fn=<NllLossBackward0>)
tensor(0.0008, grad_fn=<NllLossBackward0>)
tensor(0.0008, grad_fn=<NllLossBackward0>)
tensor(0.0008, grad_fn=<NllLossBackward0>)
tensor(0.0008, grad_fn=<NllLossBackward0>)
tensor(0.0008, grad_fn=<NllLossBackward0>)
tensor(0.0008, grad_fn=<NllLossBackward0>)
tensor(0.0008, grad_fn=<NllLossBackward0>)
tensor(0.0008, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.0007, grad_fn=<NllLossBackward0>)
tensor(0.00

tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)
tensor(0.0005, grad_fn=<NllLossBackward0>)


In [65]:
def predict(model,inputs,tokenizer):
    ids = tokenizer.convert_token_to_id(inputs)
    print(ids)
    y = torch.LongTensor([ids])
    print('yyyy shape',y.shape)
    for _ in range(100):
        logits = model(y)
        ### logits B,T,vocab_size
        logits = logits[:,-1,:]
        ### logits B,T,vocab_size
        predicts = logits.argmax(dim=-1,keepdim=True) # B,1
        y = torch.cat((y,predicts),dim=-1)
    print(y.shape)
    for b in range(y.shape[0]):
        for i in y[b]:
            print(tokenizer.convert_id_to_token([int(i)]))

In [69]:
predict(gpt,'你，',tokenizer)

[8, 5]
yyyy shape torch.Size([1, 2])
torch.Size([1, 102])
['你']
['，']
['界']
['世']
['界']
['<eos>']
['界']
['界']
['界']
['世']
['界']
['界']
['世']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
['界']
