<a href="https://colab.research.google.com/github/fredffsixty/Natural_Language_Processing/blob/main/Esercitazioni/E5.%20Transformers/E5_translation_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U spacy
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm
!pip install portalocker>=2.0.0

Collecting spacy
  Downloading spacy-3.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.6/6.6 MB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
Collecting weasel<0.4.0,>=0.1.0 (from spacy)
  Downloading weasel-0.3.3-py3-none-any.whl (49 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.8/49.8 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
Collecting cloudpathlib<0.17.0,>=0.7.0 (from weasel<0.4.0,>=0.1.0->spacy)
  Downloading cloudpathlib-0.16.0-py3-none-any.whl (45 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.0/45.0 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: cloudpathlib, weasel, spacy
  Attempting uninstall: spacy
    Found existing installation: spacy 3.6.1
    Uninstalling spacy-3.6.1:
      Successfully uninstalled spacy-3.6.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages t

In [None]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List

multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"
# we will use validation split as test-set

task_parameters = {
    'SRC_LANGUAGE' : 'de',
    'TGT_LANGUAGE' : 'en',
    # Define special symbols and indices in corresponding order
    'UNK_IDX' : 0,
    'PAD_IDX' : 1,
    'BOS_IDX' : 2,
    'EOS_IDX' : 3,
    'special_symbols' : ['<unk>', '<pad>', '<bos>', '<eos>']
}

hyperparameters = {
    "epochs": 5,
    "learning_rate": 1e-3,
    "batch_size": 128,
    "dropout": 0.1,
    "layers": 3,
    "h_dim": 512,
    "heads": 8,
    "patience": 5,
    "min_delta": 0.01,
}

In [None]:
# https://pytorch.org/text/stable/datasets.html#multi30k

# demo of what we're going to do, sample with the first element of the training set
train_iter = Multi30k(split='train', language_pair=(task_parameters['SRC_LANGUAGE'], task_parameters['TGT_LANGUAGE']))
token_transform = {}
token_transform[task_parameters['SRC_LANGUAGE']] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[task_parameters['TGT_LANGUAGE']] = get_tokenizer('spacy', language='en_core_web_sm')
for i,t in enumerate(train_iter):
    print(str(i)+" "+str(t))
    print(token_transform[task_parameters['SRC_LANGUAGE']](t[0]))
    print(token_transform[task_parameters['TGT_LANGUAGE']](t[1]))
    break

0 ('Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.', 'Two young, White males are outside near many bushes.')
['Zwei', 'junge', 'weiße', 'Männer', 'sind', 'im', 'Freien', 'in', 'der', 'Nähe', 'vieler', 'Büsche', '.']
['Two', 'young', ',', 'White', 'males', 'are', 'outside', 'near', 'many', 'bushes', '.']


In [None]:
# helper function to yield list of tokens
def yield_tokens(token_transform, data_iter, language, task_parameters):
    language_index = {task_parameters['SRC_LANGUAGE']: 0,
                      task_parameters['TGT_LANGUAGE']: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

def tokenization(task_parameters):

    token_transform = {}
    vocab_transform = {}

    #create tokenizers
    token_transform[task_parameters['SRC_LANGUAGE']] = get_tokenizer('spacy', language='de_core_news_sm')
    token_transform[task_parameters['TGT_LANGUAGE']] = get_tokenizer('spacy', language='en_core_web_sm')

    for ln in [task_parameters['SRC_LANGUAGE'], task_parameters['TGT_LANGUAGE']]:
        # Training data Iterator
        train_iter = Multi30k(split='train', language_pair=(task_parameters['SRC_LANGUAGE'], task_parameters['TGT_LANGUAGE']))

        # Create torchtext's Vocab object
        vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(token_transform, train_iter, ln, task_parameters),
                                min_freq=1, specials=task_parameters['special_symbols'], special_first=True)

    # Set ``UNK_IDX`` as the default index. This index is returned when
    # the token is not found. If not set, it throws ``RuntimeError`` when
    # the queried token is not found in the Vocabulary.
    # Vocabulary is build up from training set
    for ln in [task_parameters['SRC_LANGUAGE'], task_parameters['TGT_LANGUAGE']]:
        vocab_transform[ln].set_default_index(task_parameters['UNK_IDX'])

    return token_transform, vocab_transform

In [None]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math

In [None]:
# helper Module that adds positional encoding to the token embedding to
# introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size, dropout, maxlen = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0,
                                       emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding
# tensor of token embeddings: the embedding module can be seen as a lookup table
# that maps tokens to their raw embeddings that will be trained
# https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

In [None]:
# Seq2Seq Network
# https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html
class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers,
                 emb_size, nhead, src_vocab_size, tgt_vocab_size,
                 dim_feedforward = 512, dropout = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self, src, trg, src_mask, tgt_mask,
                src_padding_mask, tgt_padding_mask,
                memory_key_padding_mask):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask,
                                memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src, src_mask):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt, memory, tgt_mask):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory, tgt_mask)

In [None]:
# function for manually generate masks
# for test set, generated mask will be in a form of triangular matrix where
# tokens of the subsequent words cannot be seen (masked)

def generate_square_subsequent_mask(sz, device):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt, device, PAD_IDX):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


In [None]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids):
    return torch.cat((torch.tensor([task_parameters["BOS_IDX"]]),
                      torch.tensor(token_ids),
                      torch.tensor([task_parameters["EOS_IDX"]])))

# function to collate data samples into batch tensors
# necessary to wrap correctly data in batches
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[task_parameters["SRC_LANGUAGE"]]
         (src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[task_parameters["TGT_LANGUAGE"]]
         (tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=task_parameters["PAD_IDX"])
    tgt_batch = pad_sequence(tgt_batch, padding_value=task_parameters["PAD_IDX"])
    return src_batch, tgt_batch

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm

def train_loop(model, dataloader, loss, optimizer, device, task_parameters):
    model.train()

    epoch_loss = 0

    for src, tgt in tqdm(dataloader, desc='training set'):

        optimizer.zero_grad()

        src = src.to(device)
        tgt = tgt.to(device)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input, device, task_parameters["PAD_IDX"])
        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask,
                       tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]
        batch_loss = loss(logits.reshape(-1, logits.shape[-1]),
                          tgt_out.reshape(-1))
        batch_loss.backward()

        optimizer.step()
        epoch_loss += batch_loss.item()

    return epoch_loss / len(list(dataloader))

def test_loop(model, dataloader, loss, device, task_parameters):
    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        for src, tgt in tqdm(dataloader, desc='test set'):

            src = src.to(device)
            tgt = tgt.to(device)

            tgt_input = tgt[:-1, :]

            src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input, device, task_parameters["PAD_IDX"])
            logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

            tgt_out = tgt[1:, :]
            batch_loss = loss(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
            epoch_loss += batch_loss.item()

    return epoch_loss / len(list(dataloader))

In [None]:
def train_test(model,
               epochs,
               optimizer,
               device,
               batch_size,
               task_parameters,
               train_loss_fn,
               test_loss_fn=None,         # non necessariamente train e test loss devono differire
               early_stopping=None,       # posso addstrare senza early stopping
               val_data=None,       # e in questo caso non c'è validation set
               scheduler=None):           # possibile scheduler per monitorare l'andamento di un iperparametro,
                                          # tipicamente il learning rate

    train_iter = Multi30k(split='train', language_pair=(task_parameters["SRC_LANGUAGE"], task_parameters["TGT_LANGUAGE"]))
    train_dataloader = DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn)

    test_iter = Multi30k(split='valid', language_pair=(task_parameters["SRC_LANGUAGE"], task_parameters["TGT_LANGUAGE"]))
    test_dataloader = DataLoader(test_iter, batch_size=batch_size, collate_fn=collate_fn)

    # check sulle funzioni di loss
    if test_loss_fn == None:
        test_loss_fn = train_loss_fn

    # liste dei valori di loss e accuracy epoca per epoca per il plot
    train_loss = []
    val_loss = []
    test_loss = []

    # Ciclo di addestramento con early stopping
    for epoch in tqdm(range(1,epochs+1)):

        epoch_train_loss = train_loop(model, train_dataloader, train_loss_fn, optimizer, device, task_parameters)
        train_loss.append(epoch_train_loss)

        # test
        epoch_test_loss = test_loop(model, test_dataloader, test_loss_fn, device, task_parameters)
        test_loss.append(epoch_test_loss)

        print(f"\nTrain loss: {epoch_train_loss:6.4f} Test loss: {epoch_test_loss:6.4f}")

    return train_loss, test_loss

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('using '+str(device))

_token_transform, vocab_transform = tokenization(task_parameters)

torch.manual_seed(0)

SRC_VOCAB_SIZE = len(vocab_transform[task_parameters['SRC_LANGUAGE']])
TGT_VOCAB_SIZE = len(vocab_transform[task_parameters['TGT_LANGUAGE']])

model = Seq2SeqTransformer(hyperparameters["layers"], hyperparameters["layers"],
                        hyperparameters["h_dim"], hyperparameters["heads"],
                        SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, hyperparameters["h_dim"]).to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=task_parameters["PAD_IDX"])
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparameters["learning_rate"])

# ``src`` and ``tgt`` language text transforms to convert raw
# strings into tensors indices
text_transform = {}
for ln in [task_parameters["SRC_LANGUAGE"], task_parameters["TGT_LANGUAGE"]]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform)  # Add BOS/EOS and create tensor

using cuda




In [None]:
# Routine di addestramento
train_loss, test_loss = train_test(model,
                                hyperparameters['epochs'],
                                optimizer,
                                device,
                                hyperparameters['batch_size'],
                                task_parameters,
                                criterion,
                                criterion,
                                )

  0%|          | 0/20 [00:00<?, ?it/s]

training set: 1it [00:02,  2.97s/it][A
training set: 2it [00:03,  1.38s/it][A
training set: 3it [00:03,  1.18it/s][A
training set: 4it [00:03,  1.69it/s][A
training set: 5it [00:03,  2.22it/s][A
training set: 6it [00:04,  2.81it/s][A
training set: 7it [00:04,  3.32it/s][A
training set: 8it [00:04,  3.82it/s][A
training set: 9it [00:04,  4.26it/s][A
training set: 10it [00:04,  4.62it/s][A
training set: 11it [00:04,  4.86it/s][A
training set: 12it [00:05,  4.95it/s][A
training set: 13it [00:05,  4.91it/s][A
training set: 14it [00:05,  5.11it/s][A
training set: 15it [00:05,  5.29it/s][A
training set: 16it [00:05,  5.34it/s][A
training set: 17it [00:06,  5.11it/s][A
training set: 18it [00:06,  5.22it/s][A
training set: 19it [00:06,  5.26it/s][A
training set: 20it [00:06,  5.35it/s][A
training set: 21it [00:06,  5.20it/s][A
training set: 22it [00:07,  5.23it/s][A
training set: 23it [00:07,  4.93it/s][A
training set: 24it [00:07,


Train loss: 4.4646 Test loss: 3.7777



training set: 0it [00:00, ?it/s][A
training set: 1it [00:00,  6.24it/s][A
training set: 2it [00:00,  5.18it/s][A
training set: 3it [00:00,  5.21it/s][A
training set: 4it [00:00,  5.40it/s][A
training set: 5it [00:00,  5.49it/s][A
training set: 6it [00:01,  5.84it/s][A
training set: 7it [00:01,  5.96it/s][A
training set: 8it [00:01,  6.23it/s][A
training set: 9it [00:01,  6.31it/s][A
training set: 10it [00:01,  6.28it/s][A
training set: 11it [00:01,  6.25it/s][A
training set: 12it [00:02,  6.02it/s][A
training set: 13it [00:02,  5.78it/s][A
training set: 14it [00:02,  5.89it/s][A
training set: 15it [00:02,  5.98it/s][A
training set: 16it [00:02,  6.00it/s][A
training set: 17it [00:02,  5.69it/s][A
training set: 18it [00:03,  5.74it/s][A
training set: 19it [00:03,  5.82it/s][A
training set: 20it [00:03,  6.02it/s][A
training set: 21it [00:03,  6.12it/s][A
training set: 22it [00:03,  6.07it/s][A
training set: 23it [00:03,  5.65it/s][A
training set: 24it [00:04,  5


Train loss: 3.5676 Test loss: 3.5138



training set: 0it [00:00, ?it/s][A
training set: 1it [00:00,  6.13it/s][A
training set: 2it [00:00,  5.05it/s][A
training set: 3it [00:00,  5.12it/s][A
training set: 4it [00:00,  5.36it/s][A
training set: 5it [00:00,  5.46it/s][A
training set: 6it [00:01,  5.81it/s][A
training set: 7it [00:01,  5.90it/s][A
training set: 8it [00:01,  6.15it/s][A
training set: 9it [00:01,  6.22it/s][A
training set: 10it [00:01,  5.97it/s][A
training set: 11it [00:01,  5.81it/s][A
training set: 12it [00:02,  5.47it/s][A
training set: 13it [00:02,  5.17it/s][A
training set: 14it [00:02,  5.22it/s][A
training set: 15it [00:02,  5.34it/s][A
training set: 16it [00:02,  5.41it/s][A
training set: 17it [00:03,  5.20it/s][A
training set: 18it [00:03,  5.28it/s][A
training set: 19it [00:03,  5.22it/s][A
training set: 20it [00:03,  5.38it/s][A
training set: 21it [00:03,  5.46it/s][A
training set: 22it [00:03,  5.56it/s][A
training set: 23it [00:04,  5.26it/s][A
training set: 24it [00:04,  5


Train loss: 3.3155 Test loss: 3.4552



training set: 0it [00:00, ?it/s][A
training set: 1it [00:00,  5.96it/s][A
training set: 2it [00:00,  5.02it/s][A
training set: 3it [00:00,  5.06it/s][A
training set: 4it [00:00,  5.25it/s][A
training set: 5it [00:00,  5.38it/s][A
training set: 6it [00:01,  5.63it/s][A
training set: 7it [00:01,  5.75it/s][A
training set: 8it [00:01,  6.02it/s][A
training set: 9it [00:01,  6.11it/s][A
training set: 10it [00:01,  6.06it/s][A
training set: 11it [00:01,  6.07it/s][A
training set: 12it [00:02,  5.78it/s][A
training set: 13it [00:02,  5.49it/s][A
training set: 14it [00:02,  5.63it/s][A
training set: 15it [00:02,  5.78it/s][A
training set: 16it [00:02,  5.78it/s][A
training set: 17it [00:03,  5.56it/s][A
training set: 18it [00:03,  5.55it/s][A
training set: 19it [00:03,  5.60it/s][A
training set: 20it [00:03,  5.78it/s][A
training set: 21it [00:03,  5.89it/s][A
training set: 22it [00:03,  6.03it/s][A
training set: 23it [00:04,  5.75it/s][A
training set: 24it [00:04,  5


Train loss: 3.1637 Test loss: 3.3958



training set: 0it [00:00, ?it/s][A
training set: 1it [00:00,  5.91it/s][A
training set: 2it [00:00,  4.96it/s][A
training set: 3it [00:00,  4.98it/s][A
training set: 4it [00:00,  5.24it/s][A
training set: 5it [00:00,  5.31it/s][A
training set: 6it [00:01,  5.64it/s][A
training set: 7it [00:01,  5.74it/s][A
training set: 8it [00:01,  5.98it/s][A
training set: 9it [00:01,  5.90it/s][A
training set: 10it [00:01,  5.94it/s][A
training set: 11it [00:01,  5.77it/s][A
training set: 12it [00:02,  5.52it/s][A
training set: 13it [00:02,  5.28it/s][A
training set: 14it [00:02,  5.32it/s][A
training set: 15it [00:02,  5.43it/s][A
training set: 16it [00:02,  5.43it/s][A
training set: 17it [00:03,  5.23it/s][A
training set: 18it [00:03,  5.28it/s][A
training set: 19it [00:03,  5.34it/s][A
training set: 20it [00:03,  5.46it/s][A
training set: 21it [00:03,  5.47it/s][A
training set: 22it [00:04,  5.59it/s][A
training set: 23it [00:04,  5.24it/s][A
training set: 24it [00:04,  5


Train loss: 3.0514 Test loss: 3.4052



training set: 0it [00:00, ?it/s][A
training set: 1it [00:00,  5.82it/s][A
training set: 2it [00:00,  4.82it/s][A
training set: 3it [00:00,  4.87it/s][A
training set: 4it [00:00,  5.07it/s][A
training set: 5it [00:00,  5.24it/s][A
training set: 6it [00:01,  5.56it/s][A
training set: 7it [00:01,  5.67it/s][A
training set: 8it [00:01,  5.96it/s][A
training set: 9it [00:01,  5.93it/s][A
training set: 10it [00:01,  5.99it/s][A
training set: 11it [00:01,  6.00it/s][A
training set: 12it [00:02,  5.63it/s][A
training set: 13it [00:02,  5.36it/s][A
training set: 14it [00:02,  5.35it/s][A
training set: 15it [00:02,  5.45it/s][A
training set: 16it [00:02,  5.39it/s][A
training set: 17it [00:03,  5.08it/s][A
training set: 18it [00:03,  5.09it/s][A
training set: 19it [00:03,  5.11it/s][A
training set: 20it [00:03,  5.20it/s][A
training set: 21it [00:03,  5.37it/s][A
training set: 22it [00:04,  5.54it/s][A
training set: 23it [00:04,  5.23it/s][A
training set: 24it [00:04,  5


Train loss: 2.9484 Test loss: 3.4093



training set: 0it [00:00, ?it/s][A
training set: 1it [00:00,  5.70it/s][A
training set: 2it [00:00,  4.83it/s][A
training set: 3it [00:00,  4.91it/s][A
training set: 4it [00:00,  5.15it/s][A
training set: 5it [00:00,  5.26it/s][A
training set: 6it [00:01,  5.53it/s][A
training set: 7it [00:01,  5.60it/s][A
training set: 8it [00:01,  5.87it/s][A
training set: 9it [00:01,  5.93it/s][A
training set: 10it [00:01,  5.94it/s][A
training set: 11it [00:01,  5.91it/s][A
training set: 12it [00:02,  5.57it/s][A
training set: 13it [00:02,  5.34it/s][A
training set: 14it [00:02,  5.47it/s][A
training set: 15it [00:02,  5.59it/s][A
training set: 16it [00:02,  5.67it/s][A
training set: 17it [00:03,  5.37it/s][A
training set: 18it [00:03,  5.45it/s][A
training set: 19it [00:03,  5.50it/s][A
training set: 20it [00:03,  5.67it/s][A
training set: 21it [00:03,  5.80it/s][A
training set: 22it [00:03,  5.97it/s][A
training set: 23it [00:04,  5.71it/s][A
training set: 24it [00:04,  5


Train loss: 2.8531 Test loss: 3.4138



training set: 0it [00:00, ?it/s][A
training set: 1it [00:00,  4.94it/s][A
training set: 2it [00:00,  4.50it/s][A
training set: 3it [00:00,  4.60it/s][A
training set: 4it [00:00,  4.84it/s][A
training set: 5it [00:01,  4.91it/s][A
training set: 6it [00:01,  5.15it/s][A
training set: 7it [00:01,  5.26it/s][A
training set: 8it [00:01,  5.45it/s][A
training set: 9it [00:01,  5.41it/s][A
training set: 10it [00:01,  5.43it/s][A
training set: 11it [00:02,  5.42it/s][A
training set: 12it [00:02,  5.13it/s][A
training set: 13it [00:02,  4.88it/s][A
training set: 14it [00:02,  4.94it/s][A
training set: 15it [00:02,  5.16it/s][A
training set: 16it [00:03,  5.33it/s][A
training set: 17it [00:03,  5.16it/s][A
training set: 18it [00:03,  5.25it/s][A
training set: 19it [00:03,  5.34it/s][A
training set: 20it [00:03,  5.57it/s][A
training set: 21it [00:04,  5.76it/s][A
training set: 22it [00:04,  5.93it/s][A
training set: 23it [00:04,  5.66it/s][A
training set: 24it [00:04,  5


Train loss: 2.7723 Test loss: 3.4166



training set: 0it [00:00, ?it/s][A
training set: 1it [00:00,  5.79it/s][A
training set: 2it [00:00,  4.82it/s][A
training set: 3it [00:00,  4.92it/s][A
training set: 4it [00:00,  5.16it/s][A
training set: 5it [00:00,  5.24it/s][A
training set: 6it [00:01,  5.54it/s][A
training set: 7it [00:01,  5.63it/s][A
training set: 8it [00:01,  5.93it/s][A
training set: 9it [00:01,  6.00it/s][A
training set: 10it [00:01,  5.94it/s][A
training set: 11it [00:01,  5.95it/s][A
training set: 12it [00:02,  5.66it/s][A
training set: 13it [00:02,  5.46it/s][A
training set: 14it [00:02,  5.53it/s][A
training set: 15it [00:02,  5.65it/s][A
training set: 16it [00:02,  5.69it/s][A
training set: 17it [00:03,  5.39it/s][A
training set: 18it [00:03,  5.45it/s][A
training set: 19it [00:03,  5.50it/s][A
training set: 20it [00:03,  5.66it/s][A
training set: 21it [00:03,  5.76it/s][A
training set: 22it [00:03,  5.92it/s][A
training set: 23it [00:04,  5.62it/s][A
training set: 24it [00:04,  5

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_loss, label='training loss')
plt.plot(test_loss, label='test loss')
plt.legend(loc='lower right')
plt.ylim(0,4)
plt.show()