# 预测

In [1]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from data import StoryDataset
train_dataset = StoryDataset('data\story_genaration_dataset\ROCStories_train.csv')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_dataset), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

In [2]:
tokenizer('I am a student')

['i', 'am', 'a', 'student']

In [3]:
# vocab将文本转换为数字
vocab(['here', 'is', 'an', 'example'])

[1644, 51, 44, 9570]

In [4]:
# vocab将数字转换为文本
vocab.lookup_tokens([1644, 51, 44, 9570])

['here', 'is', 'an', 'example']

## 数据

In [5]:
import torch
from torch import nn, Tensor
from torch.utils.data import dataset
def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
test_iter = StoryDataset('data\story_genaration_dataset\ROCStories_test.csv')
test_data = data_process(test_iter)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def batchify(data: Tensor, bsz: int) -> Tensor:
    """Divides the data into ``bsz`` separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Arguments:
        data: Tensor, shape ``[N]``
        bsz: int, batch size

    Returns:
        Tensor of shape ``[N // bsz, bsz]``
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)
eval_batch_size = 10
test_data = batchify(test_data, eval_batch_size)

In [6]:
from typing import Tuple
bptt = 35
def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    Args:
        source: Tensor, shape ``[full_seq_len, batch_size]``
        i: int

    Returns:
        tuple (data, target), where data has shape ``[seq_len, batch_size]`` and
        target has shape ``[seq_len * batch_size]``
    """
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

## 模型

In [7]:
from model import TransformerModel
ntokens = len(vocab)  # size of vocabulary
emsize = 200  # embedding dimension
d_hid = 200  # dimension of the feedforward network model in ``nn.TransformerEncoder``
nlayers = 2  # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerEncoder``
nhead = 2  # number of heads in ``nn.MultiheadAttention``
dropout = 0.2  # dropout probability
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)



In [8]:
model.load_state_dict(torch.load("model_5_epoch.pth", map_location=device))

<All keys matched successfully>

## 评估

In [9]:
with torch.no_grad():
    for i in range(0, test_data.size(0) - 1, bptt):
        data, targets = get_batch(test_data, i)
        print(data.shape)    # seq_len, batch_size 
        print(targets.shape) # batch_size*seq_len, 1
        seq_len = data.size(0)
        print(seq_len)
        output = model(data)
        print(output.shape) # batch_size, seq_len, vocab_size
        output_flat = output.view(-1, ntokens)
        print(output_flat.shape) # batch_size*seq_len, vocab_size
        break

torch.Size([35, 10])
torch.Size([350])
35
torch.Size([35, 10, 22513])
torch.Size([350, 22513])


In [10]:
tmp_data = data.T

In [11]:
" ".join(vocab.lookup_tokens(tmp_data[1].tolist()))

"at her grandmother ' s house . a light shone in the window where her grandmother sat like a <unk> . her grandmother hugged beth tightly , relieved she had arrived safely . one day"

In [12]:
from decode import greedy_search, beam_search
print("greedy search")
prompt = "tommy was very close to his dad and loved him greatly"
# Example usage:
greedy_search(
    model, 
    prompt, 
    tokenizer, 
    vocab,
    max_len=100
)
print("beam search")
prompt = "a little fish bubble"
# Example usage:
beam_search(
    model, 
    prompt, 
    tokenizer, 
    vocab,
    max_len=100,
    beam_width=3
)


greedy search
11
prompt sentence: 
tommy was very close to his dad and loved him greatly


output sentence: 
tommy was very close to his dad and loved him greatly . he was in the water . he was in the water . he got out of the water . he got out of the water . he got out of the water . he got out of the water . he got out of the water . he got out of the water . he got out of the water . he got out of the water . he got out of the water . he got out of the water . he got out of the water
beam search
input sentence: 
a little fish bubble
output sentence: 
a little fish bubble bottle of sugar . he was in the middle of her . she was in her house . she got out of her house . she got out of her house . she got out of her house . she got out of her house . she got out of her house . she got out of her house . she got out of her house . she got out of her house . she got out of her house . she got out of her house . she got out of her . she was
