In [1]:
import random
import torch
import re
import os
import time
import numpy as np
from torch import nn
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from collections import defaultdict
from datasets import load_dataset
from tokenizers import Tokenizer

In [2]:
data_size = 300000
vocab_size = 150000

os.environ['DATA_SIZE'] = '300000'
os.environ['VOCAB_SIZE'] = '150000'
input_file = f"data/gigaword_input_{data_size}.txt"
summary_file = f"data/gigaword_summary_{data_size}.txt"
input_texts = []
summary_texts = []


def preprocess(text):
    text = text.replace("UNK", "<unk>")
    text = re.sub(r"#\S*#", "<num>", text)
    text = re.sub(r"#", "<num>", text)
    text = text.replace("-lrb-", "(").replace("-rrb-", ")").replace("-lsb-", "[").replace("-rsb-", "]").replace("-lcb-", "{").replace("-rcb-", "}")
    return text


if os.path.exists(input_file) and os.path.exists(summary_file):
    print(f"load txt：")
    print(f"- {input_file}")
    print(f"- {summary_file}")

    with open(input_file, "r", encoding="utf-8") as f_input:
        input_texts = [line.strip() for line in f_input]

    with open(summary_file, "r", encoding="utf-8") as f_summary:
        summary_texts = [line.strip() for line in f_summary]

else:
    print("loading dataset...")
    os.makedirs("data", exist_ok=True)
    dataset = load_dataset("gigaword", trust_remote_code=True)

    def filter_short(example):
        return len(example['document'].split()) < 35 and len(example['summary'].split()) < 15

    small_dataset = dataset['train'].filter(filter_short).select(range(data_size))

    with open(input_file, "w", encoding="utf-8") as f_input, \
            open(summary_file, "w", encoding="utf-8") as f_summary:

        for example in small_dataset:
            input_text = preprocess(example['document'])
            summary_text = preprocess(example['summary'])
            f_input.write(input_text + "\n")
            f_summary.write(summary_text + "\n")
            input_texts.append(input_text)
            summary_texts.append(summary_text)
    print(f"write txt: ")
    print(f"- {input_file}")
    print(f"- {summary_file}")

load txt：
- data/gigaword_input_300000.txt
- data/gigaword_summary_300000.txt


In [3]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

device: cuda


In [4]:
%run summary_bpe.py 

data_size: 300000
vocab_size: 150000



BPE 模型已保存到 models/bpe_model_300000.json
词汇表大小: 62148
词汇表已保存到 models/bpe_vocab_300000.txt
原文: australia 's current account deficit shrunk by a record <num> billion dollars ( <num> billion us ) in the june quarter due to soaring commodity prices , figures released monday showed .
分词结果: ['australia', "'", 's', 'current', 'account', 'deficit', 'shrunk', 'by', 'a', 'record', '<num>', 'billion', 'dollars', '(', '<num>', 'billion', 'us', ')', 'in', 'the', 'june', 'quarter', 'due', 'to', 'soaring', 'commodity', 'prices', ',', 'figures', 'released', 'monday', 'showed', '.']


In [5]:
loaded_model_path = f"models/bpe_model_{data_size}.json"


def load_bpe_model(model_path):
    tokenizer = Tokenizer.from_file(model_path)

    tokenizer.pad_token = "<pad>"
    tokenizer.pad_token_id = tokenizer.token_to_id("<pad>")

    tokenizer.bos_token = "<bos>"
    tokenizer.bos_token_id = tokenizer.token_to_id("<bos>")

    tokenizer.eos_token = "<eos>"
    tokenizer.eos_token_id = tokenizer.token_to_id("<eos>")

    tokenizer.unk_token = "<unk>"
    tokenizer.unk_token_id = tokenizer.token_to_id("<unk>")

    return tokenizer


tokenizer = load_bpe_model(loaded_model_path)

In [6]:


class TextDataset(Dataset):
    def __init__(self, src_texts, tgt_texts, tokenizer):
        self.src_texts = src_texts
        self.tgt_texts = tgt_texts
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.src_texts)

    def __getitem__(self, idx):
        src_text = self.src_texts[idx]
        tgt_text = self.tgt_texts[idx]
        src_tokens = [self.tokenizer.bos_token_id] + self.tokenizer.encode(src_text).ids + [self.tokenizer.eos_token_id]
        tgt_tokens = [self.tokenizer.bos_token_id] + self.tokenizer.encode(tgt_text).ids + [self.tokenizer.eos_token_id]
        return torch.tensor(src_tokens, dtype=torch.long), torch.tensor(tgt_tokens, dtype=torch.long)


dataset = TextDataset(input_texts, summary_texts, tokenizer)


def collate_fn(batch):
    src_tokens = [item[0] for item in batch]
    tgt_tokens = [item[1] for item in batch]
    src_tokens = pad_sequence(src_tokens, batch_first=True, padding_value=tokenizer.pad_token_id)
    tgt_tokens = pad_sequence(tgt_tokens, batch_first=True, padding_value=tokenizer.pad_token_id)
    return src_tokens, tgt_tokens


train_size = int(len(dataset) * 0.9)
train_dataset, test_dataset = random_split(dataset, [train_size, len(dataset) - train_size])

In [7]:
class Head(nn.Module):
    def __init__(self, n_embd, head_size, masking=True):
        super().__init__()
        self.query = nn.Linear(n_embd, head_size)
        self.key = nn.Linear(n_embd, head_size)
        self.value = nn.Linear(n_embd, head_size)
        self.masking = masking
        if masking:
            self.register_buffer("mask", torch.tril(torch.ones(512, 512)))

    def forward(self, x, encoder_output=None, src_padding_mask=None):
        # x: (B, T, n_embd)
        # encoder_output: (B, T1, n_embd)
        # src_padding_mask: (B, T)

        B, T = x.shape[0], x.shape[1]

        q = self.query(x)  # (B, T, head_size)
        if encoder_output is not None:
            # CROSS-ATTENTION
            # encoder_output: (B, T1, n_embd)
            k = self.key(encoder_output)  # (B, T1, head_size)
            v = self.value(encoder_output)  # (B, T1, head_size)
        else:
            # SELF-ATTENTION
            k = self.key(x)
            v = self.value(x)
        attn = q @ k.transpose(-2, -1)  # (B, T, T) or (B, T, T1)
        attn = attn * (k.size(-1) ** -0.5)
        if self.masking:
            attn = attn.masked_fill(self.mask[:T, :T] == 0, float('-inf'))

        if src_padding_mask is not None:
            if encoder_output is not None:
                attn = attn.masked_fill(src_padding_mask.unsqueeze(1), float('-inf'))
            else:
                attn = attn.masked_fill(src_padding_mask.unsqueeze(1).expand(-1, T, -1), float('-inf'))

        attn_weights = F.softmax(attn, dim=-1)  # (B, T, T)

        out = attn_weights @ v  # (B, T, head_size)
        return out, attn_weights


class MultiHead(nn.Module):
    def __init__(self, n_embd, head_size, n_head, masking=True):
        super().__init__()
        self.heads = nn.ModuleList([Head(n_embd, head_size, masking) for _ in range(n_head)])
        self.fc = nn.Linear(n_head * head_size, n_embd)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x, encoder_output=None, src_padding_mask=None, return_attn=False):
        attn_outs = [head(x, encoder_output, src_padding_mask) for head in self.heads]
        attns = [out[0] for out in attn_outs]
        out = torch.cat(attns, dim=-1)  # (B, T, n_head * head_size)
        out = self.dropout(self.fc(out))  # (B, T, n_embd)
        if return_attn:
            attn_weights = [out[1] for out in attn_outs]
            return out, attn_weights
        else:
            return out


class FeedForward(nn.Module):
    def __init__(self, n_embd, dropout=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class EncoderBlock(nn.Module):
    def __init__(self, n_embd, head_size, n_head, dropout=0.3):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.sa = MultiHead(n_embd, head_size, n_head, masking=False)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ff = FeedForward(n_embd, dropout)

    def forward(self, x, src_padding_mask=None, return_attn=False):
        if return_attn:
            sa_out, attn_weights = self.sa(self.ln1(x), src_padding_mask=src_padding_mask, return_attn=True)
        else:
            sa_out = self.sa(self.ln1(x), src_padding_mask=src_padding_mask)
        x = x + sa_out
        x = x + self.ff(self.ln2(x))
        return (x, attn_weights) if return_attn else x


class DecoderBlock(nn.Module):
    def __init__(self, n_embd, head_size, n_head, dropout=0.3):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.sa = MultiHead(n_embd, head_size, n_head, masking=True)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ca = MultiHead(n_embd, head_size, n_head, masking=False)
        self.ln3 = nn.LayerNorm(n_embd)
        self.ff = FeedForward(n_embd, dropout)

    def forward(self, x, encoder_output, src_padding_mask=None, return_attn=False):
        # x: (B, T, n_embd)
        # encoder_output: (B, T1, n_embd)
        x = x + self.sa(self.ln1(x))
        if return_attn:
            ca_out, attn_weights = self.ca(self.ln2(x), encoder_output, src_padding_mask, return_attn=True)
        else:
            ca_out = self.ca(self.ln2(x), encoder_output, src_padding_mask)
            attn_weights = None
        x = x + ca_out
        x = x + self.ff(self.ln3(x))
        return (x, attn_weights) if return_attn else x


class Model(nn.Module):
    def __init__(self, n_embd, n_head, n_layers, vocab_size, dropout=0.3):
        super().__init__()
        head_size = n_embd // n_head
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.pos_embedding = nn.Embedding(512, n_embd)
        self.dropout = nn.Dropout(dropout)
        self.encoder_blocks = nn.ModuleList([EncoderBlock(n_embd, head_size, n_head, dropout) for _ in range(n_layers)])
        self.decoder_blocks = nn.ModuleList([DecoderBlock(n_embd, head_size, n_head, dropout) for _ in range(n_layers)])
        self.ln_in = nn.LayerNorm(n_embd)
        self.ln_tgt_in = nn.LayerNorm(n_embd)
        self.ln_f = nn.LayerNorm(n_embd)

    def encode(self, src, return_attn=False):
        B, T1 = src.shape
        src_padding_mask = (src == 0).to(src.device)  # (B, T1)

        # ENCODER
        tok_emb = self.token_embedding(src)  # (B, T1, n_embd)
        pos = torch.arange(0, T1, device=src.device).unsqueeze(0).repeat(B, 1)  # (B, T1)
        pos_emb = self.pos_embedding(pos)  # (B, T1, n_embd)
        x = self.ln_in(tok_emb + pos_emb)  # (B, T1, n_embd)
        x = self.dropout(x)

        self_attns = []
        # Pass padding mask to each encoder block
        for encoder_block in self.encoder_blocks:
            if return_attn:
                x, attn_weights = encoder_block(x, src_padding_mask, return_attn=True)
                self_attns.append(attn_weights)
            else:
                x = encoder_block(x, src_padding_mask)
        encoder_output = self.ln_f(x)  # (B, T1, n_embd)
        return (encoder_output, src_padding_mask, self_attns) if return_attn else (encoder_output, src_padding_mask)

    def decode(self, tgt, encoder_output, src_padding_mask, return_attn=False):
        B, T2 = tgt.shape
        tgt_tok_emb = self.token_embedding(tgt)  # (B, T2, n_embd)
        pos = torch.arange(0, T2, device=tgt.device).unsqueeze(0).repeat(B, 1)  # (B, T2)
        pos_emb = self.pos_embedding(pos)  # (B, T2, n_embd)
        x = self.ln_tgt_in(tgt_tok_emb + pos_emb)  # (B, T2, n_embd)
        x = self.dropout(x)

        cross_attns = []
        for decoder_block in self.decoder_blocks:
            if return_attn:
                x, attn_weights = decoder_block(x, encoder_output, src_padding_mask, return_attn=True)
                cross_attns.append(attn_weights)
            else:
                x = decoder_block(x, encoder_output, src_padding_mask)
        x = self.ln_f(x)  # (B, T2, n_embd)
        logits = F.linear(x, self.token_embedding.weight)  # (B, T2, tgt_vocab_size)
        return (logits, cross_attns) if return_attn else logits

    def forward(self, src, tgt):
        # src: (B, T1), tgt: (B, T2)
        B = src.shape[0]
        encoder_output, src_padding_mask = self.encode(src)
        decoder_target = torch.cat([tgt[:, 1:], torch.zeros(B, 1, dtype=tgt.dtype, device=tgt.device)], dim=1)
        logits = self.decode(tgt, encoder_output, src_padding_mask)  # (B, T2, tgt_vocab_size)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), decoder_target.view(-1), ignore_index=0)
        return logits, loss

In [8]:
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, min_delta=0):
        self.patience = patience
        self.verbose = verbose
        self.min_delta = min_delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.improvement_counter = 0

    def __call__(self, score):
        if self.best_score is None:
            self.best_score = score
        elif score >= self.best_score - self.min_delta:
            self.improvement_counter = 0
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.improvement_counter += 1
            if self.improvement_counter >= 3:
                self.counter = 0

In [9]:
n_embd = 128
n_head = 4
n_layers = 6
dropout = 0.3
lr = 1e-3
weight_decay = 0.01
vocab_size = tokenizer.get_vocab_size()
batch_size = 128


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

print(f"train_length: {len(train_dataset)}, test_length: {len(train_dataset)}, batch_size: {batch_size}")
print(f"lr: {lr}, weight_decay: {weight_decay}")
print(f"n_embd: {n_embd}, n_head: {n_head}, n_layers: {n_layers}, src_vocab_size: {vocab_size}, dropout: {dropout}")

model = Model(n_embd, n_head, n_layers, vocab_size, dropout)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=1,
    min_lr=1e-6
)
scaler = GradScaler()

train_length: 270000, test_length: 270000, batch_size: 128
lr: 0.001, weight_decay: 0.01
n_embd: 128, n_head: 4, n_layers: 6, src_vocab_size: 62148, dropout: 0.3


In [10]:
def save_checkpoint(model, optimizer, scheduler, epoch, avg_train_loss, test_loss, path):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': avg_train_loss,
        'test_loss': test_loss,
        'model_config': {
            'n_embd': n_embd,
            'n_head': n_head,
            'n_layers': n_layers,
            'vocab_size': vocab_size,
            'dropout': dropout
        }
    }
    torch.save(checkpoint, path)


def load_checkpoint(path, device):
    checkpoint = torch.load(path, map_location=device)

    # 重建模型
    model = Model(**checkpoint['model_config'])
    model = model.to(device)

    # 将模型状态移动到正确的设备
    if device.type == 'cuda':
        checkpoint['model_state_dict'] = {k: v.to(device) for k, v in checkpoint['model_state_dict'].items()}

    model.load_state_dict(checkpoint['model_state_dict'])
    en_stoi = checkpoint['en_stoi']
    en_itos = checkpoint['en_itos']
    zh_stoi = checkpoint['zh_stoi']
    zh_itos = checkpoint['zh_itos']
    max_len = checkpoint['model_config']['max_len']

    return model, checkpoint, en_stoi, en_itos, zh_stoi, zh_itos, max_len


def print_progress(epoch, progress, loss, epoch_time):
    bar_length = 30  # Length of the progress bar
    filled_length = int(bar_length * progress // 100)
    bar = '━' * filled_length + '─' * (bar_length - filled_length)
    time_str = f"{epoch_time:.1f}s" if epoch_time < 60 else f"{epoch_time/60:.1f}min"
    print(f"epoch {epoch}, progress |{bar}| {progress:.0f}%, loss {loss.item():.4f}, time: {time_str}", end='\r')


def evaluate(model, data_loader, device):
    model.eval()
    total_loss = 0
    total_samples = 0
    with torch.no_grad():
        for x, y in data_loader:
            x = x.to(device)
            y = y.to(device)
            with autocast(device.type):
                logits, loss = model(x, y)
            total_loss += loss.item() * x.size(0)
            total_samples += x.size(0)
    model.train()
    return total_loss / total_samples

In [11]:
epoch = 0
start_time = time.time()

In [12]:
max_epoch = 300
early_stopping = EarlyStopping(patience=5, min_delta=0.001, verbose=True)

while epoch < max_epoch:
    epoch_start_time = time.time()  # 记录每个epoch开始时间
    model.train()
    train_losses = []
    for i, batch in enumerate(train_loader):
        x = batch[0].to(device)
        y = batch[1].to(device)
        with autocast(device.type):
            out, loss = model(x, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        train_losses.append(loss.item())
        if i % (len(train_loader) // 20) == 0 and i > 0:
            epoch_time = time.time() - epoch_start_time
            print_progress(epoch, (i + 1) / len(train_loader) * 100, loss, epoch_time)

    test_loss = evaluate(model, test_loader, device)
    avg_train_loss = sum(train_losses) / len(train_losses)  # Calculate average training loss
    epoch_time = time.time() - epoch_start_time  # 计算每个epoch耗时
    total_time = time.time() - start_time  # 计算总耗时
    print(" " * 90, end="\r")

    # 更新学习率调度器
    scheduler.step(test_loss)

    # 打印训练信息，包含时间
    epoch_time_str = f"{epoch_time:.1f}s" if epoch_time < 60 else f"{epoch_time/60:.1f}min"
    total_time_str = f"{total_time:.1f}s" if total_time < 60 else f"{total_time/60:.1f}min"
    print(f"epoch {epoch}, avg train loss: {avg_train_loss:.4f}, test loss: {test_loss:.4f}, "
          f"lr: {scheduler.get_last_lr()[0]:.5f}, epoch time: {epoch_time_str}, total time: {total_time_str}")

    if (epoch + 1) % 10 == 0:
        save_path = f"checkpoint/summary_checkpoint_epoch_{epoch+1}.pt"
        save_checkpoint(model, optimizer, scheduler, epoch, avg_train_loss, test_loss, save_path)
        print(f"检查点已保存到: {save_path}")

    early_stopping(test_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered. Training has been stopped.")
        break
    epoch += 1

epoch 0, avg train loss: 8.1116, test loss: 5.6471, lr: 0.00100, epoch time: 3.6min, total time: 3.6min
epoch 1, avg train loss: 5.3735, test loss: 4.6496, lr: 0.00100, epoch time: 3.6min, total time: 7.1min
epoch 2, avg train loss: 4.6191, test loss: 4.1454, lr: 0.00100, epoch time: 3.7min, total time: 10.8min
epoch 3, avg train loss: 4.1612, test loss: 3.7641, lr: 0.00100, epoch time: 3.6min, total time: 14.5min
epoch 4, avg train loss: 3.8229, test loss: 3.5306, lr: 0.00100, epoch time: 3.6min, total time: 18.1min
epoch 5, avg train loss: 3.5821, test loss: 3.3709, lr: 0.00100, epoch time: 3.7min, total time: 21.8min
epoch 6, avg train loss: 3.3991, test loss: 3.2278, lr: 0.00100, epoch time: 3.6min, total time: 25.4min
epoch 7, avg train loss: 3.2533, test loss: 3.1117, lr: 0.00100, epoch time: 3.7min, total time: 29.1min
epoch 8, avg train loss: 3.1296, test loss: 3.0260, lr: 0.00100, epoch time: 3.7min, total time: 32.8min
epoch 9, avg train loss: 3.0249, test loss: 2.9426, lr: 0

In [13]:
save_path = "checkpoint_star/summary_checkpoint_300k_128.pt"
save_checkpoint(model, optimizer, scheduler, epoch, avg_train_loss, test_loss, save_path)
print(f"\n模型已保存到: {save_path}")


模型已保存到: checkpoint_star/summary_checkpoint_300k_128.pt


In [14]:
def top_k(model, src, sos_token, eos_token, max_len=50, device='cpu', top_k=5, temperature=0.5):
    model.eval()
    with torch.no_grad():
        src = src.to(device)
        B, T = src.shape[0], src.shape[1]
        encoder_output, src_padding_mask = model.encode(src)
        sequences = torch.ones(B, 1, dtype=torch.long, device=device) * sos_token

        for _ in range(max_len):
            logits = model.decode(sequences, encoder_output, src_padding_mask)
            logits = logits[:, -1, :] / temperature  # (B, vocab_size)
            topk_logits, top_k_indices = torch.topk(logits, k=min(top_k, logits.size(-1)))  # (B, k)
            probs = F.softmax(topk_logits, dim=-1)  # (B, k)
            next_token_idx = torch.multinomial(probs, num_samples=1)
            next_token = torch.gather(top_k_indices, 1, next_token_idx)
            sequences = torch.cat([sequences, next_token], dim=1)
            if (next_token == eos_token).all():
                break
    model.train()
    return sequences


def beam_search(model, src, sos_token, eos_token, beam_width=5, max_len=50, device='cpu'):
    model.eval()
    with torch.no_grad():
        src = src.to(device)
        B, T = src.shape[0], src.shape[1]
        encoder_output, src_padding_mask = model.encode(src)
        outputs = []
        for batch_idx in range(B):
            enc_output = encoder_output[batch_idx:batch_idx + 1]  # (1, T1, n_embd)
            src_mask = src_padding_mask[batch_idx:batch_idx + 1]  # (1, T1)
            beams = [(torch.ones(1, 1, dtype=torch.long, device=device) * sos_token, 0.0)]
            for _ in range(max_len):
                new_beams = []
                for seq, log_prob in beams:
                    if seq[0, -1] == eos_token:
                        new_beams.append((seq, log_prob))
                        continue
                    logits = model.decode(seq, enc_output, src_mask)
                    logits = logits[:, -1, :]  # (1, vocab_size)
                    probs = F.softmax(logits, dim=-1)
                    topk_probs, topk_indices = torch.topk(probs, beam_width)
                    for i in range(beam_width):
                        token_id = topk_indices[0, i].item()
                        log_prob = topk_probs[0, i].item()
                        new_seq = torch.cat([seq, torch.tensor([[token_id]], device=device)], dim=1)
                        new_log_prob = log_prob + log_prob
                        new_beams.append((new_seq, new_log_prob))
                beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
                if all(seq[0, -1].item() == eos_token for seq, _ in beams):
                    break
            best_seq = max(beams, key=lambda x: x[1])[0].squeeze(0)
            outputs.append(best_seq)
    model.train()
    return outputs


def translate(model, seqs, max_len, device, tokenizer, method='beam_search', **kwargs):
    B = len(seqs)
    inputs = torch.zeros(B, max_len, dtype=torch.long, device=device)
    for i, seq in enumerate(seqs):
        encoded = [tokenizer.bos_token_id] + tokenizer.encode(seq).ids + [tokenizer.eos_token_id]
        length = min(len(encoded), max_len)
        inputs[i, :length] = torch.tensor(encoded[:length], dtype=torch.long, device=device)
    if method == 'top_k_sample':
        outputs = top_k(model, inputs, tokenizer.bos_token_id, tokenizer.eos_token_id, max_len=max_len, device=device, **kwargs)
    elif method == 'beam_search':
        outputs = beam_search(model, inputs, tokenizer.bos_token_id, tokenizer.eos_token_id, max_len=max_len, device=device, **kwargs)
    results = []
    for i in range(B):
        result = tokenizer.decode(outputs[i].tolist())
        result = result.replace('<eos>', '').replace('|', '').replace('<sos>', '')
        results.append(result)
    return results

In [15]:
seqs = [
    "libyan leader moamer kadhafi monday promised wide political and economic reforms that he said would see ministries dismantled and oil revenues going directly into the pockets of the people.",
]
print("\n".join(seqs))


print("=" * 100)
args = [{
    "method": "top_k_sample",
    "top_k": 5,
    "temperature": 0.5
},
]


for arg in args:
    print(", ".join([f"{k}: {v}" for k, v in arg.items()]))
    print("\n".join(translate(model, seqs, 50, device, tokenizer, **arg)))

libyan leader moamer kadhafi monday promised wide political and economic reforms that he said would see ministries dismantled and oil revenues going directly into the pockets of the people.
method: top_k_sample, top_k: 5, temperature: 0.5
kadhafi promises wide political reforms
