In [14]:
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

In [15]:
class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output


def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

In [16]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [17]:
!pip install torchdata

Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m[33m
[0m

In [18]:
!pip install torchtext

Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m[33m
[0m

In [62]:
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

print(tokenizer("happy"))
print(vocab(tokenizer("happy")))
#print(vocab.__dir__())
print(vocab.get_itos()[2269])




def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

# train_iter was "consumed" by the process of building the vocab,
# so we have to create it again
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)


#print(val_iter)
#print(val_data)
#print(val_data.shape)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def batchify(data: Tensor, bsz: int) -> Tensor:
    """Divides the data into bsz separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Args:
        data: Tensor, shape [N]
        bsz: int, batch size

    Returns:
        Tensor of shape [N // bsz, bsz]
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape [seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

data, targets = get_batch(val_data, 0)
#print(data)
#print(data.shape)

for batch in data:
    print("=============one_batch==============")
    for word in batch:
        print(vocab.get_itos()[word.item()])
        #print(vocab.get_itos()[word.value])
        

['happy']
[2269]
happy
=
to
michigan
support
.
the
states
37
having
which
homarus
the
wolverines
to
=
following
to
m
been
happened
gammarus
north
men
sweep
scientology
resolution
mark
)
a
during
=
west
'
the
in
that
the
diameter
member
filming
homarus
require
s
area
germany
,
game
french
of
.
gammarus
considerable
basketball
in
=
having
'
burr
the
the
,
care
team
front
the
heard
s
<unk>
unit
depiction
known
.
=
of
church
all
fifteenth
is
,
of
as
=
the
the
of
available
anniversary
located
or
the
the
=
2011
marine
scientology
evidence
.
on
a
police
european
=
–
lines
has
regarding
it
a
<unk>
attempting
lobster
<unk>
12
four
been
the
included
<unk>
team
to
or
=
michigan
japanese
present
charges
several
frame
.
shut
common
=
wolverines
37
in
against
new
,
churchill
down
lobster
=
men
mm
germany
certain
features
driven
received
the
,
there
'
(
since
members
,
by
his
video
is
is
s
1
1970
of
such
the
<unk>
shoot
a
an
basketball
@
.
the
as
auxiliary
in
due
species
airstrip
team
.
german
native

In [20]:
bptt = 35
def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    Args:
        source: Tensor, shape [full_seq_len, batch_size]
        i: int

    Returns:
        tuple (data, target), where data has shape [seq_len, batch_size] and
        target has shape [seq_len * batch_size]
    """
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

In [34]:
ntokens = len(vocab)  # size of vocabulary
print(ntokens)
emsize = 200  # embedding dimension
d_hid = 200  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2  # number of heads in nn.MultiheadAttention
dropout = 0.2  # dropout probability
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)

28782


In [63]:
import copy
import time

criterion = nn.CrossEntropyLoss()
lr = 5.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

def train(model: nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 200
    start_time = time.time()
    src_mask = generate_square_subsequent_mask(bptt).to(device)

    num_batches = len(train_data) // bptt
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        batch_size = data.size(0)
        if batch_size != bptt:  # only on last batch
            src_mask = src_mask[:batch_size, :batch_size]
        output = model(data, src_mask)
        loss = criterion(output.view(-1, ntokens), targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            total_loss = 0
            start_time = time.time()

def evaluate(model: nn.Module, eval_data: Tensor) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    src_mask = generate_square_subsequent_mask(bptt).to(device)
    with torch.no_grad():
        for i in range(0, eval_data.size(0) - 1, bptt):
            data, targets = get_batch(eval_data, i)
            batch_size = data.size(0)
            if batch_size != bptt:
                src_mask = src_mask[:batch_size, :batch_size]
            output = model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            print(output_flat)
            print(output_flat.shape)
            total_loss += batch_size * criterion(output_flat, targets).item()
    return total_loss / (len(eval_data) - 1)

In [None]:
best_val_loss = float('inf')
epochs = 3
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model)

    val_loss = evaluate(model, val_data)
    val_ppl = math.exp(val_loss)
    elapsed = time.time() - epoch_start_time
    print('-' * 89)
    print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
          f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)

    scheduler.step()

| epoch   1 |   200/ 2928 batches | lr 5.00 | ms/batch 303.64 | loss  8.01 | ppl  2998.86
| epoch   1 |   400/ 2928 batches | lr 5.00 | ms/batch 293.29 | loss  6.87 | ppl   965.37
| epoch   1 |   600/ 2928 batches | lr 5.00 | ms/batch 299.40 | loss  6.44 | ppl   626.80
| epoch   1 |   800/ 2928 batches | lr 5.00 | ms/batch 292.68 | loss  6.30 | ppl   546.28
| epoch   1 |  1000/ 2928 batches | lr 5.00 | ms/batch 283.52 | loss  6.18 | ppl   481.74
| epoch   1 |  1200/ 2928 batches | lr 5.00 | ms/batch 289.49 | loss  6.16 | ppl   471.18
| epoch   1 |  1400/ 2928 batches | lr 5.00 | ms/batch 283.60 | loss  6.11 | ppl   451.78
| epoch   1 |  1600/ 2928 batches | lr 5.00 | ms/batch 285.30 | loss  6.11 | ppl   448.27
| epoch   1 |  1800/ 2928 batches | lr 5.00 | ms/batch 287.96 | loss  6.02 | ppl   412.18
| epoch   1 |  2000/ 2928 batches | lr 5.00 | ms/batch 290.74 | loss  6.01 | ppl   408.31
| epoch   1 |  2200/ 2928 batches | lr 5.00 | ms/batch 292.77 | loss  5.89 | ppl   362.98
| epoch   

tensor([[ 8.3671,  5.9517,  9.0122,  ..., -1.2047, -1.0007, -1.1687],
        [ 8.4810,  8.8118,  9.5768,  ..., -0.1316, -0.8331, -0.3549],
        [ 7.5962,  7.4014, 11.6258,  ..., -0.6232, -0.6595, -0.6936],
        ...,
        [ 8.1101,  8.2493,  6.2740,  ..., -0.3153, -1.0494, -0.7307],
        [ 7.5472,  6.9440,  9.9867,  ..., -1.3368, -0.2986, -0.0926],
        [ 8.1636, 11.5466,  6.9560,  ..., -0.0435,  0.1705, -0.0894]])
torch.Size([350, 28782])
tensor([[ 7.7204,  6.7584,  8.0878,  ..., -1.0389, -0.3779, -0.8066],
        [ 8.4487,  7.6965, 11.3006,  ..., -0.5485, -0.3095, -0.6149],
        [ 8.1150,  6.5641, 11.3605,  ..., -0.6479, -0.6588, -0.6404],
        ...,
        [ 7.1069,  6.8911, 10.7383,  ..., -0.2214, -0.2006, -0.9144],
        [ 7.0118,  8.5613,  7.6263,  ..., -0.4698, -0.4073, -0.1555],
        [ 7.9561,  8.0052, 10.8175,  ..., -0.9577, -0.7969, -0.4855]])
torch.Size([350, 28782])
tensor([[ 8.7392,  8.7304,  5.5793,  ...,  0.2311,  0.0662, -0.5707],
        [ 7.

tensor([[ 8.6644,  7.5515,  9.8941,  ..., -0.4834, -0.8264, -0.5529],
        [ 8.0191,  7.2913, 11.0881,  ..., -0.4545, -0.3865, -0.3968],
        [ 8.4642,  7.4810,  6.6031,  ..., -0.6109, -0.2080, -0.4001],
        ...,
        [ 7.5815,  6.4185, 11.2902,  ..., -0.5893, -0.7215, -0.7306],
        [ 7.5810,  7.3901,  9.6073,  ..., -0.4409,  0.0417, -1.0424],
        [ 7.6313,  7.0055, 11.9468,  ..., -0.5609, -0.7746, -0.6817]])
torch.Size([350, 28782])
tensor([[ 8.7444,  9.3623,  7.0783,  ..., -0.4958, -0.2934, -0.7096],
        [ 8.1468,  7.8051,  6.8638,  ...,  0.1257, -0.4705, -0.2617],
        [ 8.4642,  7.4810,  6.6031,  ..., -0.6109, -0.2080, -0.4001],
        ...,
        [ 8.4461,  7.2955, 11.3881,  ..., -0.3247, -0.5323, -0.7603],
        [ 7.5521,  7.1614,  9.7299,  ..., -1.4019, -0.7372, -0.2424],
        [ 7.2892,  5.5435,  7.4645,  ..., -0.6321, -0.4781, -1.0412]])
torch.Size([350, 28782])
tensor([[ 6.9707,  4.3037,  4.9493,  ..., -0.7395, -0.1933, -0.4140],
        [ 6.

tensor([[ 8.7444,  9.3623,  7.0783,  ..., -0.4958, -0.2934, -0.7096],
        [ 8.9872,  6.9087, 11.3541,  ..., -0.3385, -0.7642, -0.5251],
        [ 8.4460,  9.7666,  7.5616,  ...,  0.2764,  0.6032,  0.0648],
        ...,
        [ 7.8739,  8.2687,  6.9698,  ..., -1.1194, -0.7233, -0.5181],
        [ 8.4526,  5.1292,  6.3676,  ..., -0.6010, -1.0151, -0.6912],
        [ 7.5542,  5.4140,  8.5012,  ..., -0.5732, -1.1291, -0.7408]])
torch.Size([350, 28782])
tensor([[ 8.2242,  6.9055, 11.2508,  ..., -0.4399, -0.8170, -0.5834],
        [ 8.1489,  6.6791, 11.1102,  ..., -0.5035, -0.8043, -0.6916],
        [ 8.7392,  8.7304,  5.5793,  ...,  0.2311,  0.0662, -0.5707],
        ...,
        [ 8.3146,  8.8795,  6.2004,  ..., -0.2773, -0.6341, -0.0446],
        [ 8.0962, 10.6509,  7.3017,  ..., -0.4084, -0.2448, -0.4843],
        [ 7.5446,  7.1494, 11.1395,  ..., -0.6592, -0.7403, -1.1291]])
torch.Size([350, 28782])
tensor([[ 8.4295,  6.2402, 10.4421,  ..., -0.6538, -0.6120, -0.6375],
        [ 7.

tensor([[ 8.1468,  7.8051,  6.8638,  ...,  0.1257, -0.4705, -0.2617],
        [ 6.9707,  4.3037,  4.9493,  ..., -0.7395, -0.1933, -0.4140],
        [ 8.0828,  7.2232, 11.4239,  ..., -0.5930, -0.7137, -0.5622],
        ...,
        [ 7.9991,  5.8283,  7.4842,  ..., -0.8207, -1.0086, -0.6280],
        [ 8.3561,  5.3508,  7.3961,  ..., -1.6124, -1.1273, -0.5542],
        [ 8.8638,  7.3595,  6.8448,  ..., -0.6297, -0.4515, -0.3992]])
torch.Size([350, 28782])
tensor([[ 6.9567,  7.3852, 11.8862,  ..., -0.7780, -0.3477, -0.5527],
        [ 8.2341,  8.9039,  8.4824,  ..., -0.1465, -0.4206, -0.3184],
        [ 7.9535,  7.0810, 11.7963,  ..., -0.6013, -0.7950, -0.5943],
        ...,
        [ 7.6152,  4.5529,  5.1285,  ..., -0.9075, -0.3143, -0.4670],
        [ 8.2168,  8.3008,  6.1178,  ..., -0.3899, -1.1543, -0.4861],
        [ 8.1139,  5.1778,  5.7311,  ..., -0.7029, -0.9052, -0.8031]])
torch.Size([350, 28782])
tensor([[ 7.2680,  8.9907, 10.1254,  ..., -0.5829,  0.2114,  0.1335],
        [ 7.

tensor([[ 8.4460,  9.7666,  7.5616,  ...,  0.2764,  0.6032,  0.0648],
        [ 7.7495,  8.6349,  7.6725,  ..., -1.0477, -0.5574, -0.5089],
        [ 8.9872,  6.9087, 11.3541,  ..., -0.3385, -0.7642, -0.5251],
        ...,
        [ 8.7762,  7.3114,  6.8459,  ..., -0.7425, -0.4147, -0.5799],
        [ 8.7192,  8.8258,  8.2862,  ..., -0.1390, -0.6630, -0.5639],
        [ 7.6422,  7.4664, 11.4965,  ..., -0.4578, -0.5416, -0.5915]])
torch.Size([350, 28782])
tensor([[ 7.9392,  5.0047,  6.5929,  ..., -0.9396, -0.4629, -1.1120],
        [ 8.2341,  8.9039,  8.4824,  ..., -0.1465, -0.4206, -0.3184],
        [ 8.7392,  8.7304,  5.5793,  ...,  0.2311,  0.0662, -0.5707],
        ...,
        [ 8.6281,  6.8810, 11.2893,  ..., -0.4627, -0.6415, -0.6083],
        [ 7.8287,  6.9318, 11.3257,  ..., -0.5755, -0.6883, -0.8341],
        [ 7.4971,  7.3682, 11.7047,  ..., -0.8348, -0.6076, -0.6517]])
torch.Size([350, 28782])
tensor([[ 8.4903,  7.7320, 11.0616,  ..., -0.6536, -0.8468, -0.5922],
        [ 8.

tensor([[ 7.3779,  6.0982,  7.8596,  ..., -0.6819, -0.2126, -1.0557],
        [ 8.9872,  6.9087, 11.3541,  ..., -0.3385, -0.7642, -0.5251],
        [ 7.5199,  6.3172,  7.2972,  ..., -0.2383, -0.2168, -0.9581],
        ...,
        [ 8.6511,  9.0880,  7.6576,  ..., -0.2925, -0.5334, -0.4694],
        [ 7.7196,  7.1017, 11.7143,  ..., -0.5720, -0.4371, -0.5249],
        [ 8.8518,  7.6800,  6.6269,  ...,  0.1502, -0.6885, -0.4562]])
torch.Size([350, 28782])
tensor([[ 6.9707,  4.3037,  4.9493,  ..., -0.7395, -0.1933, -0.4140],
        [ 7.3779,  6.0982,  7.8596,  ..., -0.6819, -0.2126, -1.0557],
        [ 7.6680,  7.5905, 11.9045,  ..., -0.3749, -0.6431, -0.5907],
        ...,
        [ 7.2750,  7.4675, 12.4671,  ..., -0.7145, -0.6872, -0.4683],
        [ 7.8262,  7.8646, 10.8403,  ..., -0.7328, -0.6255, -0.7929],
        [ 9.0026,  8.7713,  7.7530,  ..., -0.0281, -0.7475, -0.5601]])
torch.Size([350, 28782])
tensor([[ 7.4424,  6.5202,  9.8224,  ..., -0.2741, -1.0879, -0.4880],
        [ 7.

tensor([[ 7.2832,  6.3383,  9.5943,  ..., -0.2658, -1.1278, -0.6490],
        [ 7.2299,  5.0868,  7.9059,  ..., -1.1523, -0.8829, -0.4903],
        [ 6.5346,  5.7145,  9.7070,  ..., -1.1104, -0.7975, -1.0284],
        ...,
        [ 6.3964,  7.7125,  9.5003,  ...,  0.0325, -0.2608, -0.4221],
        [ 7.5861,  7.0506, 11.4545,  ..., -0.6990, -0.5610, -0.5472],
        [ 7.9829,  4.2336,  5.4512,  ..., -0.6537, -0.3255, -0.5536]])
torch.Size([350, 28782])
tensor([[ 6.9707,  4.3037,  4.9493,  ..., -0.7395, -0.1933, -0.4140],
        [ 8.1468,  7.8051,  6.8638,  ...,  0.1257, -0.4705, -0.2617],
        [ 8.1204,  7.3706, 11.2243,  ..., -0.4479, -0.6511, -0.6584],
        ...,
        [ 7.3088,  4.4135,  5.0509,  ..., -0.7667, -0.2070, -0.3552],
        [ 7.7319,  8.7607, 11.6453,  ..., -0.8584, -0.3342,  0.1136],
        [ 9.1763,  8.7200,  5.1462,  ...,  0.0508, -0.1529, -0.5116]])
torch.Size([350, 28782])
tensor([[ 8.1468,  7.8051,  6.8638,  ...,  0.1257, -0.4705, -0.2617],
        [ 8.

tensor([[ 7.9989,  6.6211, 10.9235,  ..., -0.4159, -0.6781, -0.5218],
        [ 7.4191,  8.1150, 12.0517,  ..., -0.6156, -0.6914, -0.6647],
        [ 8.2341,  8.9039,  8.4824,  ..., -0.1465, -0.4206, -0.3184],
        ...,
        [ 7.6987,  8.3612,  6.6696,  ..., -0.9093, -0.8200, -0.5118],
        [ 9.0857,  8.9901,  7.8833,  ..., -0.2013, -0.6288, -0.4223],
        [ 8.5502,  8.9077,  8.1733,  ..., -0.2396, -0.5783, -0.5235]])
torch.Size([350, 28782])
tensor([[ 7.9081,  7.9624, 11.2926,  ..., -0.4185, -0.5518, -0.6223],
        [ 8.5122,  5.7497,  9.1552,  ..., -0.4816, -0.9125, -0.1057],
        [ 8.6140,  5.3248,  9.2459,  ..., -0.5599, -0.8669, -0.5203],
        ...,
        [ 8.1186, 10.4328,  7.3500,  ...,  0.0676, -0.5515, -0.2044],
        [ 7.5614,  7.4476, 10.2777,  ..., -0.8626, -0.6763, -0.1880],
        [ 7.8383, 10.8955,  7.5361,  ..., -0.4673, -0.0461, -0.4596]])
torch.Size([350, 28782])
tensor([[ 8.9872,  6.9087, 11.3541,  ..., -0.3385, -0.7642, -0.5251],
        [ 6.

tensor([[ 8.7444,  9.3623,  7.0783,  ..., -0.4958, -0.2934, -0.7096],
        [ 8.9872,  6.9087, 11.3541,  ..., -0.3385, -0.7642, -0.5251],
        [ 8.1473,  8.4935, 11.4299,  ..., -0.5650, -0.6701, -0.6290],
        ...,
        [ 7.5163,  7.8590,  9.5308,  ...,  0.4126, -0.1097,  0.1503],
        [ 7.5032,  6.6665, 11.5849,  ..., -0.7328, -0.7400, -0.5281],
        [ 8.4365,  8.5529,  8.2342,  ...,  0.0514, -0.6643, -0.6433]])
torch.Size([350, 28782])
tensor([[ 8.9872,  6.9087, 11.3541,  ..., -0.3385, -0.7642, -0.5251],
        [ 8.2341,  8.9039,  8.4824,  ..., -0.1465, -0.4206, -0.3184],
        [ 8.1468,  7.8051,  6.8638,  ...,  0.1257, -0.4705, -0.2617],
        ...,
        [ 7.3828,  5.3318,  8.0996,  ..., -1.9502, -1.1149, -0.8208],
        [ 7.5278,  7.9215, 12.3865,  ...,  0.1059, -0.1985, -0.3808],
        [ 7.6576,  8.3835, 11.7892,  ..., -0.3847, -0.5453, -0.5227]])
torch.Size([350, 28782])
tensor([[ 8.2592,  8.6983,  7.5406,  ..., -0.5877, -1.2140, -0.6451],
        [ 7.

tensor([[ 8.9872,  6.9087, 11.3541,  ..., -0.3385, -0.7642, -0.5251],
        [ 8.0501,  5.8516,  6.4160,  ..., -0.7719, -0.6889, -0.6730],
        [ 6.9707,  4.3037,  4.9493,  ..., -0.7395, -0.1933, -0.4140],
        ...,
        [ 7.0863,  7.3437, 12.2954,  ..., -0.4682, -0.5756, -0.6544],
        [ 8.4303,  8.0615, 10.6713,  ..., -0.4525, -0.5186, -0.1873],
        [ 8.8333,  8.8101,  5.6783,  ..., -0.0626, -0.5101, -0.3513]])
torch.Size([350, 28782])
tensor([[ 7.7843,  7.6293, 11.7521,  ..., -0.5009, -0.6701, -0.6634],
        [ 8.0795,  8.7606, 10.3189,  ..., -0.5185, -0.4018, -0.3673],
        [ 8.1880,  7.4887, 10.9101,  ..., -0.4394, -0.7274, -0.7141],
        ...,
        [ 8.7302,  6.9794, 11.3990,  ..., -0.3792, -0.6408, -0.5097],
        [ 7.7130,  7.1967, 10.7433,  ..., -0.6941, -0.5802, -0.7228],
        [ 9.1436,  8.6413,  5.3248,  ...,  0.1580, -0.1506, -0.5772]])
torch.Size([350, 28782])
tensor([[ 8.1468e+00,  7.8051e+00,  6.8638e+00,  ...,  1.2575e-01,
         -4.704

tensor([[ 7.9452,  7.5536, 11.9659,  ..., -0.5632, -0.7254, -0.4847],
        [ 8.0888,  6.6523,  7.8542,  ..., -0.5737,  0.0180, -1.3690],
        [ 8.2074,  7.2854, 11.0947,  ..., -0.3029, -0.8992, -0.5724],
        ...,
        [ 7.2896,  8.1459, 11.6503,  ..., -0.7297, -0.4255, -0.4120],
        [ 9.0569,  8.4571,  8.5449,  ...,  0.1579, -0.7831, -0.5230],
        [ 7.7545,  7.2415, 12.0259,  ..., -0.6152, -0.8918, -0.8417]])
torch.Size([350, 28782])
tensor([[ 6.2956e+00,  7.0095e+00,  1.0649e+01,  ..., -4.2805e-01,
         -9.0508e-03, -6.1565e-01],
        [ 8.1232e+00,  6.9242e+00,  1.1578e+01,  ..., -4.4952e-01,
         -8.0926e-01, -6.8000e-01],
        [ 7.3779e+00,  6.0982e+00,  7.8596e+00,  ..., -6.8190e-01,
         -2.1258e-01, -1.0557e+00],
        ...,
        [ 7.2532e+00,  7.0204e+00,  1.1757e+01,  ..., -7.7858e-01,
         -8.0784e-01, -2.6437e-01],
        [ 8.5322e+00,  8.1760e+00,  1.0863e+01,  ..., -6.4522e-01,
         -6.8432e-01, -7.2478e-01],
        [ 7.9

tensor([[ 6.9441,  6.8199,  9.5321,  ..., -1.3178, -0.8006, -0.6112],
        [ 7.5526,  7.6075, 12.5185,  ..., -0.7143, -0.7036, -0.5654],
        [ 7.4174,  8.8897,  8.4166,  ...,  0.0312, -0.4491, -0.5404],
        ...,
        [ 8.7309,  9.0234,  7.8290,  ..., -0.1459, -0.5961, -0.4835],
        [ 8.3851,  5.7150,  5.8599,  ..., -0.8417, -0.8282,  0.0266],
        [ 8.4009,  7.3303,  7.1508,  ..., -0.5362, -0.8878, -0.9512]])
torch.Size([350, 28782])
tensor([[ 8.3520,  8.9954,  5.4291,  ...,  0.0716, -0.3915, -0.4530],
        [ 7.5333,  6.5126, 11.0810,  ..., -0.2593, -1.2903, -0.9476],
        [ 7.4792,  6.1140,  8.9005,  ..., -1.0976, -0.4274, -0.7667],
        ...,
        [ 8.6571, 10.2164,  6.0877,  ...,  0.1156,  0.5510, -0.0456],
        [ 7.8882,  5.0807,  5.4800,  ..., -0.6647, -0.9080, -0.7721],
        [ 8.4107,  6.5795, 10.0550,  ..., -0.0878, -0.4076, -0.6009]])
torch.Size([350, 28782])
tensor([[ 6.8035e+00,  8.8115e+00,  1.1663e+01,  ..., -6.1175e-01,
          2.570

tensor([[ 8.0501,  5.8516,  6.4160,  ..., -0.7719, -0.6889, -0.6730],
        [ 8.2442,  8.6048, 10.4800,  ..., -0.3915, -0.5551, -0.6055],
        [ 8.2692,  5.8082,  6.2512,  ..., -1.0192, -0.4084, -0.4553],
        ...,
        [ 6.6849,  7.0720, 11.8963,  ..., -0.7513, -0.3260, -0.8101],
        [ 8.7182,  9.0027,  7.8560,  ..., -0.0952, -0.5899, -0.4663],
        [ 8.6472,  9.7437,  5.3932,  ..., -0.0397, -0.3856, -0.0768]])
torch.Size([350, 28782])
tensor([[ 7.6998e+00,  7.1973e+00,  1.2159e+01,  ..., -5.6862e-01,
         -7.3312e-01, -5.3519e-01],
        [ 8.9872e+00,  6.9087e+00,  1.1354e+01,  ..., -3.3853e-01,
         -7.6416e-01, -5.2507e-01],
        [ 8.3745e+00,  9.3899e+00,  5.4104e+00,  ...,  9.4239e-03,
         -3.3896e-01, -8.8965e-02],
        ...,
        [ 7.9331e+00,  7.9531e+00,  7.7689e+00,  ..., -1.2699e+00,
         -3.1201e-01, -7.6563e-01],
        [ 8.0469e+00,  6.6437e+00,  1.0977e+01,  ..., -7.3866e-01,
         -5.6596e-01, -6.7412e-01],
        [ 8.4

tensor([[ 8.4460,  9.7666,  7.5616,  ...,  0.2764,  0.6032,  0.0648],
        [ 8.4903,  7.7320, 11.0616,  ..., -0.6536, -0.8468, -0.5922],
        [ 8.9872,  6.9087, 11.3541,  ..., -0.3385, -0.7642, -0.5251],
        ...,
        [ 7.1702,  7.0319, 13.5208,  ..., -0.7701, -0.5200, -0.6032],
        [ 7.7011,  9.5211, 10.3330,  ..., -0.3896,  0.1291,  0.1688],
        [ 9.2496,  8.6767,  5.4078,  ...,  0.1295, -0.1371, -0.5217]])
torch.Size([350, 28782])
tensor([[ 7.6162,  6.2192,  8.7290,  ..., -1.4953, -1.1217, -0.3327],
        [ 8.2341,  8.9039,  8.4824,  ..., -0.1465, -0.4206, -0.3184],
        [ 8.7392,  8.7304,  5.5793,  ...,  0.2311,  0.0662, -0.5707],
        ...,
        [ 7.9818,  6.7271, 11.9046,  ..., -0.6560, -1.1531, -0.5135],
        [ 8.7645, 10.0194,  5.7707,  ...,  0.1196,  0.4851, -0.1017],
        [ 8.2273,  9.3999,  5.6413,  ...,  0.0898, -0.3561, -0.1772]])
torch.Size([350, 28782])
tensor([[ 8.7444,  9.3623,  7.0783,  ..., -0.4958, -0.2934, -0.7096],
        [ 8.

tensor([[ 8.1468,  7.8051,  6.8638,  ...,  0.1257, -0.4705, -0.2617],
        [ 7.8714,  7.2962, 11.8764,  ..., -0.4555, -0.7297, -0.6264],
        [ 7.0769,  8.1258, 13.7138,  ..., -0.9280, -0.3922, -0.4907],
        ...,
        [ 9.2200,  7.0508, 11.0633,  ..., -0.1849, -0.8130, -0.4824],
        [ 6.8893,  7.0712, 12.5766,  ..., -0.7027, -0.7540, -1.1844],
        [ 7.7890,  6.0224,  9.1699,  ..., -1.2384, -0.3222, -0.8898]])
torch.Size([350, 28782])
tensor([[ 6.9707,  4.3037,  4.9493,  ..., -0.7395, -0.1933, -0.4140],
        [ 6.9707,  4.3037,  4.9493,  ..., -0.7395, -0.1933, -0.4140],
        [ 8.1468,  7.8051,  6.8638,  ...,  0.1257, -0.4705, -0.2617],
        ...,
        [ 8.7603,  7.8232,  6.8555,  ...,  0.0770, -0.6470, -0.4535],
        [ 7.7915,  7.0264, 11.1018,  ..., -0.7230, -0.5618, -0.7973],
        [ 8.0947,  6.1959, 10.3251,  ..., -0.6017, -0.9838, -0.9257]])
torch.Size([350, 28782])
tensor([[ 6.9707e+00,  4.3037e+00,  4.9493e+00,  ..., -7.3952e-01,
         -1.933

tensor([[ 8.4903,  7.7320, 11.0616,  ..., -0.6536, -0.8468, -0.5922],
        [ 8.2037, 11.1123,  7.5342,  ...,  0.2215,  0.2413,  0.0837],
        [ 7.8608,  6.9499, 11.4933,  ..., -0.5141, -0.6992, -0.5610],
        ...,
        [ 7.5147,  6.0083,  7.0673,  ..., -1.5804, -1.0100, -0.3465],
        [ 7.4405,  6.8570, 10.2624,  ..., -0.4294, -0.5754, -1.0926],
        [ 7.2950,  8.0670, 13.7008,  ..., -0.8701, -0.2813, -0.7056]])
torch.Size([350, 28782])
tensor([[ 7.8825e+00,  7.1201e+00,  9.2198e+00,  ..., -1.4087e+00,
         -6.8529e-01, -1.7715e-01],
        [ 8.3745e+00,  9.3899e+00,  5.4104e+00,  ...,  9.4239e-03,
         -3.3896e-01, -8.8965e-02],
        [ 7.5993e+00,  9.0397e+00,  9.0176e+00,  ..., -3.6168e-01,
          9.0612e-03, -6.3618e-01],
        ...,
        [ 7.5151e+00,  4.2765e+00,  4.8960e+00,  ..., -8.7452e-01,
         -3.0830e-01, -4.5984e-01],
        [ 7.4044e+00,  7.6942e+00,  1.3655e+01,  ..., -9.6209e-01,
         -4.9712e-01, -6.4362e-01],
        [ 7.5

tensor([[ 8.4460,  9.7666,  7.5616,  ...,  0.2764,  0.6032,  0.0648],
        [ 8.1468,  7.8051,  6.8638,  ...,  0.1257, -0.4705, -0.2617],
        [ 7.5724,  7.5380, 11.8796,  ..., -0.4555, -0.6654, -0.4923],
        ...,
        [ 7.9500,  4.9557,  5.1869,  ..., -0.8039, -0.8457, -0.7234],
        [ 8.2642,  9.1672,  9.9892,  ..., -0.9863, -0.4431, -0.5881],
        [ 7.4046,  8.1707, 11.9203,  ..., -0.5248, -0.2524, -0.4434]])
torch.Size([350, 28782])
tensor([[ 7.8234,  7.3105, 11.2571,  ..., -0.3928, -0.7269, -0.5190],
        [ 8.0406,  6.9525, 11.2659,  ..., -0.5604, -0.7845, -0.6797],
        [ 7.9265,  7.6303, 11.5374,  ..., -0.1971, -0.5232, -0.6574],
        ...,
        [ 8.6379,  9.1106,  7.7014,  ..., -0.2656, -0.5011, -0.4509],
        [ 8.8050,  7.6089,  6.6616,  ..., -0.7190, -0.4949, -0.3736],
        [ 7.6478,  4.2696,  4.9592,  ..., -0.8439, -0.2302, -0.4395]])
torch.Size([350, 28782])
tensor([[ 8.1106,  7.1884, 11.2192,  ..., -0.5086, -0.8182, -0.5132],
        [ 8.

tensor([[ 6.8472,  8.4919, 13.3700,  ..., -1.0043, -0.4127, -0.7778],
        [ 8.1998,  7.3947, 11.3376,  ..., -0.0927, -0.5270, -0.5497],
        [ 6.9707,  4.3037,  4.9493,  ..., -0.7395, -0.1933, -0.4140],
        ...,
        [ 7.2484,  6.6376, 10.3102,  ..., -0.6037, -1.0377, -0.8821],
        [ 8.1564,  8.2063,  9.6111,  ..., -0.7438, -0.2795, -0.8817],
        [ 6.8834,  6.9991, 11.8611,  ..., -0.4054, -0.5549, -0.3833]])
torch.Size([350, 28782])
tensor([[ 8.3520,  8.9954,  5.4291,  ...,  0.0716, -0.3915, -0.4530],
        [ 8.9872,  6.9087, 11.3541,  ..., -0.3385, -0.7642, -0.5251],
        [ 8.9872,  6.9087, 11.3541,  ..., -0.3385, -0.7642, -0.5251],
        ...,
        [ 7.5299,  4.9610,  7.4006,  ..., -0.6067, -1.1842, -0.7479],
        [ 9.5711,  7.0204, 10.7392,  ..., -0.1503, -0.8448, -0.4799],
        [ 7.5051,  7.0802, 11.3375,  ..., -0.0961, -0.4482, -0.8661]])
torch.Size([350, 28782])
tensor([[ 8.4460,  9.7666,  7.5616,  ...,  0.2764,  0.6032,  0.0648],
        [ 8.

tensor([[ 8.1438e+00,  7.4303e+00,  1.1131e+01,  ..., -4.0020e-01,
         -6.9494e-01, -5.0006e-01],
        [ 8.4903e+00,  7.7320e+00,  1.1062e+01,  ..., -6.5364e-01,
         -8.4684e-01, -5.9220e-01],
        [ 6.9707e+00,  4.3037e+00,  4.9493e+00,  ..., -7.3952e-01,
         -1.9334e-01, -4.1398e-01],
        ...,
        [ 7.8814e+00,  8.4340e+00,  1.0228e+01,  ..., -1.0528e+00,
         -1.7865e-01, -9.8669e-03],
        [ 7.2512e+00,  4.3366e+00,  5.6814e+00,  ..., -6.5026e-01,
         -2.9118e-01, -4.6814e-01],
        [ 7.5450e+00,  7.4894e+00,  1.1597e+01,  ..., -2.7162e-01,
         -5.7815e-01, -6.5150e-01]])
torch.Size([350, 28782])
tensor([[ 7.3650,  7.3203, 13.0219,  ..., -0.7797, -0.7526, -0.6169],
        [ 6.7527,  7.5396, 11.2081,  ..., -0.9919, -0.3558, -0.0447],
        [ 8.1468,  7.8051,  6.8638,  ...,  0.1257, -0.4705, -0.2617],
        ...,
        [ 6.8303,  7.8891,  9.7906,  ...,  0.7094,  0.5100,  0.7950],
        [ 7.5846,  8.6728,  7.0963,  ..., -0.8774,

tensor([[ 6.9400,  9.3590,  7.8325,  ..., -0.5540, -0.0554, -0.5539],
        [ 8.9872,  6.9087, 11.3541,  ..., -0.3385, -0.7642, -0.5251],
        [ 6.9707,  4.3037,  4.9493,  ..., -0.7395, -0.1933, -0.4140],
        ...,
        [ 7.2224,  6.7123, 11.4301,  ...,  0.4836, -0.6541, -0.4588],
        [ 8.7392,  7.6681,  6.7500,  ...,  0.2287, -0.6505, -0.5561],
        [ 8.2564,  9.5435,  9.9513,  ..., -0.7506, -0.4406, -0.1351]])
torch.Size([350, 28782])
tensor([[ 7.8943,  7.4963, 12.1560,  ..., -0.5550, -0.7489, -0.7797],
        [ 7.5619,  8.9137, 11.2908,  ..., -0.1550, -0.0668, -0.4036],
        [ 7.7559,  8.3288, 10.8236,  ..., -0.4737, -0.5162, -0.5002],
        ...,
        [ 8.6099,  8.0971,  6.7340,  ..., -0.8769, -0.8534, -0.6405],
        [ 7.8268,  8.4837,  8.7523,  ..., -0.1800, -0.5859, -0.4799],
        [ 7.7715,  7.7171, 11.2689,  ..., -0.7582, -0.5065, -0.7322]])
torch.Size([350, 28782])
tensor([[ 7.1775,  6.1013,  8.1724,  ..., -0.2045, -1.0250, -0.6805],
        [ 6.

tensor([[ 7.9456,  7.3480, 11.7471,  ..., -0.5405, -0.7000, -0.7385],
        [ 8.6866,  7.7052,  8.6811,  ..., -0.3437, -0.5729, -0.5424],
        [ 8.1685,  8.1565, 10.5314,  ..., -0.3779, -0.7101, -0.6797],
        ...,
        [ 7.2533,  7.6596, 11.4703,  ..., -0.4267, -0.4276, -0.6275],
        [ 7.2394,  5.6804,  7.2872,  ..., -0.6650, -0.4163, -1.0344],
        [ 7.8811,  6.2447,  8.6260,  ..., -1.1632, -0.6666, -0.7377]])
torch.Size([350, 28782])
tensor([[ 7.6763e+00,  6.2206e+00,  8.2041e+00,  ..., -4.8725e-01,
         -5.1870e-03, -1.1886e+00],
        [ 6.1746e+00,  6.1266e+00,  9.9436e+00,  ..., -1.2378e+00,
         -4.1905e-01, -1.1940e+00],
        [ 8.1726e+00,  6.8459e+00,  1.0911e+01,  ..., -4.3311e-01,
         -7.6161e-01, -4.9145e-01],
        ...,
        [ 6.5505e+00,  7.6603e+00,  1.1315e+01,  ..., -5.1701e-01,
          1.3839e-01,  2.9287e-01],
        [ 8.5813e+00,  6.9886e+00,  1.1476e+01,  ..., -4.3954e-01,
         -6.2052e-01, -7.2381e-01],
        [ 7.5

tensor([[ 8.1468,  7.8051,  6.8638,  ...,  0.1257, -0.4705, -0.2617],
        [ 7.6184,  7.5525,  8.9897,  ...,  0.1952, -0.1821,  0.2525],
        [ 8.2341,  8.9039,  8.4824,  ..., -0.1465, -0.4206, -0.3184],
        ...,
        [ 7.7806,  8.3539,  6.6560,  ..., -0.8305, -0.8883, -0.4849],
        [ 7.5400,  6.9408, 11.2405,  ..., -0.1794, -0.6770, -0.9582],
        [ 8.5801,  7.2455,  7.4547,  ..., -0.2039, -0.0659, -1.3889]])
torch.Size([350, 28782])
tensor([[ 8.2431,  7.3026,  7.0747,  ..., -0.1671, -0.2639, -1.0000],
        [ 6.2481,  8.3782, 12.0182,  ..., -0.7913, -0.0663, -0.5262],
        [ 7.7617,  7.7006, 11.4085,  ..., -0.4355, -0.6529, -0.3696],
        ...,
        [ 8.0456,  7.7847, 10.7133,  ..., -1.0159, -0.7776, -0.0333],
        [ 7.6388,  6.7855, 11.4465,  ..., -0.7212, -0.7854, -0.7531],
        [ 6.9983,  7.9443, 12.6169,  ..., -0.8978, -0.5969, -0.9387]])
torch.Size([350, 28782])
tensor([[ 8.7697,  7.7318,  7.1705,  ..., -1.0072, -0.7698, -0.8744],
        [ 6.

tensor([[ 8.0501,  5.8516,  6.4160,  ..., -0.7719, -0.6889, -0.6730],
        [ 8.2692,  5.8082,  6.2512,  ..., -1.0192, -0.4084, -0.4553],
        [ 8.2341,  8.9039,  8.4824,  ..., -0.1465, -0.4206, -0.3184],
        ...,
        [ 8.5683,  7.6488, 10.6416,  ..., -0.7404, -0.4307, -0.3357],
        [ 7.4552,  5.8222,  8.2181,  ..., -0.5964, -0.7996, -0.4190],
        [ 8.5909,  7.7320,  7.0815,  ..., -0.7733, -0.8529, -0.8144]])
torch.Size([350, 28782])
tensor([[ 8.2341,  8.9039,  8.4824,  ..., -0.1465, -0.4206, -0.3184],
        [ 8.9872,  6.9087, 11.3541,  ..., -0.3385, -0.7642, -0.5251],
        [ 7.5734,  9.7608,  9.0944,  ..., -0.0661,  0.2407, -0.1307],
        ...,
        [ 8.4787,  6.8044, 11.9152,  ...,  0.7308, -0.6636, -0.7676],
        [ 7.3610,  7.7576, 11.9751,  ..., -0.6693, -0.5408, -0.4834],
        [ 7.6448,  5.6892, 10.1537,  ..., -1.4019, -0.9935, -0.5432]])
torch.Size([350, 28782])
tensor([[ 7.6391,  6.5564, 11.2452,  ..., -0.6204, -0.5808, -0.7046],
        [ 8.

tensor([[ 8.3995,  9.6593,  5.7028,  ..., -0.0736, -0.2329, -0.0479],
        [ 7.7838,  7.1933, 11.8191,  ..., -0.5971, -0.7177, -0.5002],
        [ 8.1468,  7.8051,  6.8638,  ...,  0.1257, -0.4705, -0.2617],
        ...,
        [ 7.6659,  6.9249, 11.7033,  ..., -0.5916, -0.4502, -0.2644],
        [ 7.7745,  7.2769, 11.4755,  ..., -0.7717, -0.6341, -0.3900],
        [ 9.0364,  7.5097,  6.4978,  ..., -0.6487, -0.5037, -0.4309]])
torch.Size([350, 28782])
tensor([[ 6.8260,  9.4798,  8.1401,  ..., -0.6089,  0.1430, -0.4546],
        [ 6.2820,  6.0403,  9.8624,  ..., -1.5921, -0.6161, -0.9494],
        [ 6.9707,  4.3037,  4.9493,  ..., -0.7395, -0.1933, -0.4140],
        ...,
        [ 7.3515,  7.5514, 11.5952,  ..., -0.3580, -0.5931, -0.5891],
        [ 8.0153,  6.5269, 11.1055,  ..., -1.0475, -0.6492, -0.8378],
        [ 7.8310,  6.2638,  8.1538,  ..., -1.2789, -0.7480, -0.5268]])
torch.Size([350, 28782])
tensor([[ 6.9707,  4.3037,  4.9493,  ..., -0.7395, -0.1933, -0.4140],
        [ 8.

tensor([[ 8.4642,  7.4810,  6.6031,  ..., -0.6109, -0.2080, -0.4001],
        [ 7.7882,  7.3889, 11.5772,  ..., -0.3756, -0.6249, -0.6874],
        [ 7.8664,  6.7125,  8.6325,  ..., -1.1337, -0.9066, -0.3334],
        ...,
        [ 6.8565,  7.4231, 12.5387,  ..., -0.6521, -0.6317, -0.1895],
        [ 7.6707,  7.6459, 11.2620,  ..., -0.4838, -0.6488, -0.7698],
        [ 7.0684,  6.7408,  9.3540,  ..., -0.6353, -0.1066,  0.2201]])
torch.Size([350, 28782])
tensor([[ 8.1014,  8.4033,  7.6336,  ..., -1.0184, -0.8551, -0.5625],
        [ 7.6184,  7.5525,  8.9897,  ...,  0.1952, -0.1821,  0.2525],
        [ 6.6328,  7.1785, 10.9755,  ..., -1.1038, -0.2043, -0.6808],
        ...,
        [ 7.4450,  7.8378,  9.7719,  ...,  0.2966, -0.0964,  0.0870],
        [ 8.5335,  6.8544, 11.4576,  ..., -0.3994, -0.6562, -0.6177],
        [ 8.6560, 10.1207,  5.9984,  ...,  0.1575,  0.5418, -0.0589]])
torch.Size([350, 28782])
tensor([[ 8.3745e+00,  9.3899e+00,  5.4104e+00,  ...,  9.4239e-03,
         -3.389

tensor([[ 7.6974,  7.2935, 12.3742,  ..., -0.6129, -0.6273, -0.5346],
        [ 7.9601,  6.3852, 11.6835,  ..., -0.7432, -1.0570, -0.7009],
        [ 8.0574,  6.7496, 11.3142,  ..., -0.6728, -0.7206, -0.5454],
        ...,
        [ 8.6159,  9.6817,  5.3086,  ..., -0.0810, -0.3226, -0.0225],
        [ 8.4869,  4.8900,  5.9211,  ..., -0.8055, -0.8262, -1.4174],
        [ 7.7384,  6.9926, 11.5183,  ..., -0.6103, -0.5093, -0.4145]])
torch.Size([350, 28782])
tensor([[ 7.7495,  8.6349,  7.6725,  ..., -1.0477, -0.5574, -0.5089],
        [ 8.5084,  7.1922, 11.4377,  ..., -0.8015, -0.9581, -0.6200],
        [ 7.7495,  8.6349,  7.6725,  ..., -1.0477, -0.5574, -0.5089],
        ...,
        [ 7.5045,  7.7683, 12.5909,  ..., -0.6638, -0.7479, -0.3624],
        [ 8.7791,  7.8713,  6.5055,  ..., -0.6036, -0.8188, -0.7259],
        [ 7.6405,  4.4463,  5.8260,  ..., -0.7333, -0.3867, -0.6983]])
torch.Size([350, 28782])
tensor([[ 8.2341,  8.9039,  8.4824,  ..., -0.1465, -0.4206, -0.3184],
        [ 8.

tensor([[ 8.9872e+00,  6.9087e+00,  1.1354e+01,  ..., -3.3853e-01,
         -7.6416e-01, -5.2507e-01],
        [ 7.9354e+00,  7.3949e+00,  1.1240e+01,  ..., -5.7295e-01,
         -8.9628e-01, -7.6099e-01],
        [ 6.2187e+00,  8.8410e+00,  1.0493e+01,  ..., -4.8899e-01,
          2.8620e-02, -4.7013e-01],
        ...,
        [ 8.5505e+00,  9.8833e+00,  5.6694e+00,  ..., -1.1121e-01,
         -2.7653e-01, -1.5954e-01],
        [ 8.3108e+00,  9.5182e+00,  5.4646e+00,  ...,  1.2417e-03,
         -3.4549e-01, -1.5243e-01],
        [ 8.1123e+00,  8.4706e+00,  1.1375e+01,  ..., -5.5668e-01,
         -5.0981e-01, -4.6096e-01]])
torch.Size([350, 28782])
tensor([[ 8.1922,  6.4795, 10.7863,  ..., -0.6145, -0.8763, -0.5959],
        [ 7.7495,  8.6349,  7.6725,  ..., -1.0477, -0.5574, -0.5089],
        [ 8.3949,  8.9287,  5.5895,  ...,  0.0908, -0.3626, -0.1387],
        ...,
        [ 8.6337, 10.0565,  5.9295,  ...,  0.1534,  0.5358, -0.1213],
        [ 6.5464,  7.7598, 12.5811,  ..., -0.7272,

tensor([[ 9.0220,  7.8394,  9.0610,  ..., -0.2844, -0.2378, -0.6652],
        [ 7.5199,  6.3172,  7.2972,  ..., -0.2383, -0.2168, -0.9581],
        [ 7.9458,  7.1932, 11.6289,  ..., -0.5267, -0.6853, -0.6222],
        ...,
        [ 8.7131,  7.6631,  6.7261,  ...,  0.1076, -0.6044, -0.3982],
        [ 7.4629,  4.1847,  5.1855,  ..., -0.5710, -0.2756, -0.5406],
        [ 8.4440,  8.4975,  8.4984,  ...,  0.0625, -0.6736, -0.6572]])
torch.Size([350, 28782])
tensor([[ 8.9872,  6.9087, 11.3541,  ..., -0.3385, -0.7642, -0.5251],
        [ 8.2341,  8.9039,  8.4824,  ..., -0.1465, -0.4206, -0.3184],
        [ 7.3741,  9.4224,  7.8230,  ..., -0.2750, -0.2241, -0.6511],
        ...,
        [ 7.1954,  8.0933, 12.0189,  ..., -0.9339, -0.5290,  0.5293],
        [ 7.1723,  6.3330, 11.2924,  ..., -0.7247, -0.3833, -1.0465],
        [ 8.6140,  4.8308,  5.6889,  ..., -1.1942, -0.7518, -0.9639]])
torch.Size([350, 28782])
tensor([[ 8.0501e+00,  5.8516e+00,  6.4160e+00,  ..., -7.7186e-01,
         -6.889

tensor([[ 8.2592,  8.6983,  7.5406,  ..., -0.5877, -1.2140, -0.6451],
        [ 8.2341,  8.9039,  8.4824,  ..., -0.1465, -0.4206, -0.3184],
        [ 7.1710,  5.3916,  7.8624,  ..., -0.4701, -0.9755, -0.6534],
        ...,
        [ 8.6982,  6.9427, 11.6147,  ..., -0.4212, -0.7229, -0.6133],
        [ 7.1543,  5.8041,  7.7628,  ..., -0.6841, -0.2751, -1.2327],
        [ 8.4619,  7.7146,  6.9530,  ...,  0.1345, -0.5753, -0.6247]])
torch.Size([350, 28782])
tensor([[ 8.7444,  9.3623,  7.0783,  ..., -0.4958, -0.2934, -0.7096],
        [ 9.1620,  6.7315,  7.1244,  ...,  0.8948, -0.6346,  0.1802],
        [ 8.7392,  8.7304,  5.5793,  ...,  0.2311,  0.0662, -0.5707],
        ...,
        [ 9.1403,  7.3713,  6.7107,  ..., -0.5664, -0.5854, -0.4010],
        [ 8.1676,  7.3949, 11.5146,  ..., -0.6307, -0.5417, -0.7045],
        [ 7.7713,  6.5237, 11.1730,  ..., -1.1861, -0.4933, -0.7295]])
torch.Size([350, 28782])
tensor([[ 8.0501,  5.8516,  6.4160,  ..., -0.7719, -0.6889, -0.6730],
        [ 8.

tensor([[ 7.2424,  7.6062,  9.4396,  ..., -0.6969, -0.0945, -1.2026],
        [ 7.9933,  7.2730, 11.5974,  ..., -0.4479, -0.6561, -0.5196],
        [ 7.6933,  7.2647, 11.2187,  ..., -0.5062, -0.5656, -0.5055],
        ...,
        [ 7.5724,  6.4607,  9.8694,  ..., -1.1219, -0.1410,  0.1752],
        [ 8.1002,  6.8843, 10.7600,  ..., -0.8761, -0.2654, -0.8930],
        [ 8.0962,  8.1303,  8.4490,  ...,  0.2722, -0.6630, -0.7809]])
torch.Size([350, 28782])
tensor([[ 7.2138,  8.2846,  7.5789,  ..., -0.6269, -0.2163, -0.4475],
        [ 7.8807,  6.9564, 10.0751,  ..., -1.3341, -0.9204, -0.7858],
        [ 8.7392,  8.7304,  5.5793,  ...,  0.2311,  0.0662, -0.5707],
        ...,
        [ 7.4836,  7.6829, 13.0906,  ..., -1.0104, -0.3775, -0.6887],
        [ 8.1036,  5.1894,  5.7963,  ..., -0.8566, -0.8347, -0.7613],
        [ 7.5882,  8.4833,  7.2809,  ...,  0.0190, -0.5035, -0.5247]])
torch.Size([350, 28782])
tensor([[ 8.1826,  7.0976, 11.4325,  ..., -0.5075, -0.8075, -0.6287],
        [ 8.

tensor([[ 7.2424,  7.6062,  9.4396,  ..., -0.6969, -0.0945, -1.2026],
        [ 8.0945, 10.6572,  7.3312,  ...,  0.2458, -0.1983, -0.2541],
        [ 8.2842,  9.6790,  6.3743,  ..., -0.1674, -0.2976, -0.3849],
        ...,
        [ 7.4699,  7.0413, 11.2614,  ..., -0.5385, -0.5645, -0.9576],
        [ 7.5816,  8.7515, 11.5275,  ..., -0.3430, -0.5362, -0.3084],
        [ 7.2630,  8.4777,  7.9760,  ..., -0.3797, -0.4786, -0.4224]])
torch.Size([350, 28782])
tensor([[ 8.7392,  8.7304,  5.5793,  ...,  0.2311,  0.0662, -0.5707],
        [ 8.2155,  8.2186,  7.0232,  ..., -0.5769, -1.1858, -0.6381],
        [ 7.6482,  7.2145, 11.8503,  ..., -0.7539, -0.6609, -0.6955],
        ...,
        [ 7.6173,  6.5591,  9.3483,  ..., -1.3179, -0.2348, -0.1390],
        [ 7.9002,  8.5797,  6.8144,  ..., -0.7759, -0.8077, -0.5173],
        [ 7.2157,  8.1959,  8.1162,  ..., -0.2353, -0.3015, -0.3775]])
torch.Size([350, 28782])
tensor([[ 8.2341,  8.9039,  8.4824,  ..., -0.1465, -0.4206, -0.3184],
        [ 6.

| epoch   2 |  1400/ 2928 batches | lr 4.75 | ms/batch 312.94 | loss  5.69 | ppl   295.30
| epoch   2 |  1600/ 2928 batches | lr 4.75 | ms/batch 329.81 | loss  5.70 | ppl   299.66
| epoch   2 |  1800/ 2928 batches | lr 4.75 | ms/batch 330.91 | loss  5.64 | ppl   282.65


In [None]:
test_loss = evaluate(best_model, test_data)
test_ppl = math.exp(test_loss)
print('=' * 89)
print(f'| End of training | test loss {test_loss:5.2f} | '
      f'test ppl {test_ppl:8.2f}')
print('=' * 89)