## Build a miniGPT from scratch
### 1.import package
### 2.Define GPT parameters
### * 3.Define GPT structure
    · 3.1 Single Head attention
    · 3.2 Multi Head attention 
    · 3.3 Feed Forward(MLP)
    · 3.4 block
    · 3.5 GPT(embedding, position, norm, mlp, block)
### * 4.Construct input Dataset
### 5.run Function
![GPT2](./pics/GPT2.png)
### 1.import package

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as f
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass

torch.manual_seed(1024)

### 2. Define GPT parameters

In [None]:
@dataclass
class GPTConfig:
    block_size : int = 512  # 文本最大长度
    batch_size : int = 12
    n_layer : int = 12
    n_head : int = 12
    n_embd : int = 768  # hidden_dim, hidden_size; 这里emb_size
    hidden_dim : int = n_embd
    # 为了可以ti_embedding_weight
    dropout : float = 0.1
    head_size : int = n_embd // n_head
    # vocab_size, GPT2官方的tokenzier
    vocab_size : int = 50274

### * 3. Define GPT structure

In [None]:
'''
    1.Single Head Attention
'''
class Single_Head_Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.key = nn.Linear(config.hidden_dim, config.head_size)
        self.value = nn.Linear(config.hidden_dim, config.head_size)
        self.query = nn.Linear(config.hidden_dim, config.head_size)
        # attention_mask 通过 register_buffer 注册
        self.register_buffer(
            'attention_mask',
            # tril 下三角矩阵
            torch.tril(
                torch.ones(config.block_size, config.block_size)
            )
        )
        self.dropout = nn.Droput(config.dropout)
    
    def forward(self, x):
        batch_size, block_size, _ = x.size()
        k = self.key(x)
        v = self.value(x)
        q = self.query(x)
        weight = q @ k.transpose(-2, -1)
        weight = weight.masked_fill(
            self.attention_mask[:seq_len, :sqe_len] == 0,
            float('-inf')
        )
        weight = F.softmax(weight, dim=-1) / math.sqrt(self.head_size)
        weight = self.dropout(weight)
        out = weight @ v
        return out


'''
    2.Multi Head Attention
'''
class Multi_Head_Attention(nn.Module):
    def __init__(self):
        super().__init__()
        self.heads - nn.ModuleList(
            [
                Single_Head_Attention(config)
                for _ in range(config.n_head)
            ]
        )
        self.proj = nn.Liner(config.hidden_dim, config.hidden_dim)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        output = torch.cat(
            [h(x) for h in self.heads], 
            dim=-1
        )
        output = self.proj(output)
        output = self.dropout(output)
        return output


'''
    3.Feed Forward(MLP)
'''
class Feed_Forward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.net = nn.Sequential(
            nn.Liner(config.hidden_dim, 4 * config.hidden_dim),
            nn.GELU(),
            nn.Liner(4 * config.hidden_dim, config.hidden_dim),
            nn.Dropout(config.dropout)
        )

    def forward(self, x):
        return self.net(x)


'''
    4.block
'''
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.att = Multi_Head_Attention(config)
        self.ffn = Feed_Forward(config)
        self.ln1 = nn.LayerNorm(config.hidden_dim)
        self.ln2 = nn.LayerNorm(config.hidden_dim)

    def forward(self, x):
        x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x


'''
    5.GPT
    (embedding, position, norm, mlp, block)
'''
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        # embedding
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd)
        # block
        self.blocks = nn.Sequential(
            *[Block(config) for _ in range(config.n_layer)]
        )
        # norm
        self.ln_final = nn.LayerNorm(config.n_embd)
        # mlp
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # SLM 模型使用 tie_weight 减少参数
        self.token_embedding_table.weight = self.lm_head.weight

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        batch_size, sqe_len = idx.size()
        # embedding
        token_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(sqe_len, device=idx.device))
        x = token_emb + pos_emb
        # block
        x = self.blocks(x)
        # norm
        x = self.ln_final(x)
        # mlp
        logits = self.lm_head(x)
        if targets is None:
            loss = None
        else:
            batch_size, sqe_len, vocab_size = logits.size()
            logits = logits.view(batch_size * sqe_len, vocab_size)
            targets = targets.view(batch_size * sqe_len)
            
        return logits, loss
