# Step0: Set up Environment

In [1]:
!pip install spacy --quiet
!python -m spacy download en_core_web_lg

import json
import re
import unicodedata
import spacy
from spacy.language import Language
from spacy.tokens import Doc
from spacy.tokenizer import Tokenizer
import torch
import torch.nn as nn
from collections import Counter
import torch
from torch.nn.utils.rnn import pad_sequence
import re
import spacy
from collections import Counter
from torch.utils.data import Dataset, DataLoader


[notice] A new release of pip is available: 24.2 -> 25.1
[notice] To update, run: python.exe -m pip install --upgrade pip

[notice] A new release of pip is available: 24.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


Collecting en-core-web-lg==3.8.0
  Using cached https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.8.0/en_core_web_lg-3.8.0-py3-none-any.whl (400.7 MB)
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_lg')


In [2]:
#====================================================================================
# 设置device
DEVICE = 'cpu'
if torch.cuda.is_available():
    DEVICE = "cuda"
print(f"Using {DEVICE} device")

Using cuda device


# Step1: Pre-process data

#### 1. `processed_dataset`
- **Type:** `List[Dict]`
- **Purpose:** This is the core preprocessed data, where each entry represents one sentence from the dataset along with:
  - Variable placeholders (e.g., `$city`)
  - Variable-replaced versions
  - SQL queries (with and without placeholders)
  - Tagging labels

- **Use cases:**
  - Input for training models (e.g., slot tagging, SQL generation)
  - Dataset analysis and debugging
  - Evaluation of preprocessing or model predictions

---

#### 2. `variable_names`
- **Type:** `Set[str]`
- **Purpose:** Collects all unique variable names used across the dataset (e.g., `$city`, `$airline`).

- **Use cases:**
  - Build a vocabulary of variables
  - Label/tagging schemes for models
  - Analysis of variable frequency or diversity

---

#### 3. `sql_templates`
- **Type:** `Set[str]`
- **Purpose:** Stores all unique shortest SQL templates (with variables, not replaced), one per sample.

- **Use cases:**
  - Template classification models
  - SQL structure coverage analysis
  - Slot-filling frameworks (template → fill variables → complete SQL)

---

processed_dataset have below Columns
| Key                              | Description                                                                 |
|----------------------------------|-----------------------------------------------------------------------------|
| `text_with_vars`                | Original sentence text with variable placeholders (e.g., `$city`), cleaned |
| `text_with_vars_replaced`       | Sentence text with variables replaced by their actual values                |
| `sentence_var_tagging_labels`   | Token-wise labels matching `text_with_vars`, showing variable names or `"-"`|
| `vars_metadata`                 | Metadata of variables in the sample, with `"location"` field removed        |
| `variables`                     | Dictionary mapping variable placeholders to their actual values             |
| `sql_with_vars`                 | All SQL templates with variables (not replaced)                             |
| `shortest_sql_with_vars`        | The shortest SQL template with variables (not replaced)                     |
| `sql_with_vars_replaced`        | All SQL templates with variables replaced by actual values                  |
| `shortest_sql_with_vars_replaced`| The shortest SQL template with variables replaced by values                 |
| `query_split`                   | Indicates which query split this sample belongs to (e.g., `"train"`)       |
| `question_split`                | Indicates which question split this sentence belongs to (e.g., `"train"`)  |

In [3]:
def unicodeToAscii(s):
    # Convert a Unicode string 's' to plain ASCII.
    # This is done by first normalizing the string into its decomposed form using 'NFD',
    # which separates characters from their accents. Then, it filters out all nonspacing marks (Mn).
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Does not allow sql to have multiple space between each word
def normalize_whitespace(text):
    return re.sub(r'\s+', ' ', text).strip()

def preprocess_sentence(s:str) -> str:
    """
    Preprocesses sentence text for consistency
    """
    s = s.strip()
    s = normalize_whitespace(s)
    s = unicodeToAscii(s)
    s = s.strip()
    return s

def preprocess_dataset(dataset_loc = "atis.json",split_type=None, split=['dev', 'test', 'train']):
    
    # Read Dataset JSON file
    with open(dataset_loc) as f:
        dataset_json = json.load(f)

    processed_dataset = []
    variable_names = set()
    sql_templates = set()

    for sample in dataset_json:
        processed_sample = {}

        # Preprocess sql queries
        sql = [preprocess_sentence(query) for query in sample['sql']]

        # All valid sql queries for this examples sorted by their length
        sql = sorted(sql, key=len)

        # Adds shorests sql template to the set of sql templates
        sql_templates.add(sql[0])
        

        # Dictionary for variables/placeholders metadata
        variables_metadata = sample["variables"]

        # Delete 'location' key from variables dictionary
        # variable_type_mapping = {var['name']:var['type'] for var in variables_metadata}
        for var in variables_metadata:
            # Add current variable to set of all possible variable names
            variable_names.add(var.get("name"))
            var.pop('location', None)
        # query split for this sample
        query_split = sample['query-split']

        # Skips sample if its not the specified split_type or split
        if(split_type == "query"):
            if(query_split not in split):
                continue

        # Process each sentence
        for sentence in sample['sentences']:
            # Skips sample if its not the specified split_type or split
            if(split_type == "question"):
                if(sentence['question-split'] not in split):
                    continue
            # variables/placeholder mapping dictionary
            variables = sentence['variables']

            # Sentence text with variables/placeholders
            text_with_vars = preprocess_sentence(sentence['text'])

            # Replacing variables/placeholders in current sentence and sql query with their values from the variables dictionary
            text_with_vars_replaced = text_with_vars
            sql_with_vars_replaced = sql

            # Replace sentence and all sql variables with their values
            for var in variables:
                text_with_vars_replaced = text_with_vars_replaced.replace(var,variables[var])
                sql_with_vars_replaced = [query.replace(var,variables[var]) for query in sql_with_vars_replaced]

            # Taggingg expected output
            sentence_var_tagging_labels = []
            for word in text_with_vars.split():
                if(word in variables):
                    sentence_var_tagging_labels.append(word)
                else:
                    sentence_var_tagging_labels.append("-")

            # Appends preprocessed dictionary of current sentence to the processesed_dataset list
            processed_dataset.append({
                "text_with_vars":text_with_vars,
                "text_with_vars_replaced":text_with_vars_replaced,
                "sentence_var_tagging_labels":sentence_var_tagging_labels,
                "vars_metadata":variables_metadata,
                "variables":variables,
                "sql_with_vars": sql,
                "shortest_sql_with_vars":sql[0],
                "sql_with_vars_replaced": sql_with_vars_replaced,
                "shortest_sql_with_vars_replaced":sql_with_vars_replaced[0],
                "query_split":sample['query-split'],
                "question_split":sentence['question-split']
            })
    
    return processed_dataset,variable_names,sql_templates
    
    
train,var_names,sql_templates = preprocess_dataset(dataset_loc="atis.json", split_type="question", split=["train"])

# var_names contains all unique variable names in the samples processed
print("var dict: ", list(var_names))

# sql_templates contains all unique SQL templates in the samples processed
print("sql dict: ", list(sql_templates))

print("=================================================================")
sample = train[1]
print(f"Sentence with variables:\n{sample['text_with_vars']}\n")
print(f"Sentence with variables replaced by their values:\n{sample['text_with_vars_replaced']}\n")
print(f"Variable tagging expected output:\n{sample['sentence_var_tagging_labels']}\n")
print(f"Variables mapping in sentence:\n{sample['variables']}\n")
print(f"Metadat about variables in sentence:\n{sample['vars_metadata']}\n")
print(f"Shortest SQL query with variables:\n{sample['shortest_sql_with_vars']}\n")
print(f"Shortest SQL query with variables replaced by their values:\n{sample['shortest_sql_with_vars_replaced']}\n")
print(f"Which query split is this sample part of?:\n{sample['query_split']}\n")
print(f"Which question split is this sample part of?:\n{sample['question_split']}\n")


var dict:  ['round_trip_required0', 'meal_description0', 'city_name3', 'aircraft_code0', 'connections0', 'airport_name0', 'arrival_time0', 'year1', 'round_trip_cost0', 'airline_code0', 'transport_type0', 'basic_type0', 'manufacturer0', 'flight_number0', 'departure_time3', 'meal_code0', 'transport_type1', 'year0', 'days_code0', 'day_name4', 'airport_code1', 'propulsion0', 'meal_code1', 'state_name2', 'country_name0', 'day_name2', 'day_name0', 'class_type1', 'city_name2', 'day_name3', 'city_name0', 'departure_time1', 'restriction_code0', 'airport_code0', 'arrival_time1', 'stops0', 'airline_code2', 'month_number0', 'airline_name0', 'booking_class1', 'class_type0', 'economy0', 'fare_basis_code1', 'fare_basis_code0', 'airline_code1', 'booking_class0', 'departure_time2', 'state_code0', 'arrival_time2', 'state_name1', 'state_name0', 'day_number0', 'day_number1', 'day_name1', 'discounted0', 'flight_days0', 'one_direction_cost0', 'state_code1', 'departure_time0', 'flight_number1', 'city_name1',

# Step 2 Classification 

Now we have `train`, `var_names`, `sql_templates`

# Step 3 Generation

1. tokenize both input and output
2. pad both input and output to make sure they have equal length
3. build LSTM model for both encoder and decoder
4. train the model 

In [4]:
question_train_data, question_train_vars, question_train_sqls = preprocess_dataset(dataset_loc="atis.json", split_type="question", split=["train"])
question_test_data, question_test_vars, question_test_sqls = preprocess_dataset(dataset_loc="atis.json", split_type="question", split=["test"])
question_dev_data, question_dev_vars, question_dev_sqls = preprocess_dataset(dataset_loc="atis.json", split_type="question", split=["dev"])


query_train_data, query_train_vars, query_train_sqls = preprocess_dataset(dataset_loc="atis.json", split_type="query", split=["train"])
query_test_data, query_test_vars, query_test_sqls = preprocess_dataset(dataset_loc="atis.json", split_type="query", split=["test"])
query_dev_data, query_dev_vars, query_dev_sqls = preprocess_dataset(dataset_loc="atis.json", split_type="query", split=["dev"])

In [5]:
# 假设有以下参数
input_dim = 1000  # 输入词汇表大小
output_dim = 1000  # 输出词汇表大小
emb_dim = 64  # 嵌入维度
hid_dim = 128  # 隐藏层维度
n_layers = 2  # LSTM 层数
dropout = 0.5  # Dropout 比例
learning_rate = 5

## 3.1 Data loader

Set devices

In [6]:
#====================================================================================
# 设置device
device = 'cpu'
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    pass
    #device = "mps"
print(f"Using {device} device")

Using cuda device


set tokenizer

![image.png](https://course.spacy.io/pipeline.png)

In [7]:
#====================================================================================
# Define tokenizer
# 这会加载默认的 spaCy NLP 管道，包括 tokenizer、tagger、parser、ner 等
nlp = spacy.load('en_core_web_lg')

# Create a custom component to merge entities
@Language.component("entity_merger")
def entity_merger(doc):
    """
    Custom component of the spacy nlp pipeline which merges geographical location entity tokens into a single token
    For example: 'New York' would noramlly be split into 2 tokens 'New' and 'York' but this will combine into a single 'New York' token
    This is implemented because city_name type variables could have the value 'New York' and for effective tagging we aim to keep the tokenisation scheme consistent to the dataset
    """
    # Iterate over the entities in reverse order (to avoid index issues when merging)
    with doc.retokenize() as retokenizer:
        for ent in reversed(list(doc.ents)):
            # Merge the entity tokens into one token
            if(ent.label_ in ["GPE"]):
                attrs = {"LEMMA": ent.text}
                retokenizer.merge(ent, attrs=attrs)
    return doc

# Add the custom component after NER
nlp.add_pipe("entity_merger", after="ner")


print(list(nlp("Let's meet in New York, at 10:00 a.m.!")))
# 替换默认的 tokenizer 为自定义的只按空格切分的 tokenizer
def whitespace_tokenizer(nlp):
    # Create a custom tokenizer that splits only on whitespace
    return Tokenizer(nlp.vocab, token_match=re.compile(r'\S+').match)

nlp.tokenizer = whitespace_tokenizer(nlp)

print(list(nlp("Let's meet in New York, at 10:00 a.m.!")))

[Let, 's, meet, in, New York, ,, at, 10:00, a.m., !]
[Let's, meet, in, New, York,, at, 10:00, a.m.!]


Define Dataloader

In [8]:

#====================================================================================
'''
We need to build vocab for input and output respectively!!! Since natural sentence is 
different from sql query!!!! (if both input and output have similar type such as use another 
way to explain a sentence, then we need not to use 2 vocab)
'''
# Prepare Vocab
def build_vocab(dataset, nlp, special_tokens=["<pad>", "<sos>", "<eos>", "<unk>"]):
    input_counter = Counter()
    output_counter = Counter()
    
    for sample in dataset:
        # 输入部分统计
        src_text = sample["text_with_vars_replaced"]
        src_tokens = [token.text for token in nlp(src_text)]
        input_counter.update(src_tokens)
        
        # 输出部分统计
        trg_text = sample["shortest_sql_with_vars_replaced"]
        trg_tokens = [token.text for token in nlp(trg_text)]
        output_counter.update(trg_tokens)
    
    input_vocab = {token: idx for idx, token in enumerate(special_tokens + list(input_counter.keys()))}
    output_vocab = {token: idx for idx, token in enumerate(special_tokens + list(output_counter.keys()))}
    
    return input_vocab, output_vocab

input_vocab, output_vocab = build_vocab(train, nlp)

#====================================================================================
# Define Dataloader
class generationDataset(Dataset):
    def __init__(self, data=train, input_vocab=input_vocab, output_vocab=output_vocab):
        self.data = data
        self.input_vocab = input_vocab
        self.output_vocab = output_vocab
        self.nlp = nlp

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        src_text = self.data[idx]["text_with_vars_replaced"]
        trg_text = self.data[idx]["shortest_sql_with_vars_replaced"]

        # Tokenize + <sos> <eos>
        src_tokens = ['<sos>'] + [token.text for token in self.nlp(src_text)] + ['<eos>']
        trg_tokens = ['<sos>'] + [token.text for token in self.nlp(trg_text)] + ['<eos>']

        # convert token to id in vocab
        src_ids = [self.input_vocab.get(token, self.input_vocab['<unk>']) for token in src_tokens]
        trg_ids = [self.output_vocab.get(token, self.output_vocab['<unk>']) for token in trg_tokens]

        return torch.tensor(src_ids), torch.tensor(trg_ids)

# 3. we need to pad both input and output
def get_collate_fn(input_vocab=input_vocab, output_vocab=output_vocab, device=device):
    def collate_fn(batch):
        src_batch, trg_batch = zip(*batch)
        
        # 使用 pad_sequence 对输入和输出进行填充
        src_batch = pad_sequence(src_batch, padding_value=input_vocab['<pad>'], batch_first=True)
        trg_batch = pad_sequence(trg_batch, padding_value=output_vocab['<pad>'], batch_first=True)
        
        # 将填充后的数据移到设备上
        src_batch = src_batch.to(device)
        trg_batch = trg_batch.to(device)
        
        return src_batch, trg_batch
    
    return collate_fn

# 4. using dataloader to get many batch of data for training! 
def get_dataloader(data, input_vocab = input_vocab, output_vocab = output_vocab):
    dataset = generationDataset(data, input_vocab, output_vocab)
    collate_fn = get_collate_fn(input_vocab, output_vocab)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    return dataset, dataloader

question_train_data_iter, question_train_data_dataloader = get_dataloader(question_train_data)
question_test_data_iter, question_test_data_dataloader = get_dataloader(question_test_data)
question_dev_data_iter, question_dev_data_dataloader = get_dataloader(question_dev_data)

query_train_data_iter, query_train_data_dataloader = get_dataloader(query_train_data)
query_test_data_iter, query_test_data_dataloader = get_dataloader(query_test_data)
query_dev_data_iter, query_dev_data_dataloader = get_dataloader(query_dev_data)

training_text_batch, training_label_batch = next(iter(question_train_data_iter))
print("First text:")
print(training_text_batch[0])

First text:
tensor(1)


## Step 3.2 LSTM

LSTM model

In [9]:
class ContextualEmbedding(nn.Module):
    def __init__(self, vocab_size, embedding_dim, max_len=512):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=input_vocab['<pad>'])
        self.position_embedding = nn.Embedding(max_len, embedding_dim)

    def forward(self, x):
        positions = torch.arange(0, x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
        x = self.token_embedding(x) + self.position_embedding(positions)
        return x
embedder = ContextualEmbedding(len(input_vocab), embedding_dim=emb_dim).to(device)

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers=1, dropout=0.5):
        super().__init__()
        # self.embedding = nn.Embedding(input_dim, emb_dim)
        self.embedding = embedder
        self.lstm = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout, batch_first=True, bidirectional=True)

    def forward(self, src):
        embedded = self.embedding(src)  # [batch size, src len, emb dim]

        # 双向LSTM的输出
        outputs, (hidden, cell) = self.lstm(embedded)

        # 改这里！拼接而不是mean
        hidden = hidden.view(hidden.size(0) // 2, 2, hidden.size(1), hidden.size(2))
        hidden = torch.cat((hidden[:,0,:,:], hidden[:,1,:,:]), dim=2)

        cell = cell.view(cell.size(0) // 2, 2, cell.size(1), cell.size(2))
        cell = torch.cat((cell[:,0,:,:], cell[:,1,:,:]), dim=2)

        return hidden, cell


class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers=1, dropout=0.5):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hid_dim * 2, n_layers, dropout=dropout, batch_first=True)
        self.fc_out = nn.Linear(hid_dim * 2, output_dim)


    def forward(self, input, hidden, cell):
        input = input.unsqueeze(1)  # [batch size, 1]
        embedded = self.embedding(input)  # [batch size, 1, emb dim]

        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        prediction = self.fc_out(output.squeeze(1))  # [batch size, output dim]
        return prediction, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.fc_out.out_features
        
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)
        
        hidden, cell = self.encoder(src)  # 返回的是双向LSTM处理后的hidden和cell
        
        input = trg[:, 0]  # first input to the decoder is the <sos> token
        
        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[:, t, :] = output
            
            top1 = output.argmax(1)
            input = trg[:, t] if torch.rand(1).item() < teacher_forcing_ratio else top1
        
        return outputs

train

In [10]:
def train_loop(model, dataloader, optimizer, loss_function, clip=0.5):
    '''
    clip=0.1：梯度裁剪更严格，训练过程可能会更加稳定，但可能导致训练速度较慢，尤其在较为平稳的模型中。

    clip=1：梯度裁剪较宽松，允许较大的梯度更新，训练可能会较快，但如果模型梯度爆炸问题较为严重，可能会导致不稳定。
    '''


    model.train()  # 让模型进入train模式
    epoch_loss = 0
    
    for src, trg in dataloader:
        # src: [batch size, src len]
        # trg: [batch size, trg len]

        optimizer.zero_grad()

        output = model(src, trg)  
        # output: [batch size, trg len, output vocab size]

        output_dim = output.shape[-1]
        
        # 我们要 reshape 成 [batch_size * (trg_len-1), output_dim]
        output = output[:, 1:, :].reshape(-1, output_dim)
        trg = trg[:, 1:].reshape(-1)

        # 注意，criterion已经在ignore pad了
        loss = loss_function(output, trg)
        
        loss.backward()
        
        # 梯度裁剪，防止爆炸
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
    
    return epoch_loss / len(dataloader)

Train

In [11]:


# 创建模型
encoder = Encoder(input_dim, emb_dim, hid_dim, n_layers, dropout)
decoder = Decoder(output_dim, emb_dim, hid_dim, n_layers, dropout)

# 创建 Seq2Seq 模型
model = Seq2Seq(encoder, decoder, device).to(device)
PAD_IDX = output_vocab['<pad>']
loss_function = nn.CrossEntropyLoss(ignore_index=PAD_IDX)  # 如果有填充符，使用 ignore_index 忽略该索引
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# 超参数
num_epochs = 10
clip_value = 0.5

# 训练循环
for epoch in range(num_epochs):
    # Train on question dataset
    question_epoch_loss = train_loop(model, question_train_data_dataloader, optimizer, loss_function, clip=clip_value)
    
    # Train on query dataset
    query_epoch_loss = train_loop(model, query_train_data_dataloader, optimizer, loss_function, clip=clip_value)
    
    # Combine the losses (if needed)
    total_epoch_loss = question_epoch_loss + query_epoch_loss
    
    # Optionally, you could log the individual losses too
    print(f"Epoch {epoch+1}/{num_epochs}, Question Loss: {question_epoch_loss}, Query Loss: {query_epoch_loss}, Total Loss: {total_epoch_loss}")

Epoch 1/10, Question Loss: 3.294944909565589, Query Loss: 1.5197538619009865, Total Loss: 4.814698771466576
Epoch 2/10, Question Loss: 1.0410590316442883, Query Loss: 0.9352714644362595, Total Loss: 1.9763304960805477
Epoch 3/10, Question Loss: 0.8212527367560303, Query Loss: 0.7426077954421769, Total Loss: 1.5638605321982073
Epoch 4/10, Question Loss: 0.666707393439377, Query Loss: 0.6730466781073059, Total Loss: 1.339754071546683
Epoch 5/10, Question Loss: 0.6521341241896152, Query Loss: 0.6241890284794056, Total Loss: 1.2763231526690209
Epoch 6/10, Question Loss: 0.5919445847325465, Query Loss: 0.5913747507610069, Total Loss: 1.1833193354935534
Epoch 7/10, Question Loss: 0.5474447404198787, Query Loss: 0.5498775044419119, Total Loss: 1.0973222448617905
Epoch 8/10, Question Loss: 0.5178405905471128, Query Loss: 0.5209629000812177, Total Loss: 1.0388034906283306
Epoch 9/10, Question Loss: 0.4869489064987968, Query Loss: 0.4980111793177017, Total Loss: 0.9849600858164985
Epoch 10/10, Q

In [None]:
def test_loop(dataloader, model, pad_idx, device):
    model.eval()
    total_correct = 0
    total_tokens = 0

    with torch.no_grad():
        '''
        idx 就是 batch 的编号（如 0、1、2...）

        src 是输入张量：[[1, 2], [3, 4]]

        trg 是目标标签：[0, 1]
        '''
        for idx, (src, trg) in enumerate(dataloader):
            src, trg = src.to(device), trg.to(device)  # 将数据移动到cuda
            
            # 模型预测
            output = model(src, trg[:, :-1])  # 输出: [batch_size, trg_len, output_dim]
            
            # 获取预测结果
            pred_tokens = output.argmax(dim=-1)  # shape: [batch_size, trg_len]
            
            # 去掉填充符部分
            non_pad_mask = (trg[:, 1:] != pad_idx)  # 只去除 <pad>，并排除 <sos> 和 <eos> (trg[:, 1:])
            
            # 计算正确的预测
            correct_predictions = (pred_tokens == trg[:, 1:]) & non_pad_mask  # [batch_size, seq_len]
            
            correct_count = correct_predictions.sum().item()  # 总正确数
            total_count = non_pad_mask.sum().item()  # 总有效词数（不包括pad部分）
            
            total_correct += correct_count
            total_tokens += total_count
    
    # 返回准确率
    accuracy = total_correct / total_tokens if total_tokens > 0 else 0
    return accuracy



accuracy = test_loop(query_test_data_dataloader, model, PAD_IDX, device)
print(f"Test Accuracy: {accuracy:.4f}")

accuracy = test_loop(query_train_data_dataloader, model, PAD_IDX, device)
print(f"Train Accuracy: {accuracy:.4f}")

Test Accuracy: 0.0231
Train Accuracy: 0.0251


# 3.2 LSTM with self-attention

In [13]:
class ContextualEmbedding(nn.Module):
    def __init__(self, vocab_size, embedding_dim, max_len=512):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=input_vocab['<pad>'])
        self.position_embedding = nn.Embedding(max_len, embedding_dim)

    def forward(self, x):
        positions = torch.arange(0, x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
        x = self.token_embedding(x) + self.position_embedding(positions)
        return x
embedder = ContextualEmbedding(len(input_vocab), embedding_dim=emb_dim).to(device)

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers=1, dropout=0.5):
        super().__init__()
        # self.embedding = nn.Embedding(input_dim, emb_dim)
        self.embedding = embedder
        self.lstm = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout, batch_first=True, bidirectional=True)

    def forward(self, src):
        embedded = self.embedding(src)  # [batch size, src len, emb dim]

        # 双向LSTM的输出
        outputs, (hidden, cell) = self.lstm(embedded)

        # 改这里！拼接而不是mean
        hidden = hidden.view(hidden.size(0) // 2, 2, hidden.size(1), hidden.size(2))
        hidden = torch.cat((hidden[:,0,:,:], hidden[:,1,:,:]), dim=2)

        cell = cell.view(cell.size(0) // 2, 2, cell.size(1), cell.size(2))
        cell = torch.cat((cell[:,0,:,:], cell[:,1,:,:]), dim=2)

        return hidden, cell


class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers=1, dropout=0.5):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hid_dim * 2, n_layers, dropout=dropout, batch_first=True)
        self.fc_out = nn.Linear(hid_dim * 2, output_dim)


    def forward(self, input, hidden, cell):
        input = input.unsqueeze(1)  # [batch size, 1]
        embedded = self.embedding(input)  # [batch size, 1, emb dim]

        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        prediction = self.fc_out(output.squeeze(1))  # [batch size, output dim]
        return prediction, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.fc_out.out_features
        
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)
        
        hidden, cell = self.encoder(src)  # 返回的是双向LSTM处理后的hidden和cell
        
        input = trg[:, 0]  # first input to the decoder is the <sos> token
        
        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[:, t, :] = output
            
            top1 = output.argmax(1)
            input = trg[:, t] if torch.rand(1).item() < teacher_forcing_ratio else top1
        
        return outputs
    




# 创建模型
encoder = Encoder(input_dim, emb_dim, hid_dim, n_layers, dropout)
decoder = Decoder(output_dim, emb_dim, hid_dim, n_layers, dropout)

# 创建 Seq2Seq 模型
model = Seq2Seq(encoder, decoder, device).to(device)
PAD_IDX = output_vocab['<pad>']
loss_function = nn.CrossEntropyLoss(ignore_index=PAD_IDX)  # 如果有填充符，使用 ignore_index 忽略该索引
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# 超参数
num_epochs = 10
clip_value = 0.5

# 训练循环
for epoch in range(num_epochs):
    epoch_loss = train_loop(model, train_dataloader, optimizer, loss_function, clip=clip_value)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

accuracy = test_loop(test_dataloader, model, PAD_IDX, device)
print(f"Test Accuracy: {accuracy:.4f}")

accuracy = test_loop(train_dataloader, model, PAD_IDX, device)
print(f"Train Accuracy: {accuracy:.4f}")


NameError: name 'train_dataloader' is not defined