In [14]:
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

In [15]:
class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output


def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

In [16]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [17]:
!pip install torchdata

Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m[33m
[0m

In [18]:
!pip install torchtext

Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m[33m
[0m

In [None]:
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

print(tokenizer("happy"))
print(vocab(tokenizer("happy")))
#print(vocab.__dir__())
print(vocab.get_itos()[2269])




def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

# train_iter was "consumed" by the process of building the vocab,
# so we have to create it again
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)


#print(val_iter)
#print(val_data)
#print(val_data.shape)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def batchify(data: Tensor, bsz: int) -> Tensor:
    """Divides the data into bsz separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Args:
        data: Tensor, shape [N]
        bsz: int, batch size

    Returns:
        Tensor of shape [N // bsz, bsz]
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape [seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

data, targets = get_batch(val_data, 0)
#print(data)
#print(data.shape)

for batch in data:
    print("=============one_batch==============")
    for word in batch:
        print(vocab.get_itos()[word.item()])
        #print(vocab.get_itos()[word.value])
        

['happy']
[2269]
happy
=
to
michigan
support
.
the
states
37
having
which
homarus
the
wolverines
to
=
following
to
m
been
happened
gammarus
north
men
sweep
scientology
resolution
mark
)
a
during
=
west
'
the
in
that
the
diameter
member
filming
homarus
require
s
area
germany
,
game
french
of
.
gammarus
considerable
basketball
in
=
having
'
burr
the
the
,
care
team
front
the
heard
s
<unk>
unit
depiction
known
.
=
of
church
all
fifteenth
is
,
of
as
=
the
the
of
available
anniversary
located
or
the
the
=
2011
marine
scientology
evidence
.
on
a
police
european
=
–
lines
has
regarding
it
a
<unk>
attempting
lobster
<unk>
12
four
been
the
included
<unk>
team
to
or
=
michigan
japanese
present
charges
several
frame
.
shut
common
=
wolverines
37
in
against
new
,
churchill
down
lobster
=
men
mm
germany
certain
features
driven
received
the
,
there
'
(
since
members
,
by
his
video
is
is
s
1
1970
of
such
the
<unk>
shoot
a
an
basketball
@
.
the
as
auxiliary
in
due


In [20]:
bptt = 35
def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    Args:
        source: Tensor, shape [full_seq_len, batch_size]
        i: int

    Returns:
        tuple (data, target), where data has shape [seq_len, batch_size] and
        target has shape [seq_len * batch_size]
    """
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

In [34]:
ntokens = len(vocab)  # size of vocabulary
print(ntokens)
emsize = 200  # embedding dimension
d_hid = 200  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2  # number of heads in nn.MultiheadAttention
dropout = 0.2  # dropout probability
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)

28782


In [22]:
import copy
import time

criterion = nn.CrossEntropyLoss()
lr = 5.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

def train(model: nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 200
    start_time = time.time()
    src_mask = generate_square_subsequent_mask(bptt).to(device)

    num_batches = len(train_data) // bptt
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        batch_size = data.size(0)
        if batch_size != bptt:  # only on last batch
            src_mask = src_mask[:batch_size, :batch_size]
        output = model(data, src_mask)
        loss = criterion(output.view(-1, ntokens), targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            total_loss = 0
            start_time = time.time()

def evaluate(model: nn.Module, eval_data: Tensor) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    src_mask = generate_square_subsequent_mask(bptt).to(device)
    with torch.no_grad():
        for i in range(0, eval_data.size(0) - 1, bptt):
            data, targets = get_batch(eval_data, i)
            batch_size = data.size(0)
            if batch_size != bptt:
                src_mask = src_mask[:batch_size, :batch_size]
            output = model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            print(output_flat)
            total_loss += batch_size * criterion(output_flat, targets).item()
    return total_loss / (len(eval_data) - 1)

In [23]:
best_val_loss = float('inf')
epochs = 3
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model)

    val_loss = evaluate(model, val_data)
    val_ppl = math.exp(val_loss)
    elapsed = time.time() - epoch_start_time
    print('-' * 89)
    print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
          f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)

    scheduler.step()

| epoch   1 |   200/ 2928 batches | lr 5.00 | ms/batch 351.77 | loss  8.21 | ppl  3684.09
| epoch   1 |   400/ 2928 batches | lr 5.00 | ms/batch 311.36 | loss  6.87 | ppl   961.71
| epoch   1 |   600/ 2928 batches | lr 5.00 | ms/batch 365.90 | loss  6.43 | ppl   622.75
| epoch   1 |   800/ 2928 batches | lr 5.00 | ms/batch 306.73 | loss  6.31 | ppl   547.80
| epoch   1 |  1000/ 2928 batches | lr 5.00 | ms/batch 288.00 | loss  6.19 | ppl   488.39
| epoch   1 |  1200/ 2928 batches | lr 5.00 | ms/batch 285.23 | loss  6.16 | ppl   472.40
| epoch   1 |  1400/ 2928 batches | lr 5.00 | ms/batch 300.86 | loss  6.11 | ppl   451.19
| epoch   1 |  1600/ 2928 batches | lr 5.00 | ms/batch 295.29 | loss  6.11 | ppl   448.20
| epoch   1 |  1800/ 2928 batches | lr 5.00 | ms/batch 308.40 | loss  6.03 | ppl   414.62
| epoch   1 |  2000/ 2928 batches | lr 5.00 | ms/batch 306.10 | loss  6.02 | ppl   409.92
| epoch   1 |  2200/ 2928 batches | lr 5.00 | ms/batch 300.47 | loss  5.90 | ppl   365.69
| epoch   

tensor([[ 9.0441,  7.8721,  9.0071,  ..., -1.2061, -0.4901, -1.0907],
        [ 7.9590,  9.0513, 10.1068,  ..., -0.5106, -0.5474, -0.3297],
        [ 8.3951,  7.9902, 10.4079,  ..., -0.3695, -0.3362, -0.6974],
        ...,
        [ 8.3871,  8.7763,  6.7018,  ..., -0.2745, -0.5108, -0.9051],
        [ 7.8994,  7.5017, 10.1545,  ..., -0.6444, -0.3623, -1.0937],
        [ 7.6547, 12.8126,  7.2609,  ..., -0.5255,  0.1270, -0.7957]])
tensor([[ 8.2475,  7.8343,  7.8361,  ..., -0.5800, -0.7531, -0.6876],
        [ 8.1779,  8.4122, 10.6488,  ..., -0.5823, -0.1134, -0.9675],
        [ 8.3795,  8.2661,  9.9217,  ..., -0.4899, -0.2010, -0.5478],
        ...,
        [ 7.8699,  7.3828, 11.0793,  ..., -0.6959, -0.6336, -1.6356],
        [ 7.5024,  9.0186,  7.3065,  ..., -0.8028, -0.0454, -1.4524],
        [ 8.1263,  8.9897, 10.7478,  ..., -0.7336, -0.9306, -0.7328]])
tensor([[ 8.8606,  9.3394,  5.3373,  ..., -0.1279, -0.0718, -0.1619],
        [ 8.1933,  8.0155, 11.2666,  ..., -0.4063, -0.3092, -0

tensor([[ 8.0223,  5.1832,  4.2697,  ..., -0.4801, -0.6599, -1.2417],
        [ 8.0223,  5.1832,  4.2697,  ..., -0.4801, -0.6599, -1.2417],
        [ 7.9180, 10.7207,  8.1727,  ..., -0.6390, -0.5071,  0.1651],
        ...,
        [ 7.7495, 11.8658,  7.4467,  ..., -0.9128, -0.3931, -0.5055],
        [ 8.0893,  7.0521, 10.2323,  ..., -0.3907, -0.4319, -0.8756],
        [ 9.0230,  9.4602,  5.1280,  ...,  0.1234,  0.0120, -0.2997]])
tensor([[ 8.4906,  7.5599,  5.5362,  ...,  0.2160, -0.4300, -0.1577],
        [ 8.9311,  8.3298,  6.1742,  ..., -0.1283, -0.3258, -0.8380],
        [ 8.5157,  8.8974,  5.8933,  ..., -0.3210, -0.6268, -1.0299],
        ...,
        [ 8.1089,  9.6637, 10.1282,  ..., -0.4648, -1.2474,  0.2136],
        [ 8.6667,  8.7720,  6.1332,  ..., -0.3595, -0.5065, -1.0561],
        [ 7.6635,  7.8454, 10.9856,  ..., -0.1057, -0.5197, -0.7292]])
tensor([[ 9.1159,  8.0938, 10.8176,  ..., -0.5149, -0.2383, -0.6514],
        [ 8.5157,  8.8974,  5.8933,  ..., -0.3210, -0.6268, -1

tensor([[ 8.5141e+00,  7.6423e+00,  1.0159e+01,  ..., -5.7326e-01,
         -3.8742e-01, -8.5861e-01],
        [ 8.1831e+00,  6.9986e+00,  1.0508e+01,  ..., -1.9683e-01,
         -4.8336e-01, -7.2581e-01],
        [ 7.2742e+00,  8.3932e+00,  1.1408e+01,  ..., -1.4140e-03,
         -2.6981e-01, -2.6850e-01],
        ...,
        [ 7.4333e+00,  8.5279e+00,  1.2076e+01,  ..., -5.6143e-01,
         -6.5758e-01, -3.1834e-01],
        [ 7.9059e+00,  5.2170e+00,  6.5353e+00,  ..., -5.5915e-01,
         -8.0160e-01, -9.1682e-01],
        [ 7.9222e+00,  7.4608e+00,  1.0678e+01,  ..., -2.7735e-01,
         -4.6960e-01, -7.0636e-01]])
tensor([[ 8.0201,  4.9575,  4.4225,  ...,  0.0515, -0.3783, -0.7649],
        [ 8.4833,  7.9099, 10.6507,  ..., -0.5329, -0.3503, -0.7958],
        [ 7.5660,  9.4634,  7.4468,  ..., -0.6696, -0.4821, -1.1715],
        ...,
        [ 8.3521,  9.5011,  6.3840,  ..., -0.3150, -0.6672, -0.7105],
        [ 7.8324, 11.7607,  7.6762,  ..., -1.0164, -0.3266, -0.6039],
     

tensor([[ 8.2329, 11.8185,  7.6121,  ..., -0.7633,  0.0272, -0.4146],
        [ 8.5238, 10.3870,  5.0021,  ..., -0.5013, -0.0304, -0.4141],
        [ 8.2727,  9.7926,  7.9440,  ..., -0.2222, -0.3623, -0.7513],
        ...,
        [ 7.6922, 10.1662, 10.5869,  ..., -0.7378, -0.3224, -0.3721],
        [ 8.5371,  7.7877, 11.6978,  ...,  0.1234, -0.1999, -0.6010],
        [ 7.5419,  7.8816, 10.3237,  ...,  0.1769, -0.6837, -0.4048]])
tensor([[ 8.0223e+00,  5.1832e+00,  4.2697e+00,  ..., -4.8006e-01,
         -6.5990e-01, -1.2417e+00],
        [ 8.9311e+00,  8.3298e+00,  6.1742e+00,  ..., -1.2827e-01,
         -3.2575e-01, -8.3803e-01],
        [ 8.4861e+00,  8.1699e+00,  1.0443e+01,  ..., -7.4950e-01,
         -4.6097e-01, -1.1280e+00],
        ...,
        [ 8.4979e+00,  1.0714e+01,  4.8074e+00,  ..., -3.5079e-01,
          2.5338e-03, -3.5543e-01],
        [ 9.0090e+00,  9.0727e+00,  6.1525e+00,  ..., -3.9596e-01,
         -5.6308e-01, -8.8552e-01],
        [ 8.7330e+00,  1.0236e+01,  6.

tensor([[ 8.4867,  7.9738, 10.4052,  ..., -0.4981, -0.3948, -0.8597],
        [ 7.5721, 10.2669,  7.5497,  ..., -0.4555, -0.5241, -0.8610],
        [ 8.4972,  7.9209, 10.7986,  ..., -0.3328, -0.3022, -0.5499],
        ...,
        [ 8.0355, 10.5286,  9.1317,  ..., -0.7885, -0.4260, -0.6220],
        [ 8.7002,  7.8149, 10.7131,  ..., -0.5027, -0.3523, -0.7214],
        [ 7.9488,  7.8393, 11.2346,  ..., -0.3778, -0.3556, -0.8419]])
tensor([[ 8.2727,  9.7926,  7.9440,  ..., -0.2222, -0.3623, -0.7513],
        [ 7.8161, 11.7072,  6.7242,  ..., -0.1043, -0.0404, -0.5119],
        [ 8.5157,  8.8974,  5.8933,  ..., -0.3210, -0.6268, -1.0299],
        ...,
        [ 8.4920, 10.9878,  4.9708,  ..., -0.4833,  0.0462, -0.3944],
        [ 7.8124,  8.2275, 11.5541,  ..., -0.5072, -0.3115, -0.6171],
        [ 9.0974,  9.0409,  6.5246,  ..., -0.5185, -0.4732, -1.0200]])
tensor([[ 7.6460,  8.5193,  6.6363,  ..., -0.7612, -0.9533, -0.6792],
        [ 7.0352,  6.7968,  9.8755,  ..., -0.1351, -0.6024,  0

tensor([[ 8.7675, 10.2439,  6.3488,  ..., -1.1255, -0.3953, -0.8378],
        [ 8.4269, 11.3567,  6.9821,  ..., -0.6456, -0.2969,  0.0974],
        [ 8.0223,  5.1832,  4.2697,  ..., -0.4801, -0.6599, -1.2417],
        ...,
        [ 7.8892,  8.0341, 10.7920,  ..., -0.2475, -0.4793, -0.7313],
        [ 7.9060,  7.7885, 10.6502,  ..., -0.3093, -0.4699, -0.6496],
        [ 7.6548,  9.7913, 10.5833,  ..., -0.5141, -0.6058, -0.1567]])
tensor([[ 8.0223,  5.1832,  4.2697,  ..., -0.4801, -0.6599, -1.2417],
        [ 8.4269, 11.3567,  6.9821,  ..., -0.6456, -0.2969,  0.0974],
        [ 7.5458,  8.3061, 11.2873,  ..., -0.1903, -0.4167, -0.5216],
        ...,
        [ 6.8851,  7.1900, 10.3206,  ..., -0.6239, -0.7748, -0.5323],
        [ 8.0895,  6.8998,  9.3701,  ..., -0.1189, -0.4419, -0.8276],
        [ 8.7030,  7.7225, 10.6205,  ..., -0.4917, -0.4042, -0.7159]])
tensor([[ 8.1458,  8.3235, 10.5838,  ..., -0.5574, -0.2559, -0.6024],
        [ 8.1152,  8.3736, 10.6265,  ..., -0.3098, -0.4236, -0

tensor([[ 8.8323,  8.0090, 10.5576,  ..., -0.6468, -0.4299, -0.7518],
        [ 7.8718, 10.6375,  5.0276,  ..., -0.5355, -0.2800, -0.3429],
        [ 7.5871, 11.6443,  7.5768,  ..., -0.8121, -0.3764, -0.5202],
        ...,
        [ 7.9858,  8.5136, 10.4514,  ..., -0.4292, -0.7448, -0.4940],
        [ 7.9658,  8.6894, 10.8081,  ..., -0.3117, -0.1766, -0.2740],
        [ 7.6973,  8.1209, 10.8451,  ..., -0.6269, -0.3973, -0.7783]])
tensor([[ 7.7405,  7.1690,  8.8995,  ..., -0.4978, -0.5963, -0.1288],
        [ 7.9945,  8.0706, 10.6739,  ..., -0.3217, -0.2441, -0.7631],
        [ 7.3716,  9.6699, 10.0203,  ..., -0.6680, -0.2427, -0.9603],
        ...,
        [ 8.7381,  5.9644,  6.8234,  ..., -1.0523, -1.0522, -1.2394],
        [ 8.1308,  7.7610,  9.8616,  ..., -0.9048, -0.0967,  0.0861],
        [ 7.6435,  6.4689,  9.6687,  ..., -0.2379, -0.4243, -0.2201]])
tensor([[ 8.0375,  8.3674, 11.0822,  ..., -0.4218, -0.2826, -0.6021],
        [ 8.2789,  8.2291, 10.7341,  ..., -0.5033, -0.3412, -0

tensor([[ 7.3471,  6.3815,  7.4313,  ..., -0.5269, -0.7007, -0.4695],
        [ 7.6771,  8.0474, 11.0562,  ...,  0.0303, -0.3536, -0.0292],
        [ 8.5157,  8.8974,  5.8933,  ..., -0.3210, -0.6268, -1.0299],
        ...,
        [ 8.5521, 11.8059,  6.0037,  ..., -0.3453, -0.2664,  0.1137],
        [ 8.8969, 10.6214,  6.9402,  ..., -0.0895, -0.3911, -0.3887],
        [ 8.0003,  8.3686, 11.3764,  ..., -0.4555, -0.3461, -0.7342]])
tensor([[ 8.9311,  8.3298,  6.1742,  ..., -0.1283, -0.3258, -0.8380],
        [ 7.8446,  8.2980, 11.1126,  ..., -0.2428, -0.2952, -0.5431],
        [ 8.2533,  7.7008, 10.2191,  ..., -0.3942, -0.4017, -0.8361],
        ...,
        [ 7.8373,  7.1093, 10.1378,  ..., -0.0428, -0.5823, -0.8060],
        [ 8.6619,  7.4707, 10.7134,  ..., -0.3755, -0.4017, -0.8449],
        [ 9.0521,  6.8298,  8.9343,  ..., -0.9380, -0.9062, -1.2654]])
tensor([[ 9.1159e+00,  8.0938e+00,  1.0818e+01,  ..., -5.1492e-01,
         -2.3825e-01, -6.5141e-01],
        [ 7.7202e+00,  8.0753

tensor([[ 8.4269, 11.3567,  6.9821,  ..., -0.6456, -0.2969,  0.0974],
        [ 8.5157,  8.8974,  5.8933,  ..., -0.3210, -0.6268, -1.0299],
        [ 7.3555,  9.0878, 11.6652,  ..., -0.3684, -0.0914,  0.4803],
        ...,
        [ 8.3489,  5.3986,  4.1792,  ..., -0.4685, -0.7003, -1.0009],
        [ 7.9853,  7.4887, 11.2444,  ..., -0.2002, -0.4851, -0.9893],
        [ 7.8654,  8.1880, 11.3476,  ..., -0.4141, -0.3483, -0.7113]])
tensor([[ 8.7945,  9.5519,  5.9088,  ..., -0.1036, -0.5290, -0.6382],
        [ 9.1159,  8.0938, 10.8176,  ..., -0.5149, -0.2383, -0.6514],
        [ 7.4450,  8.2213, 10.8076,  ..., -0.1453, -0.2801, -0.7813],
        ...,
        [ 8.0149,  6.9900, 10.0913,  ..., -0.3222, -0.4515, -0.9506],
        [ 8.4284,  5.6195,  5.5093,  ..., -0.7123, -0.7158, -1.2759],
        [ 8.9553,  9.0517,  5.1315,  ...,  0.1765, -0.3016, -0.6773]])
tensor([[ 8.2845,  8.0207, 11.0271,  ..., -0.3966, -0.3219, -0.5045],
        [ 8.3749,  8.1246, 11.0845,  ..., -0.4971, -0.3345, -0

tensor([[ 8.0201,  4.9575,  4.4225,  ...,  0.0515, -0.3783, -0.7649],
        [ 8.4269, 11.3567,  6.9821,  ..., -0.6456, -0.2969,  0.0974],
        [ 8.5157,  8.8974,  5.8933,  ..., -0.3210, -0.6268, -1.0299],
        ...,
        [ 8.2192,  5.2915,  3.8943,  ..., -0.3612, -0.6888, -0.9762],
        [ 8.1069,  7.1066, 10.2812,  ..., -0.3763, -0.4223, -1.0503],
        [ 8.0223,  7.9371, 10.6832,  ..., -0.3299, -0.3317, -0.6486]])
tensor([[ 7.4590,  8.0033, 10.6598,  ..., -0.0967, -0.4943, -0.7749],
        [ 8.1865, 11.0017,  5.1790,  ..., -0.7743, -0.1209, -0.4049],
        [ 8.3577,  7.8890, 11.7573,  ..., -0.3326, -0.2597, -0.5338],
        ...,
        [ 9.1153,  9.7980,  5.0082,  ...,  0.1005,  0.0275,  0.0136],
        [ 8.6348,  7.7948, 10.6583,  ..., -0.4355, -0.3469, -0.8244],
        [ 6.9566,  8.7942,  9.7507,  ..., -0.3104, -0.1135,  0.1337]])
tensor([[ 8.7365,  8.0397, 10.5768,  ..., -0.7036, -0.3635, -0.7632],
        [ 8.1149,  8.6358,  6.9103,  ..., -0.1462, -1.3729, -0

tensor([[ 8.2727,  9.7926,  7.9440,  ..., -0.2222, -0.3623, -0.7513],
        [ 8.5351, 10.4315,  4.7877,  ..., -0.4877, -0.1794, -0.4674],
        [ 9.1220,  7.6978,  9.4401,  ..., -0.9874, -0.3054, -0.8298],
        ...,
        [ 8.3786,  5.2013,  4.2915,  ..., -0.4923, -0.6926, -1.1120],
        [ 8.4407, 10.2948,  4.8654,  ..., -0.4105, -0.1697, -0.5444],
        [ 8.1291,  9.1350,  7.6499,  ..., -0.2189, -0.1094, -0.9015]])
tensor([[ 8.4883,  7.6535, 10.4898,  ..., -0.4739, -0.4456, -0.6762],
        [ 9.1159,  8.0938, 10.8176,  ..., -0.5149, -0.2383, -0.6514],
        [ 8.2754, 10.8420,  5.2382,  ..., -0.6954, -0.0725, -0.3710],
        ...,
        [ 8.0667, 12.0642,  7.0380,  ..., -0.8398, -0.4136, -0.2259],
        [ 8.2132,  5.3997,  5.2357,  ..., -0.3948, -0.5787, -1.2138],
        [ 7.9647,  7.9275, 10.1449,  ..., -0.3614, -0.4044, -0.7750]])
tensor([[ 9.1159,  8.0938, 10.8176,  ..., -0.5149, -0.2383, -0.6514],
        [ 9.1159,  8.0938, 10.8176,  ..., -0.5149, -0.2383, -0

tensor([[ 9.1210e+00,  8.0920e+00,  1.0431e+01,  ..., -5.6396e-01,
         -4.3291e-01, -6.7657e-01],
        [ 8.1160e+00,  5.0109e+00,  4.7711e+00,  ..., -9.0798e-02,
         -5.9495e-01, -9.4960e-01],
        [ 6.9766e+00,  6.6017e+00,  9.5282e+00,  ..., -9.3981e-03,
         -6.7669e-01,  3.1120e-01],
        ...,
        [ 8.1861e+00,  5.9623e+00,  6.9420e+00,  ..., -4.8736e-01,
         -1.1478e+00, -1.0462e+00],
        [ 8.3925e+00,  1.1635e+01,  6.4717e+00,  ..., -4.9135e-01,
         -3.2637e-01,  1.1491e-03],
        [ 7.6690e+00,  8.1774e+00,  9.9265e+00,  ..., -7.7714e-01,
         -3.5937e-01, -1.2669e+00]])
tensor([[ 7.3697, 12.2622,  8.1747,  ..., -0.8910, -0.0852, -0.2995],
        [ 7.3496,  9.9413,  6.7348,  ..., -0.6801, -0.1870, -0.9842],
        [ 8.8889, 10.2421,  6.8063,  ..., -1.3457, -0.2946, -1.0292],
        ...,
        [ 8.5369,  9.7423, 10.1543,  ..., -0.7931, -0.9608, -0.3395],
        [ 8.4329, 11.0198,  4.9240,  ..., -0.5348,  0.0363, -0.3632],
     

tensor([[ 8.0277,  9.5318,  9.8409,  ..., -0.7126, -0.3917, -0.6082],
        [ 9.1159,  8.0938, 10.8176,  ..., -0.5149, -0.2383, -0.6514],
        [ 9.3429,  8.1051,  9.5621,  ..., -0.4099, -0.4756, -0.0593],
        ...,
        [ 7.9969,  7.5396, 10.0105,  ..., -0.5951, -0.6465, -1.0483],
        [ 7.9036,  7.3276, 11.3075,  ..., -0.2350, -0.4154, -0.7945],
        [ 6.6009,  6.8036,  9.3276,  ..., -0.2595, -0.6502, -0.4083]])
tensor([[ 8.0223e+00,  5.1832e+00,  4.2697e+00,  ..., -4.8006e-01,
         -6.5990e-01, -1.2417e+00],
        [ 7.6140e+00,  8.5070e+00,  9.5537e+00,  ..., -8.6305e-01,
         -1.2587e-01, -4.0785e-01],
        [ 7.3330e+00,  8.6317e+00,  8.9117e+00,  ..., -6.5688e-01,
         -1.0268e-02, -7.0544e-01],
        ...,
        [ 7.9682e+00,  7.4602e+00,  1.0665e+01,  ..., -1.6930e-01,
         -4.1394e-01, -6.4190e-01],
        [ 7.1071e+00,  8.5368e+00,  8.8978e+00,  ..., -4.9179e-01,
         -6.0503e-01, -7.6955e-01],
        [ 7.6583e+00,  1.1602e+01,  7.

tensor([[ 9.1210,  8.0920, 10.4306,  ..., -0.5640, -0.4329, -0.6766],
        [ 8.2754, 10.8420,  5.2382,  ..., -0.6954, -0.0725, -0.3710],
        [ 8.2727,  9.7926,  7.9440,  ..., -0.2222, -0.3623, -0.7513],
        ...,
        [ 8.4201, 11.7228,  6.4561,  ..., -0.4762, -0.4033,  0.3365],
        [ 7.9875,  7.6310, 11.6065,  ...,  0.2373, -0.3514, -0.3711],
        [ 8.9445,  7.0427,  7.1176,  ..., -0.1272, -0.1811, -0.7441]])
tensor([[ 7.9866,  9.2976,  9.5284,  ..., -0.4107, -0.3221, -0.6800],
        [ 8.1054,  8.7158, 10.3396,  ..., -0.5011, -0.4250, -0.7871],
        [ 8.6383,  8.0580, 11.4557,  ..., -0.3399, -0.2478, -0.4068],
        ...,
        [ 8.0111,  8.2582,  7.0258,  ...,  0.2724, -0.7672, -0.5629],
        [ 7.9866,  8.3809,  9.8608,  ..., -0.5935, -0.5794, -0.6873],
        [ 7.9875, 10.7933,  5.0960,  ..., -0.5347, -0.1563, -0.5575]])
tensor([[ 7.4572,  8.7726,  8.0810,  ..., -0.8244, -0.2845, -1.6375],
        [ 7.7159,  8.8006, 10.8500,  ..., -0.3517, -0.6155, -0

tensor([[ 8.8889, 10.2421,  6.8063,  ..., -1.3457, -0.2946, -1.0292],
        [ 8.3863,  7.8928, 10.5006,  ..., -0.6678, -0.3796, -0.9082],
        [ 8.3215, 10.4249,  9.4953,  ..., -1.0465, -0.5851, -0.6009],
        ...,
        [ 6.8097,  7.1700, 11.2053,  ..., -0.3977, -0.6045, -0.4585],
        [ 8.9107,  9.0161,  6.2887,  ..., -0.3790, -0.4977, -0.8376],
        [ 6.9414,  5.5317,  7.9443,  ..., -0.2648, -1.1189, -0.8856]])
tensor([[ 7.9220,  8.3828, 11.1456,  ..., -0.3667, -0.3494, -0.5770],
        [ 7.9829,  8.2261, 11.0200,  ..., -0.3037, -0.3083, -0.5014],
        [ 7.4730, 10.3249,  9.6563,  ..., -0.3624, -0.2646,  0.0816],
        ...,
        [ 8.3578, 11.7400,  6.4053,  ..., -0.5315, -0.3803,  0.1462],
        [ 8.6901,  7.2996,  8.8592,  ..., -0.2918, -0.3994, -0.9208],
        [ 8.4517,  8.2026,  7.1293,  ...,  0.0724, -1.0016, -0.8174]])
tensor([[ 8.2754, 10.8420,  5.2382,  ..., -0.6954, -0.0725, -0.3710],
        [ 7.1996,  6.4933,  9.8684,  ..., -0.0174, -0.6372,  0

tensor([[ 8.2754, 10.8420,  5.2382,  ..., -0.6954, -0.0725, -0.3710],
        [ 8.4007,  7.9026, 10.2480,  ..., -0.5090, -0.4258, -0.8645],
        [ 7.6093, 10.6717,  9.3884,  ..., -0.6243, -0.3937,  0.0138],
        ...,
        [ 8.7086,  5.9261,  4.9904,  ..., -0.4203, -1.0674, -0.8413],
        [ 9.5838,  7.7535, 10.7437,  ..., -0.6215, -0.2357, -0.4584],
        [ 7.6576,  9.0767, 10.7143,  ..., -0.5285, -0.2540, -0.7624]])
tensor([[ 8.3107,  8.2000, 10.7861,  ..., -0.5017, -0.2469, -0.6507],
        [ 8.2622,  8.2729,  9.7958,  ..., -0.6429, -0.2573, -0.7323],
        [ 9.2537,  8.0822,  7.4733,  ..., -1.0969, -0.4774,  0.1989],
        ...,
        [ 8.2695,  8.7571,  5.5480,  ..., -0.2320, -0.5671, -0.9090],
        [ 9.1110,  8.1039,  4.9721,  ...,  0.5855, -0.3028, -0.5159],
        [ 8.6236, 10.5762,  6.6863,  ..., -1.0977, -0.0266, -1.1048]])
tensor([[ 7.7742,  8.2715, 11.0026,  ..., -0.6249, -0.4308, -0.5594],
        [ 7.9202, 10.8183,  8.4496,  ..., -0.4534, -0.3539, -0

tensor([[ 7.3835,  7.5072,  9.4049,  ..., -0.1545,  0.0883,  0.3152],
        [ 7.7793,  8.4978, 10.9885,  ..., -0.2200, -0.3273, -0.4928],
        [ 7.8613,  8.4728,  7.9179,  ..., -0.3422, -0.3411, -1.3934],
        ...,
        [ 7.9092,  7.0026,  9.9534,  ..., -0.3775, -0.5419, -0.8228],
        [ 7.8044,  7.0101, 11.3989,  ...,  0.0862, -0.4795, -0.6078],
        [ 8.6094,  9.7099,  7.2004,  ..., -0.0124, -0.1432, -0.5362]])
tensor([[ 8.4906e+00,  7.5599e+00,  5.5362e+00,  ...,  2.1596e-01,
         -4.2996e-01, -1.5767e-01],
        [ 8.3016e+00,  8.2625e+00,  1.0950e+01,  ..., -4.1427e-01,
         -3.1651e-01, -5.9506e-01],
        [ 7.0730e+00,  6.4781e+00,  9.7561e+00,  ..., -1.1211e-02,
         -6.1733e-01,  3.0527e-01],
        ...,
        [ 8.4859e+00,  1.0681e+01,  4.7532e+00,  ..., -3.4293e-01,
          1.0748e-02, -3.6376e-01],
        [ 8.4872e+00,  1.1514e+01,  6.3473e+00,  ..., -5.1594e-01,
         -3.5474e-01, -5.3760e-02],
        [ 8.7964e+00,  8.6379e+00,  6.

tensor([[ 7.2985,  9.7654,  9.3901,  ..., -0.5016, -0.0186,  0.4434],
        [ 9.1159,  8.0938, 10.8176,  ..., -0.5149, -0.2383, -0.6514],
        [ 8.2727,  9.7926,  7.9440,  ..., -0.2222, -0.3623, -0.7513],
        ...,
        [ 8.0518,  7.2990, 10.2889,  ..., -0.3700, -0.5491, -0.6842],
        [ 9.2499,  7.1772, 10.0579,  ..., -0.5709, -0.3453, -0.3494],
        [ 8.0646,  7.4830, 10.6340,  ..., -0.3012, -0.4634, -0.8737]])
tensor([[ 7.8343,  8.4431,  7.2472,  ...,  0.0168, -0.5652, -0.2720],
        [ 8.2727,  9.7926,  7.9440,  ..., -0.2222, -0.3623, -0.7513],
        [ 8.5076,  8.0955, 11.4235,  ...,  0.0820, -0.3089, -0.3872],
        ...,
        [ 7.1072,  7.0127, 11.1994,  ...,  0.3320,  0.4323,  0.4110],
        [ 7.9515,  8.7941, 10.1665,  ..., -0.5598, -1.0681, -0.4121],
        [ 7.6846,  7.8057, 10.2682,  ..., -0.4480, -0.3474, -0.8099]])
tensor([[ 8.0201,  4.9575,  4.4225,  ...,  0.0515, -0.3783, -0.7649],
        [ 8.2754, 10.8420,  5.2382,  ..., -0.6954, -0.0725, -0

tensor([[ 8.5157e+00,  8.8974e+00,  5.8933e+00,  ..., -3.2096e-01,
         -6.2681e-01, -1.0299e+00],
        [ 8.4204e+00,  8.1623e+00,  1.0774e+01,  ..., -4.0471e-01,
         -3.6668e-01, -5.9790e-01],
        [ 7.9486e+00,  8.5349e+00,  9.4287e+00,  ..., -5.6327e-01,
         -4.5306e-01, -6.4395e-01],
        ...,
        [ 8.1531e+00,  1.3375e+01,  7.1139e+00,  ..., -6.2050e-01,
         -1.2933e-04, -1.9124e-01],
        [ 8.4710e+00,  5.4517e+00,  7.7901e+00,  ..., -1.1150e+00,
         -7.1989e-01, -7.1089e-01],
        [ 8.1605e+00,  7.5690e+00,  1.0711e+01,  ..., -3.8654e-01,
         -4.7710e-01, -9.5727e-01]])
tensor([[ 8.9311,  8.3298,  6.1742,  ..., -0.1283, -0.3258, -0.8380],
        [ 8.2727,  9.7926,  7.9440,  ..., -0.2222, -0.3623, -0.7513],
        [ 7.6324,  8.5800, 11.0321,  ..., -0.4539, -0.3240, -0.6392],
        ...,
        [ 7.8590, 11.3901,  7.3714,  ..., -0.4276, -0.2922, -0.3238],
        [ 7.9883,  5.0279,  4.0152,  ..., -0.3243, -0.6960, -1.1308],
     

tensor([[ 8.0710,  6.6710,  7.2967,  ..., -1.0235, -0.4523, -0.8738],
        [ 8.0223,  5.1832,  4.2697,  ..., -0.4801, -0.6599, -1.2417],
        [ 8.5076,  8.0955, 11.4235,  ...,  0.0820, -0.3089, -0.3872],
        ...,
        [ 8.5803, 11.6899,  6.1476,  ..., -0.4275, -0.3383,  0.1214],
        [ 8.4407, 10.0429,  6.4402,  ...,  0.0480, -0.3744, -0.5190],
        [ 7.2286,  9.3448,  8.9644,  ..., -0.7445, -0.6423, -0.8276]])
tensor([[ 8.1338,  7.1264,  9.2093,  ..., -0.8755, -1.1975, -0.8273],
        [ 8.5047,  8.0932, 10.8466,  ..., -0.4534, -0.2588, -0.6337],
        [ 7.7723,  7.9756,  9.7596,  ..., -0.1582, -0.4281, -0.2322],
        ...,
        [ 8.9980,  8.6456,  5.8835,  ...,  0.1237, -0.2107, -0.5877],
        [ 7.5647,  6.6282, 11.3040,  ..., -0.0912, -0.6286, -0.5192],
        [ 7.8084, 10.0893, 10.0165,  ..., -0.5112, -0.4298, -0.0528]])
tensor([[ 8.5157,  8.8974,  5.8933,  ..., -0.3210, -0.6268, -1.0299],
        [ 9.1159,  8.0938, 10.8176,  ..., -0.5149, -0.2383, -0

tensor([[ 8.5695,  6.1523,  7.0128,  ..., -0.3089,  0.1772, -1.1647],
        [ 8.4708,  8.0704, 10.8333,  ..., -0.4490, -0.3274, -0.6735],
        [ 8.2727,  9.7926,  7.9440,  ..., -0.2222, -0.3623, -0.7513],
        ...,
        [ 7.3763,  9.2282, 11.7740,  ..., -0.0977, -0.8163,  0.2041],
        [ 8.4130,  9.9841,  6.8056,  ..., -0.1011, -0.3104, -0.6256],
        [ 9.3292,  7.9316,  6.2259,  ..., -0.1715, -0.4135, -0.9313]])
tensor([[ 8.2727,  9.7926,  7.9440,  ..., -0.2222, -0.3623, -0.7513],
        [ 8.0710,  6.6710,  7.2967,  ..., -1.0235, -0.4523, -0.8738],
        [ 8.2754, 10.8420,  5.2382,  ..., -0.6954, -0.0725, -0.3710],
        ...,
        [ 7.9888,  6.8049, 10.1180,  ..., -0.0967, -0.3897, -0.5351],
        [ 7.3206,  8.9930,  7.1500,  ..., -0.6674, -0.1415, -1.4471],
        [ 7.7433,  9.2914, 11.2324,  ..., -0.4441, -0.2546, -0.8153]])
tensor([[ 8.4068,  8.1047, 10.8538,  ..., -0.5859, -0.3595, -0.8555],
        [ 7.8508,  9.4458, 10.8598,  ..., -0.4368, -0.1877, -1

tensor([[ 7.1094,  6.3561,  9.7160,  ...,  0.0467, -0.6478,  0.2684],
        [ 7.6140,  8.5070,  9.5537,  ..., -0.8631, -0.1259, -0.4078],
        [ 7.7519,  9.7620,  9.3721,  ..., -0.4044, -0.2981, -0.0154],
        ...,
        [ 7.9863,  8.4406,  7.5814,  ...,  0.1596, -0.4719, -0.3269],
        [ 7.8857,  7.9676, 11.1500,  ..., -0.6959, -0.8954, -0.9428],
        [ 9.3474,  8.4864,  5.1382,  ...,  0.4095, -0.3503, -0.5100]])
tensor([[ 6.8770,  8.3984,  8.5946,  ..., -0.6478, -0.4597, -1.1239],
        [ 8.2754, 10.8420,  5.2382,  ..., -0.6954, -0.0725, -0.3710],
        [ 8.2754, 10.8420,  5.2382,  ..., -0.6954, -0.0725, -0.3710],
        ...,
        [ 7.8446,  8.3519, 12.8383,  ..., -0.5549,  0.0184, -0.3334],
        [ 8.5406, 10.0591,  6.8525,  ..., -0.1170, -0.3260, -0.6571],
        [ 7.9545,  7.7250, 11.0066,  ..., -0.3006, -0.3676, -0.8119]])
tensor([[ 8.4269, 11.3567,  6.9821,  ..., -0.6456, -0.2969,  0.0974],
        [ 8.2754, 10.8420,  5.2382,  ..., -0.6954, -0.0725, -0

tensor([[ 8.2820,  8.1366, 11.1345,  ..., -0.3298, -0.3154, -0.5191],
        [ 8.0727,  9.3949,  7.9795,  ..., -1.3467, -0.3903,  0.6264],
        [ 8.0223,  5.1832,  4.2697,  ..., -0.4801, -0.6599, -1.2417],
        ...,
        [ 8.4761, 10.5567,  4.6729,  ..., -0.3714, -0.0268, -0.3866],
        [ 8.2251,  7.0623, 10.3681,  ..., -0.5196, -0.5239, -1.0551],
        [ 8.2944,  5.0575,  4.6942,  ..., -0.4979, -0.6765, -1.4951]])
tensor([[ 8.5351, 10.4315,  4.7877,  ..., -0.4877, -0.1794, -0.4674],
        [ 8.2490,  7.7268, 10.2826,  ..., -0.3752, -0.3807, -0.6811],
        [ 8.0997,  8.6590,  6.6673,  ...,  0.0345, -0.6185, -0.3471],
        ...,
        [ 8.1246,  5.4398,  6.2492,  ..., -0.3853, -0.8442, -0.9361],
        [ 8.2690,  8.1713, 10.2842,  ..., -0.4896, -0.3100, -0.8087],
        [ 7.9189,  8.0699, 11.2530,  ..., -0.4396, -0.3155, -0.8485]])
tensor([[ 8.2652,  8.0562, 10.6997,  ..., -0.4560, -0.3880, -0.6904],
        [ 7.7601,  6.8941, 10.4068,  ..., -0.2744, -1.0757,  0

tensor([[ 7.6978, 10.6163,  8.2981,  ..., -0.7493, -0.8055,  0.4928],
        [ 7.8284,  8.4082, 11.3046,  ..., -0.3821, -0.3603, -0.5008],
        [ 7.6733,  9.7192, 10.0112,  ..., -0.5752, -0.2233, -0.5083],
        ...,
        [ 8.0258,  7.3595, 10.7334,  ..., -0.3241, -0.4457, -0.6569],
        [ 8.2549,  7.7783, 10.2982,  ..., -0.1210, -0.5455, -0.7571],
        [ 8.9492,  7.7595, 10.6786,  ..., -0.3241, -0.4092, -0.5119]])
tensor([[ 7.6003e+00,  7.9137e+00,  1.1273e+01,  ..., -2.1244e-01,
         -2.6035e-01, -1.7673e-01],
        [ 7.5077e+00,  9.1430e+00,  1.0724e+01,  ..., -5.9551e-01,
         -5.3033e-01, -3.9605e-01],
        [ 8.9937e+00,  7.6057e+00,  1.0088e+01,  ..., -6.3549e-01,
         -1.0680e+00, -7.9005e-01],
        ...,
        [ 8.1283e+00,  1.0999e+01,  5.0603e+00,  ..., -4.0094e-01,
         -1.3926e-01, -3.9227e-01],
        [ 7.6003e+00,  8.9260e+00,  9.5191e+00,  ...,  1.0158e-02,
         -5.9961e-01, -6.7062e-01],
        [ 7.8395e+00,  8.2868e+00,  1.

tensor([[ 8.7114, 10.2159,  4.8529,  ..., -0.5581, -0.5865, -0.6767],
        [ 8.4269, 11.3567,  6.9821,  ..., -0.6456, -0.2969,  0.0974],
        [ 7.8125,  8.4000, 10.9083,  ..., -0.5331, -0.3071, -0.6218],
        ...,
        [ 6.9942,  7.1095, 10.9066,  ..., -0.3062, -0.6399, -0.0902],
        [ 8.1621,  4.9446,  4.3822,  ..., -0.5577, -0.6976, -1.5206],
        [ 8.7425, 11.6744,  6.1276,  ..., -0.6053, -0.1992, -0.1114]])
tensor([[ 7.7808,  8.1985, 11.2448,  ..., -0.2304, -0.3155, -0.5167],
        [ 7.8991,  8.1663, 11.2281,  ..., -0.2588, -0.3866, -0.3156],
        [ 8.2526,  7.9949, 10.7689,  ..., -0.3806, -0.3519, -0.6949],
        ...,
        [ 8.1472,  5.2434,  4.1606,  ..., -0.3717, -0.7751, -1.1428],
        [ 7.6812,  7.4513,  7.7135,  ..., -0.6462, -0.9322, -1.0586],
        [ 8.4200, 11.4977,  6.1774,  ..., -0.4365, -0.3141, -0.1590]])
tensor([[ 7.3835,  7.5072,  9.4049,  ..., -0.1545,  0.0883,  0.3152],
        [ 7.7505,  7.0457,  8.3520,  ..., -0.6758, -0.5850, -0

tensor([[ 9.1159,  8.0938, 10.8176,  ..., -0.5149, -0.2383, -0.6514],
        [ 8.2754, 10.8420,  5.2382,  ..., -0.6954, -0.0725, -0.3710],
        [ 8.0223,  5.1832,  4.2697,  ..., -0.4801, -0.6599, -1.2417],
        ...,
        [ 7.4519,  9.9581, 10.0313,  ..., -0.3747, -0.5895, -0.1854],
        [ 7.7133,  6.6998,  9.3496,  ..., -0.3174, -0.8592, -1.8015],
        [ 8.0010,  7.9433,  9.3448,  ..., -0.9252, -0.7598, -1.3448]])
tensor([[ 9.1159,  8.0938, 10.8176,  ..., -0.5149, -0.2383, -0.6514],
        [ 8.2727,  9.7926,  7.9440,  ..., -0.2222, -0.3623, -0.7513],
        [ 9.1159,  8.0938, 10.8176,  ..., -0.5149, -0.2383, -0.6514],
        ...,
        [ 8.0995,  9.0159,  6.3409,  ...,  0.1365, -0.5164, -0.1884],
        [ 8.3194,  8.1701, 11.4790,  ..., -0.4654, -0.3297, -0.4191],
        [ 7.2950,  8.8420, 11.7784,  ..., -0.7513, -0.0809, -0.5886]])
tensor([[ 8.3017,  7.9010, 10.5790,  ..., -0.4906, -0.3195, -0.8136],
        [ 8.2727,  9.7926,  7.9440,  ..., -0.2222, -0.3623, -0

tensor([[ 8.5157e+00,  8.8974e+00,  5.8933e+00,  ..., -3.2096e-01,
         -6.2681e-01, -1.0299e+00],
        [ 7.3835e+00,  7.5072e+00,  9.4049e+00,  ..., -1.5453e-01,
          8.8306e-02,  3.1519e-01],
        [ 7.1489e+00,  6.1653e+00,  9.2841e+00,  ..., -8.6870e-03,
         -7.1211e-01,  2.3040e-01],
        ...,
        [ 8.0649e+00,  1.0294e+01,  9.4900e+00,  ..., -4.6491e-01,
         -7.9160e-01, -1.4819e-02],
        [ 8.0580e+00,  7.7380e+00,  1.0917e+01,  ..., -2.2095e-01,
         -4.6525e-01, -7.2162e-01],
        [ 8.5385e+00,  5.5168e+00,  5.3045e+00,  ..., -4.8853e-01,
         -4.6447e-01, -1.3903e+00]])
tensor([[ 8.0972,  7.9194, 10.7906,  ..., -0.4236, -0.4254, -0.7460],
        [ 8.1149,  8.6358,  6.9103,  ..., -0.1462, -1.3729, -0.6609],
        [ 8.2727,  9.7926,  7.9440,  ..., -0.2222, -0.3623, -0.7513],
        ...,
        [ 8.5181, 11.7368,  6.1564,  ..., -0.3915, -0.2712,  0.1133],
        [ 8.0951,  9.1397, 10.4331,  ..., -0.5799, -0.3735, -0.9194],
     

tensor([[ 8.1305e+00,  7.6323e+00,  1.0643e+01,  ..., -4.6964e-01,
         -2.7938e-01, -7.8649e-01],
        [ 8.2172e+00,  8.0533e+00,  1.0957e+01,  ..., -3.8077e-01,
         -3.2336e-01, -6.6481e-01],
        [ 8.1347e+00,  8.3596e+00,  8.1433e+00,  ..., -6.3391e-01,
         -2.9668e-01,  5.9399e-01],
        ...,
        [ 9.2309e+00,  9.6663e+00,  5.0847e+00,  ...,  7.0931e-02,
          8.0812e-03, -7.5544e-02],
        [ 8.3713e+00,  5.8882e+00,  5.8945e+00,  ..., -4.5265e-01,
         -1.0769e+00, -1.2332e+00],
        [ 8.1415e+00,  6.1325e+00,  8.0083e+00,  ..., -1.0943e+00,
         -1.2457e+00, -1.6914e+00]])
tensor([[ 8.5527e+00,  7.1641e+00,  7.2420e+00,  ..., -1.0710e+00,
         -7.5080e-01, -1.3218e-01],
        [ 8.2727e+00,  9.7926e+00,  7.9440e+00,  ..., -2.2225e-01,
         -3.6225e-01, -7.5128e-01],
        [ 7.7394e+00,  8.4322e+00,  1.1097e+01,  ..., -3.7482e-01,
         -3.4523e-01, -5.3781e-01],
        ...,
        [ 8.3791e+00,  1.0575e+01,  4.6507e+00

tensor([[ 9.1159,  8.0938, 10.8176,  ..., -0.5149, -0.2383, -0.6514],
        [ 8.2208,  8.0473, 11.0073,  ..., -0.4271, -0.3353, -0.7082],
        [ 7.8161, 11.7072,  6.7242,  ..., -0.1043, -0.0404, -0.5119],
        ...,
        [ 8.5920,  8.7181,  9.1484,  ..., -0.3944, -0.4286, -0.3336],
        [ 8.1434,  9.0609, 10.5368,  ..., -0.8570, -0.5530, -0.9983],
        [ 7.5290,  8.4118, 10.2849,  ..., -0.4948, -0.2322, -1.1061]])
tensor([[ 8.5157e+00,  8.8974e+00,  5.8933e+00,  ..., -3.2096e-01,
         -6.2681e-01, -1.0299e+00],
        [ 9.1081e+00,  8.8367e+00,  8.5193e+00,  ..., -3.2905e-01,
         -6.4881e-01, -8.5023e-01],
        [ 8.5157e+00,  8.8974e+00,  5.8933e+00,  ..., -3.2096e-01,
         -6.2681e-01, -1.0299e+00],
        ...,
        [ 9.2395e+00,  9.8966e+00,  5.1264e+00,  ...,  7.1706e-02,
         -4.2116e-02,  5.1504e-04],
        [ 8.0212e+00,  5.1439e+00,  4.1844e+00,  ..., -5.0249e-01,
         -6.3255e-01, -1.3672e+00],
        [ 8.7741e+00,  8.9458e+00,  6.

tensor([[ 9.2821,  7.5944,  9.3909,  ..., -0.9220, -0.5542, -0.5891],
        [ 8.2559,  7.7897, 11.9480,  ..., -0.4088, -0.2314, -0.5586],
        [ 8.0223,  5.1832,  4.2697,  ..., -0.4801, -0.6599, -1.2417],
        ...,
        [ 8.7425,  6.0245,  7.7427,  ..., -0.7450, -1.0872, -1.0383],
        [ 7.6045,  7.9283, 11.3219,  ..., -0.1876, -0.4076, -0.6933],
        [ 7.2130, 12.4519,  7.6549,  ..., -0.6697,  0.0128, -0.8063]])
tensor([[ 8.0776,  7.9797, 10.9010,  ..., -0.2530, -0.3260, -0.5775],
        [ 8.0223,  5.1832,  4.2697,  ..., -0.4801, -0.6599, -1.2417],
        [ 8.1341,  8.0754,  6.7522,  ...,  0.1435, -0.9004, -0.6530],
        ...,
        [ 8.8556,  6.5679,  9.2647,  ..., -0.6831, -1.7691, -1.1342],
        [ 8.9037,  6.6499,  7.9348,  ..., -0.8449,  0.0804, -1.7072],
        [ 7.6453,  9.6003, 10.0842,  ..., -0.5987, -0.2747, -0.7742]])
tensor([[ 8.2754, 10.8420,  5.2382,  ..., -0.6954, -0.0725, -0.3710],
        [ 7.3834,  9.1047, 11.2114,  ..., -0.2731, -0.5351,  0

tensor([[ 8.2143,  8.8323,  7.2220,  ..., -0.1383, -0.5822, -0.4577],
        [ 7.6465,  9.3561, 10.3808,  ..., -0.8992, -0.4190, -0.8120],
        [ 8.2981,  5.5743,  5.8985,  ..., -0.4765, -0.8847, -1.2189],
        ...,
        [ 8.2262,  6.8862, 10.4519,  ..., -0.4005, -0.4741, -0.9693],
        [ 7.8815,  9.8235, 10.7285,  ..., -0.4729, -0.7878, -0.0786],
        [ 7.8350,  6.9892, 10.9001,  ..., -0.0950, -0.4778, -0.9658]])
tensor([[ 7.3673,  8.6678,  9.9958,  ..., -0.1435, -0.4999, -0.0156],
        [ 9.1159,  8.0938, 10.8176,  ..., -0.5149, -0.2383, -0.6514],
        [ 8.2754, 10.8420,  5.2382,  ..., -0.6954, -0.0725, -0.3710],
        ...,
        [ 8.7042,  6.3669,  7.9856,  ..., -0.6967, -0.7506, -0.6509],
        [ 8.0327,  5.0446,  4.4076,  ..., -0.0295, -0.6619, -0.9645],
        [ 8.5065,  8.7833,  5.7964,  ..., -0.3625, -0.5369, -0.9770]])
tensor([[ 7.8301, 12.3057,  7.9718,  ..., -0.7604, -0.1092, -0.0541],
        [ 8.0486,  8.1317,  7.4963,  ..., -0.0873, -1.2237, -0

KeyboardInterrupt: 

In [None]:
test_loss = evaluate(best_model, test_data)
test_ppl = math.exp(test_loss)
print('=' * 89)
print(f'| End of training | test loss {test_loss:5.2f} | '
      f'test ppl {test_ppl:8.2f}')
print('=' * 89)