In [1]:
%load_ext autoreload
%autoreload 2

import os
import zipfile

if not os.path.exists("datasets/"):
    with zipfile.ZipFile("Multi30K.zip", "r") as zip_ref:
        zip_ref.extractall()

In [2]:
from transformers import GPT2Tokenizer

# 1. 初始化GPT2 Tokenizer
en_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
de_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# 添加特殊标记
special_tokens = {"bos_token": "<sos>", "eos_token": "<eos>", "pad_token": "<pad>"}
en_tokenizer.add_special_tokens(special_tokens)
de_tokenizer.add_special_tokens(special_tokens)


3

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import pytorch_lightning as pl

# 1. 自定义数据集
class Multi30KDataset(Dataset):
    def __init__(self, en_path, de_path, en_tokenizer, de_tokenizer):
        self.en_sentences = self._read_file(en_path)
        self.de_sentences = self._read_file(de_path)
        self.en_tokenizer = en_tokenizer
        self.de_tokenizer = de_tokenizer
        assert len(self.en_sentences) == len(self.de_sentences), "数据不匹配！"

    def _read_file(self, path):
        with open(path, 'r', encoding='utf-8') as f:
            return [line.strip() for line in f]

    def __len__(self):
        return len(self.en_sentences)

    def __getitem__(self, idx):
        en_encoded = self.en_tokenizer(
            self.en_sentences[idx],
            return_tensors="pt",
            padding=False,
            truncation=True,
            add_special_tokens=True,
        )["input_ids"].squeeze(0)

        de_encoded = self.de_tokenizer(
            self.de_sentences[idx],
            return_tensors="pt",
            padding=False,
            truncation=True,
            add_special_tokens=True,
        )["input_ids"].squeeze(0)

        return en_encoded, de_encoded

# 2. 定义collate_fn
def collate_fn(batch):
    en_batch, de_batch = zip(*batch)
    en_batch = pad_sequence(en_batch, batch_first=True, padding_value=en_tokenizer.pad_token_id)
    de_batch = pad_sequence(de_batch, batch_first=True, padding_value=de_tokenizer.pad_token_id)
    return en_batch, de_batch

# 3. 定义LightningDataModule
class Multi30KDataModule(pl.LightningDataModule):
    def __init__(self, en_file_path, de_file_path, en_tokenizer, de_tokenizer, batch_size=32):
        super().__init__()
        self.en_file_path = en_file_path
        self.de_file_path = de_file_path
        self.en_tokenizer = en_tokenizer
        self.de_tokenizer = de_tokenizer
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.dataset = Multi30KDataset(self.en_file_path, self.de_file_path, self.en_tokenizer, self.de_tokenizer)

    def train_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn)

    def val_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn)

    def test_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn)

# 4. 初始化数据模块
en_file_path = 'datasets/train/train.en'
de_file_path = 'datasets/train/train.de'
batch_size = 32

data_module = Multi30KDataModule(en_file_path, de_file_path, en_tokenizer, de_tokenizer, batch_size)

# 5. 测试数据加载器
data_module.setup()
for en_batch, de_batch in data_module.train_dataloader():
    print("English batch shape:", en_batch.shape)
    print("German batch shape:", de_batch.shape)
    print("English batch example (tokens):", en_batch[0])
    print("German batch example (tokens):", de_batch[0])
    print("Decoded English:", en_tokenizer.decode(en_batch[0]))
    print("Decoded German:", de_tokenizer.decode(de_batch[0]))
    break

English batch shape: torch.Size([32, 33])
German batch shape: torch.Size([32, 55])
English batch example (tokens): tensor([   32,  2415,  2832,  1223,   284,   257,  4675, 26960, 12049,   287,
          257,  2330, 33323,    13, 50259, 50259, 50259, 50259, 50259, 50259,
        50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
        50259, 50259, 50259])
German batch example (tokens): tensor([   36,   500, 39313,    84,   302, 30830,   304,  7749, 10255,   304,
         7749,   356,    72, 39683,   368, 21039, 33255,   307,    74,   293,
          312,   316,   268, 15195, 39683,   268,    74,  9116,    77,   301,
         1754,  2123,  9776,    13, 50259, 50259, 50259, 50259, 50259, 50259,
        50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
        50259, 50259, 50259, 50259, 50259])
Decoded English: A woman hands something to a street performer dressed in a white gown.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad

In [None]:
class pl_transformer(pl.LightningDataModule)
    def __init__(self, en_tokenizer, de_tokenizer, batch_size=32):
        super().__init__()
        self.en_tokenizer = en_tokenizer
        self.de_tokenizer = de_tokenizer
        self.batch_size = batch_size