In [1]:
import torch.nn as nn
import pandas as pd
import math
import torch

batch_size = 32
d_model = 256
d_ff = 512
N_decoder_layers = 8
N_encoder_layers = 4
vocab_size = 23


In [2]:
"""
src_mask: 用于 encoder 和 decoder 的 cross-attention
tgt_mask: 用于 decoder 的 self-attention
"""

'\nsrc_mask: 用于 encoder 和 decoder 的 cross-attention\ntgt_mask: 用于 decoder 的 self-attention\n'

In [3]:
def get_padding_mask(queries, keys):
    """
    queries: (B,heads ,L_q, D)
    keys:    (B,heads ,L_k, D)
    return: padding_mask (B, heads, L_q, L_k)
    """
    # 假设padding部分全为0,这里不为0的部分都是True
    query_mask = torch.sum(queries, dim=-1) != 0  # shape: (B,heads, L_q) bool
    key_mask = torch.sum(keys, dim=-1) != 0      # shape: (B,heads, L_k) bool

    # 扩展维度以构造 (B, heads, L_q, L_k)
    query_mask = query_mask.unsqueeze(3)         # (B,heads, L_q, 1)
    key_mask = key_mask.unsqueeze(2)             # (B, heads,1, L_k)

    # 只有 query 和 key 都是有效的，才是 True，True代表着有效
    padding_mask = query_mask & key_mask         # (B, L_q, L_k)
    padding_mask = ~padding_mask                 # 取反，mask位置为 True

    return padding_mask           # (B, 1, L_q, L_k)

### FFN层

In [4]:
class FFN(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.dense1 = nn.Linear(in_features, hidden_features)
        self.dense2 = nn.Linear(hidden_features, out_features)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.dense1(x))
        x = self.dense2(x)
        return x  #-->batch_size,seq_length,d_model

### 多头注意力

In [5]:
#@save
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model):
        super().__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        assert (d_model % n_heads == 0)
        self.d_k = d_model // n_heads
        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)
        self.linear_o = nn.Linear(d_model, d_model)

    def _masked_attention(self, q, k, v, mask=None):
        mask = get_padding_mask(q, k)
        attention = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.d_k)  # (B, H, L_q, L_k)

        # 构造因果掩码（causal_mask）
        causal_mask = torch.triu(torch.ones(q.size(-2), k.size(-2), device=q.device), diagonal=1)
        causal_mask = causal_mask.bool().unsqueeze(0).unsqueeze(1)  # (1, 1, L_q, L_k)

        if mask is not None:
            mask = mask.bool()
            # 广播：padding_mask (B, 1, L_q, L_k) | causal_mask (1, 1, L_q, L_k) -> (B, 1, L_q, L_k)
            mask = mask | causal_mask
        else:
            mask = causal_mask  # 直接使用因果mask，广播给attention用

        # 应用mask：mask为True的位置被设置为-1e9
        attention = torch.where(mask, torch.tensor(-1e9, device=q.device), attention)

        weights = torch.softmax(attention, dim=-1)
        return torch.matmul(weights, v)  # (B, H, L_q, D)

    def forward(self, q, k, v, mask=None):
        b,s_q,d = q.shape
        _,s_k,_ = k.shape
        assert (d == self.d_model)
        queries = self.linear_q(q).reshape(b,s_q,-1,self.d_k).transpose(1, 2)
        keys = self.linear_k(k).reshape(b,s_k,-1,self.d_k).transpose(1, 2)
        values = self.linear_v(v).reshape(b,s_k,-1,self.d_k).transpose(1, 2)
        attention = self._masked_attention(queries, keys, values, mask)
        attention = self.linear_o(attention.transpose(1, 2).reshape(b,s_q,-1))
        return attention  #-->batch_size,query_seq_length,d_model

### 残差链接以及归一化

In [6]:
class AddAndNorm(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.layernorm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.2)
    def forward(self, input,output):
        return self.layernorm(self.dropout(output) + input)

### 用到的特殊MLP

In [7]:
class MLP(torch.nn.Module):
    def __init__(self, list_dims, dropout):
        super().__init__()
        self.layers = torch.nn.ModuleList()
        for i in range(len(list_dims)-1):
            self.layers.append(torch.nn.Linear(list_dims[i], list_dims[i+1]))
            self.layers.append(torch.nn.ReLU())
            self.layers.append(torch.nn.Dropout(p=dropout))

    def forward(self, x):
        output = x
        for layer in self.layers:
            output = layer(output)
        return output

### 位置编码

In [8]:
class PositionalEncoding(nn.Module):  #-->这里要注意，dmodel一定要是偶数，不然我也不知道会发生什么
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.max_len = max_len
        self.pe = torch.zeros(1,max_len, d_model)
        self.position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) #-->max_len,1
        self.sin_div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        self.cos_div_term = torch.exp(torch.arange(1, d_model, 2).float() * (-math.log(10000.0) / d_model))
        self.pe[:,:,0::2] = torch.sin(self.position * self.sin_div_term)
        self.pe[:,:,1::2] = torch.cos(self.position * self.cos_div_term)

    def forward(self, x):
        #处理一下设备的问题
        self.pe = self.pe.to(x.device)
        return x + self.pe[:,:x.size(1),:]

### 编码层

In [9]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads,d_ff,max_nb_vars = 7,dropout=0.2):
        super().__init__()
        self.mlp = MLP([d_model*max_nb_vars,d_model,d_model], dropout)
        self.attention = MultiHeadAttention(n_heads, d_model)
        self.ffn = FFN(d_model, d_ff, d_model)
        self.add_norm1 = AddAndNorm(d_model)
        self.add_norm2 = AddAndNorm(d_model)

    def forward(self, x, mask=None):

        #x的维度应该是（batch_size,nb_samples,max_nb_vars,d_model）  已经用cell-mlp处理了
        x_flat = torch.flatten(x, start_dim=2)

        mlp_output = self.mlp(x_flat)
        #上面的部分时模型专属，处理出来的x(batch_size,nb_samples,d_model)

        attention_output = self.attention(mlp_output,mlp_output,mlp_output, mask).unsqueeze(2)

        x1 = self.add_norm1(x, attention_output)
        ffn_output = self.ffn(x1)
        encoder_output = self.add_norm2(x1, ffn_output)
        return encoder_output

In [10]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads,d_ff):
        super().__init__()
        self.self_attention = MultiHeadAttention(n_heads, d_model)
        self.cross_attention = MultiHeadAttention(n_heads, d_model)

        self.add_norm1 = AddAndNorm(d_model)
        self.add_norm2 = AddAndNorm(d_model)
        self.add_norm3 = AddAndNorm(d_model)

        self.ffn = FFN(d_model, d_ff, d_model)

    def forward(self, memory, tgt ,src_mask=None,tgt_mask=None):
        m = memory
        self_attention_output = self.self_attention(tgt,tgt,tgt, tgt_mask)
        y1 = self.add_norm1(tgt, self_attention_output)
        cross_attention_output = self.cross_attention(tgt,m,m,src_mask)
        y2 = self.add_norm2(y1,cross_attention_output)
        ffn_output = self.ffn(y2)
        decoder_output = self.add_norm3(y2,ffn_output)
        return decoder_output

In [11]:
class Encoder(nn.Module):
    def __init__(self, d_model, n_heads,d_ff,N_encoder_layers,dropout=0.2):
        super().__init__()
        self.encoder = nn.ModuleList()
        for i in range(N_encoder_layers):
            self.encoder.append(EncoderLayer(d_model, n_heads, d_ff))

        self.last_mlp = MLP([d_model, d_model], dropout)
        self.cell_mlp = MLP([1, d_model,d_model], dropout)

    def forward(self, x, mask=None):
        #原始的x形状  batch_size,nb_samples,max_nb_vars
        x = self.cell_mlp(x.unsqueeze(-1))
        #现在的x形状  batch_size,nb_samples,max_nb_vars,d_model
        for layer in self.encoder:
            x = layer(x, mask)

        x = self.last_mlp(x)
        #现在的x形状不改变
        x = torch.max(x,dim=1)[0]
        #这是我们失去了第一维度，于是batch_size,max_nb_vars,d_model
        return x

class Decoder(nn.Module):
    def __init__(self, d_model, n_heads,d_ff,N_decoder_layers):
        super().__init__()
        self.decoder = nn.ModuleList()
        for i in range(N_decoder_layers):
            self.decoder.append(DecoderLayer(d_model, n_heads, d_ff))

    def forward(self, enc_out ,tgt , src_mask=None,tgt_mask=None):
        for layer in self.decoder:
            tgt = layer(enc_out,tgt ,src_mask=src_mask, tgt_mask=tgt_mask)

        return tgt

class Transformer(nn.Module):
    def __init__(self, d_model, n_heads,d_ff,vocab_size,N_encoder_layers,N_decoder_layers,dropout=0.2):
        super().__init__()
        self.encoder = Encoder(d_model, n_heads,d_ff,N_encoder_layers,dropout)
        self.decoder = Decoder(d_model, n_heads,d_ff,N_decoder_layers)
        self.output_layer = nn.Sequential(nn.Linear(d_model, vocab_size),
                                          )
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model, padding_idx=22)

    def forward(self,src,tgt,src_mask=None,tgt_mask=None):
        tgt = self.embedding(tgt)
        memory = self.encoder(src,src_mask)
        decoder_output = self.decoder(memory,tgt,src_mask = src_mask,tgt_mask=tgt_mask)
        output = self.output_layer(decoder_output)
        return output  #-->batch_size,query_seq_length,vocab_size,按照我们的模型来讲就是b,32,21

In [12]:
import numpy as np


decoder_vocab = np.array([
    ['add', 4, 2],  # 二元操作符：加法
    ['mul', 6, 2],  # 二元操作符：乘法
    ['sub', 3, 2],
    ['sin', 1, 1],  # 一元操作符：正弦函数
    ['cos', 1, 1],  # 一元操作符：余弦函数
    ['log', 2, 1],  # 一元操作符：对数
    ['exp', 2, 1],  # 一元操作符：指数
    ['neg', 0, 1],  # 一元操作符：取负（权重为0表示此处不做采样，可根据需要调整）
    ['inv', 3, 1],  # 一元操作符：求倒数
    ['sq', 2, 1],   # 一元操作符：平方
    ['cb', 0, 1],   # 一元操作符：立方（权重为0暂不采样）
    ['sqrt', 2, 1], # 一元操作符：平方根
    ['cbrt', 0, 1], # 一元操作符：立方根（权重为0暂不采样）
    ['C', 8, 0],    # 叶子节点：常数
    ['x1', 8, 0],   # 叶子节点：变量1
    ['x2', 8, 0],   # 叶子节点：变量2
    ['x3', 4, 0],   # 叶子节点：变量3
    ['x4', 4, 0],   # 叶子节点：变量4
    ['x5', 2, 0],   # 叶子节点：变量5
    ['x6', 2, 0],   # 叶子节点：变量6
    ['<SOS>',0,0],
    ['<EOS>',0,0],
    ['<PAD>',0,0],
])


from _utils import *
decoder_vocab = create_id_vocab(decoder_vocab)
decoder_vocab['<PAD>']


22

In [13]:
from torch.utils.data import DataLoader , Dataset

class MyDataset(Dataset):
    def __init__(self,):
        self.datasets = datasets_generator(N_orig=1000,repeat_sampling=10)
        self.datas = [item[0] for item in self.datasets]
        self.labels = [item[1] for item in self.datasets]


    def __len__(self):
        return len(self.datas)

    def __getitem__(self, idx):
        data = torch.tensor(self.datas[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return data, label


def my_collate_fn(batch, pad_idx=22, max_length=32):
    # batch 是一个 list，每个元素是 (data, label)
    datas, labels = zip(*batch)  # 解包
    datas = torch.stack([torch.tensor(d, dtype=torch.float32) for d in datas])



    # padding 到 max_length
    padded_labels = []
    for seq in labels:
        seq = torch.tensor(seq, dtype=torch.long)
        if len(seq) >= max_length:
            padded_seq = seq[:max_length]
        else:
            padding = [pad_idx] * (max_length - len(seq))
            padded_seq = torch.cat([seq, torch.tensor(padding, dtype=torch.long)])
        padded_labels.append(padded_seq)

    padded_labels = torch.stack(padded_labels)

    return datas, padded_labels





In [14]:
loader =  DataLoader(MyDataset(),batch_size=batch_size,shuffle=True,collate_fn=my_collate_fn)

原始生成表达式数： 1000
过滤后表达式数： 101
去重后唯一表达式数： 97


  return x4*exp(C**2*x1*(x6**2 + 1)/x6)/(-log(x1) - 1/2*log(x2))
  return x4*exp(C**2*x1*(x6**2 + 1)/x6)/(-log(x1) - 1/2*log(x2))
  return x4*exp(C**2*x1*(x6**2 + 1)/x6)/(-log(x1) - 1/2*log(x2))
  return x4*exp(C**2*x1*(x6**2 + 1)/x6)/(-log(x1) - 1/2*log(x2))
  return x4*exp(C**2*x1*(x6**2 + 1)/x6)/(-log(x1) - 1/2*log(x2))
  return x4*exp(C**2*x1*(x6**2 + 1)/x6)/(-log(x1) - 1/2*log(x2))
  return x4*exp(C**2*x1*(x6**2 + 1)/x6)/(-log(x1) - 1/2*log(x2))
  return x4*exp(C**2*x1*(x6**2 + 1)/x6)/(-log(x1) - 1/2*log(x2))
  return x4*exp(C**2*x1*(x6**2 + 1)/x6)/(-log(x1) - 1/2*log(x2))
  return x4*exp(C**2*x1*(x6**2 + 1)/x6)/(-log(x1) - 1/2*log(x2))
  return (C*x2*(sqrt(x2) + x2) + x1)*exp(x1**2*x3**2)
  return (C*x2*(sqrt(x2) + x2) + x1)*exp(x1**2*x3**2)
  return (C*x2*(sqrt(x2) + x2) + x1)*exp(x1**2*x3**2)
  return (C*x2*(sqrt(x2) + x2) + x1)*exp(x1**2*x3**2)
  return (C*x2*(sqrt(x2) + x2) + x1)*exp(x1**2*x3**2)
  return (C*x2*(sqrt(x2) + x2) + x1)*exp(x1**2*x3**2)
  return (C*x2*(sqrt(x2) +

采样到的数据集数： 713


  return sqrt(C) + x3 + x4
  return sqrt(C) + x3 + x4
  return sqrt(C) + x3 + x4
  return sqrt(C) + x3 + x4
  return sqrt(C) + x3 + x4
  return sqrt(C) + x3 + x4


In [15]:
from torch.optim import Adam
import matplotlib.pyplot as plt

model = Transformer(d_model =d_model,n_heads = 4, d_ff = d_ff,N_encoder_layers=N_encoder_layers,N_decoder_layers = N_decoder_layers,vocab_size=  vocab_size)

criterion = nn.CrossEntropyLoss(ignore_index=22)
optimizer = Adam(model.parameters(),lr = 0.0001)

In [16]:


def train(num_epochs):
    model.train()
    pad_id = decoder_vocab['<PAD>']
    eos_id = decoder_vocab['<EOS>']
    sos_id = decoder_vocab['<SOS>']
    model.train()
    losses = []
    epoches = torch.arange(1, num_epochs+1)
    for epoch in range(num_epochs):
        running_loss = 0.0
        for idx, (x,y) in enumerate(loader):
            tem_batch_size = y.size(0)
            pad_idx = torch.where(y == pad_id)[0][0]
            eos = torch.full((tem_batch_size,1), eos_id, dtype=torch.long)
            sos = torch.full((tem_batch_size,1), sos_id, dtype=torch.long)


            labels = torch.cat([y[:,:pad_idx] ,eos,y[:,pad_idx+1:]],dim = 1)
            labels = labels
            decoder_input = torch.cat([sos,y[:,:-1]],dim = 1)

            optimizer.zero_grad()
            x,labels,decoder_input = x.to(device),labels.to(device),decoder_input.to(device)
            y_pred = model(x,decoder_input)


            y_pred = y_pred.view(-1, y_pred.size(-1))  # [batch * seq_len, vocab_size]
            labels = labels.view(-1)                  # [batch * seq_len]

            loss = criterion(y_pred, labels)

            loss.backward()
            optimizer.step()
            running_loss += loss.item()*x.size(0)
            if (idx+1) % 1 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}],Batches[{idx+1}/{len(loader)}], Loss: {running_loss/((idx+1)*(x.size(0))):.3f}")
        losses.append(running_loss/len(loader.dataset))
    plt.figure(figsize=(6, 4))
    plt.plot(epoches, losses)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()

In [17]:
for name, param in model.named_parameters():
    if torch.isnan(param).any():
        print(f"NaN detected in parameter: {name}")

In [None]:
# 如果你在 Jupyter 或脚本里运行：
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

device = torch.device('cuda')
model = model.to(device)

train(num_epochs = 1)

encoder.encoder.0.mlp.layers.0.weight is on cuda:0
encoder.encoder.0.mlp.layers.0.bias is on cuda:0
encoder.encoder.0.mlp.layers.3.weight is on cuda:0
encoder.encoder.0.mlp.layers.3.bias is on cuda:0
encoder.encoder.0.attention.linear_q.weight is on cuda:0
encoder.encoder.0.attention.linear_q.bias is on cuda:0
encoder.encoder.0.attention.linear_k.weight is on cuda:0
encoder.encoder.0.attention.linear_k.bias is on cuda:0
encoder.encoder.0.attention.linear_v.weight is on cuda:0
encoder.encoder.0.attention.linear_v.bias is on cuda:0
encoder.encoder.0.attention.linear_o.weight is on cuda:0
encoder.encoder.0.attention.linear_o.bias is on cuda:0
encoder.encoder.0.ffn.dense1.weight is on cuda:0
encoder.encoder.0.ffn.dense1.bias is on cuda:0
encoder.encoder.0.ffn.dense2.weight is on cuda:0
encoder.encoder.0.ffn.dense2.bias is on cuda:0
encoder.encoder.0.add_norm1.layernorm.weight is on cuda:0
encoder.encoder.0.add_norm1.layernorm.bias is on cuda:0
encoder.encoder.0.add_norm2.layernorm.weight i

  datas = torch.stack([torch.tensor(d, dtype=torch.float32) for d in datas])
  seq = torch.tensor(seq, dtype=torch.long)


Epoch [1/1],Batches[1/23], Loss: 3.471
Epoch [1/1],Batches[2/23], Loss: 3.139
Epoch [1/1],Batches[3/23], Loss: 2.951
Epoch [1/1],Batches[4/23], Loss: 2.836
Epoch [1/1],Batches[5/23], Loss: 2.736
Epoch [1/1],Batches[6/23], Loss: 2.670
Epoch [1/1],Batches[7/23], Loss: 2.629
Epoch [1/1],Batches[8/23], Loss: 2.585
Epoch [1/1],Batches[9/23], Loss: 2.536
Epoch [1/1],Batches[10/23], Loss: 2.489
Epoch [1/1],Batches[11/23], Loss: 2.443
Epoch [1/1],Batches[12/23], Loss: 2.398
Epoch [1/1],Batches[13/23], Loss: 2.361
Epoch [1/1],Batches[14/23], Loss: 2.317
Epoch [1/1],Batches[15/23], Loss: 2.286
Epoch [1/1],Batches[16/23], Loss: 2.254
Epoch [1/1],Batches[17/23], Loss: 2.212
Epoch [1/1],Batches[18/23], Loss: 2.181
Epoch [1/1],Batches[19/23], Loss: 2.156
Epoch [1/1],Batches[20/23], Loss: 2.127
Epoch [1/1],Batches[21/23], Loss: 2.104
Epoch [1/1],Batches[22/23], Loss: 2.091
Epoch [1/1],Batches[23/23], Loss: 7.188


In [None]:
import matplotlib.pyplot as plt
import torch
x = torch.tensor([1,2,3,4])
y = torch.tensor([2,3,4,5])
plt.plot(x,y)
plt.show()