In [2]:
!python -m spacy download de_core_news_sm
!pip install datasets

Collecting de-core-news-sm==3.7.0
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.7.0/de_core_news_sm-3.7.0-py3-none-any.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.6/14.6 MB[0m [31m58.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: de-core-news-sm
Successfully installed de-core-news-sm-3.7.0
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('de_core_news_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.
Collecting datasets
  Downloading datasets-3.0.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (fro

In [4]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import datasets
import spacy
import torch.utils
import torch.utils.data
import tqdm
import random

device = torch.device("cuda" if torch.cpu.is_available() else "cpu")
dataset = datasets.load_dataset("bentrevett/multi30k")

en_nlp = spacy.load("en_core_web_sm")
de_nlp = spacy.load("de_core_news_sm")

pad_token = "<pad>"
sos_token = "<sos>"
eos_token = "<eos>"
unk_token = "<unk>"

train_data, valid_data, test_data = (
  dataset["train"],
  dataset["validation"],
  dataset["test"],
)

def tokenize_data(data, en_nlp, de_nlp, sos_token, eos_token):
  en_tokens = [token.text.lower() for token in en_nlp.tokenizer(data["en"])]
  de_tokens = [token.text.lower() for token in de_nlp.tokenizer(data["de"])]
  en_tokens = [sos_token] + en_tokens + [eos_token]
  de_tokens = [sos_token] + de_tokens + [eos_token]
  return {"en_tokens": en_tokens, "de_tokens": de_tokens}

fn_kwargs = {
  "en_nlp": en_nlp,
  "de_nlp": de_nlp,
  "sos_token": sos_token,
  "eos_token": eos_token,
}

train_data = train_data.map(tokenize_data, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(tokenize_data, fn_kwargs=fn_kwargs)
test_data = test_data.map(tokenize_data, fn_kwargs=fn_kwargs)

def build_vocab(token_lists, min_freq=1):
  s = set([token for tokens in token_lists for token in tokens])
  vocab = {}
  vocab[unk_token] = len(vocab)
  vocab[pad_token] = len(vocab)
  vocab[sos_token] = len(vocab)
  vocab[eos_token] = len(vocab)
  for token in s:
    if token not in vocab:
      vocab[token] = len(vocab)
  return vocab

en_vocab = build_vocab(train_data["en_tokens"])
de_vocab = build_vocab(train_data["de_tokens"])

def text_to_ids(texts, vocab):
  return [[vocab.get(token, vocab[unk_token]) for token in text] for text in texts]

en_ids = text_to_ids(train_data["en_tokens"], en_vocab)
de_ids = text_to_ids(train_data["de_tokens"], de_vocab)






Map:   0%|          | 0/1014 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [7]:
class TranslationDataset(torch.utils.data.Dataset):
    def __init__(self, src, tgt, src_vocab, tgt_vocab, pad_token='<pad>'):
        self.src = src
        self.tgt = tgt
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.pad_token = pad_token
        self.src_max_length = max(len(seq) for seq in src)
        self.tgt_max_length = max(len(seq) for seq in tgt)

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        src_ids = self.src[idx]
        tgt_ids = self.tgt[idx]
        # Pad the source and target sequences
        src_ids = src_ids + [self.src_vocab[self.pad_token]] * (self.src_max_length - len(src_ids))
        tgt_ids = tgt_ids + [self.tgt_vocab[self.pad_token]] * (self.tgt_max_length - len(tgt_ids))

        # Convert lists to tensors
        src_tensor = torch.tensor(src_ids, dtype=torch.long)
        tgt_tensor = torch.tensor(tgt_ids, dtype=torch.long)

        return src_tensor, tgt_tensor


In [40]:
# model
class Encoder(nn.Module):
  def __init__(self, input_size, embed_size, hidden_size, n_layers, dropout_rate):
    super(Encoder, self).__init__()
    self.embedding = nn.Embedding(input_size, embed_size)
    self.bigru = nn.GRU(embed_size, hidden_size, n_layers, bidirectional=True, dropout=dropout_rate)
    self.fc = nn.Linear(hidden_size*2, hidden_size)
    self.dropout = nn.Dropout(dropout_rate)
  def forward(self, input):
    # input: (seq_length, batch_size)
    embedded = self.dropout(self.embedding(input)) # (seq_length, batch_size, embed_size)
    outputs, hidden = self.bigru(embedded)
    # output: (seq_length, batch_size, hidden_size*2)
    # hidden: (n_layers * n_directions, batch_size, hidden_size)
    hidden_cat = torch.cat([hidden[-2,:,:], hidden[-1,:,:]], dim=1)
    # hidden_cat: (batch_size, hidden_size*2) - n_layers set to 1 here for simplicity
    hidden_out = F.relu(self.fc(hidden_cat))
    # hidden_out: (batch_size, hidden_size)
    return outputs, hidden_out

class BahdanauAttention(nn.Module):
  def __init__(self, hidden_size):
    super(BahdanauAttention, self).__init__()
    self.wa = nn.Linear(hidden_size, hidden_size, bias=False)
    self.ua = nn.Linear(hidden_size * 2, hidden_size, bias=False)
    self.va = nn.Linear(hidden_size, 1, bias=False)
  def forward(self, hidden, encoder_outputs):
    # hidden: (batch_size, hidden_size)
    # encoder_outputs: (seq_length, batch_size, hidden_size*2)
    seq_len = encoder_outputs.shape[0]
    hidden = hidden.unsqueeze(1).repeat(1, seq_len, 1)  # (batch_size, seq_length, hidden_size)
    encoder_outputs = encoder_outputs.permute(1, 0, 2)  # (batch_size, seq_length, hidden_size * 2)
    # calculate alignment scores
    energy = F.relu(self.wa(hidden) + self.ua(encoder_outputs))  # (batch_size, seq_length, hidden_size)
    scores = self.va(energy)  # (batch_size, seq_length, 1)
    scores = scores.squeeze(2)  # (batch_size, seq_length)

    # calculate the attention weights (prob) from alignment scores
    attn_weights = F.softmax(scores, dim=-1)  # (batch_size, seq_length)

    # calculate context vector
    context_vector = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)  # (batch_size, 1, hidden_size*2)
    #print(context_vector.shape, attn_weights.shape, hidden.shape)
    # context_vector: (batch_size, 1, hidden_size * 2)
    # alphas: (batch_size, seq_length)
    return context_vector, attn_weights

class Decoder(nn.Module):
  def __init__(self, output_size, embed_size, hidden_size, n_layers, dropout_rate):
    super(Decoder, self).__init__()
    self.output_size = output_size
    self.embedding = nn.Embedding(output_size, embed_size)
    self.gru = nn.GRU(embed_size + hidden_size*2, hidden_size, n_layers, bidirectional=False, dropout=dropout_rate)
    self.attention = BahdanauAttention(hidden_size)
    self.fc = nn.Linear(embed_size + hidden_size + hidden_size*2, output_size)
    self.dropout = nn.Dropout(dropout_rate)
  def forward(self, input, hidden, encoder_outputs):
    # input: (batch_size, 1)
    # hidden: (batch_size, hidden_size)
    # encoder_outputs: (seq_length, batch_size, hidden_size*2)
    embedded = self.dropout(self.embedding(input.unsqueeze(0))) # (1, batch_size, embed_size)
    context_vector, attn_weights = self.attention(hidden, encoder_outputs)
    # context_vector: (batch_size, 1, hidden_size*2)
    # attn_weights: (batch_size, seq_length)
    context_vector = context_vector.permute(1, 0, 2) # (1, batch_size, hidden_size*2)
    #print(context_vector.shape, embedded.shape)
    #print(input.shape)
    gru_input = torch.cat([embedded, context_vector], dim=2) # (1, batch_size, embed_size + hidden_size*2)
    output, hidden = self.gru(gru_input, hidden.unsqueeze(0))
    # output: (1, batch_size, hidden_size)
    # hidden: (1, batch_size, hidden_size)
    embedded = embedded.squeeze(0) # (batch_size, embed_size)
    output = output.squeeze(0) # (batch_size, hidden_size)
    context_vector = context_vector.squeeze(0) # (batch_size, hidden_size*2)
    pred = self.fc(torch.cat([embedded, output, context_vector], dim=1)) # (batch_size, output_size)
    hidden = hidden.squeeze(0) # (batch_size, hidden_size)
    return pred, hidden, attn_weights

class Seq2Seq(nn.Module):
  def __init__(self, encoder, decoder):
    super(Seq2Seq, self).__init__()
    self.encoder = encoder
    self.decoder = decoder

  def forward(self, source, target, teacher_force_ratio=0.5):
    # source: (source_seq_length, batch_size)
    # target: (target_seq_length, batch_size)
    target_seq_length, batch_size = target.shape
    target_vocab_size = self.decoder.output_size
    outputs = torch.zeros((target_seq_length, batch_size, target_vocab_size)).to(device)
    encoder_outputs, hidden = self.encoder(source)
    # encoder_outputs: (source_seq_length, batch_size, hidden_size*2)
    # hidden: (batch_size, hidden_size)
    input = target[0, :] # (batch_size,) - initial input is sos token
    for t in range(1, target_seq_length):
      output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
      # output: (batch_size, output_size)
      # hidden: (batch_size, hidden_size)
      outputs[t] = output
      top1 = output.argmax(dim=1)
      input = target[t] if random.random() < teacher_force_ratio else top1
    return outputs

In [49]:
def train_fn(model, dataloader, optimizer, criterion, clip=4.0):
  model.train()
  total_loss = 0

  for idx, (src, tgt) in enumerate(tqdm.tqdm(dataloader, total=len(dataloader), position=0, leave=True)):
    src, tgt = src.to(device), tgt.to(device)
    src, tgt = src.T, tgt.T
    # src: (src_seq_length, batch_size)
    # tgt: (tgt_seq_length, batch_size)
    optimizer.zero_grad()
    output = model(src, tgt) # (target_seq_length, batch_size, output_size)
    output = output.view(-1, output.shape[-1])  # (target_seq_length*batch_size, output_size)
    tgt = tgt.contiguous().view(-1)  # (target_seq_length*batch_size)
    loss = criterion(output, tgt)
    total_loss += loss.item()
    loss.backward()
    output = output.argmax(dim=-1)
    nn.utils.clip_grad_norm_(model.parameters(), clip) # gradient clipping
    optimizer.step()
    if (idx + 1) % 50 == 0:
      print(f"loss: {loss.item()}")

  return total_loss / len(dataloader)


In [47]:
# train
INPUT_DIM = len(en_vocab)
OUTPUT_DIM = len(de_vocab)
EMBEDDING_DIM = 256
HIDDEN_DIM = 512
N_LAYERS = 1
DROPOUT = 0.5

enc = Encoder(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, N_LAYERS, DROPOUT).to(device)
dec = Decoder(OUTPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, N_LAYERS, DROPOUT).to(device)
model = Seq2Seq(enc, dec).to(device)

BATCH_SIZE = 64
train_dataset = TranslationDataset(en_ids, de_ids, en_vocab, de_vocab)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=en_vocab[pad_token])

# test train
num_epochs = 10
for epoch in range(num_epochs):
  epoch_loss = train_fn(model, train_loader, optimizer, criterion)
  print(f"epoch {epoch+1}, loss: {epoch_loss}")

 11%|█         | 50/454 [00:38<05:28,  1.23it/s]

loss: 5.3592329025268555


 22%|██▏       | 100/454 [01:15<04:49,  1.22it/s]

loss: 4.905766487121582


 33%|███▎      | 150/454 [01:53<04:24,  1.15it/s]

loss: 4.429150104522705


 44%|████▍     | 200/454 [02:31<03:32,  1.20it/s]

loss: 4.24530029296875


 55%|█████▌    | 250/454 [03:09<02:49,  1.20it/s]

loss: 4.012444496154785


 66%|██████▌   | 300/454 [03:47<02:08,  1.20it/s]

loss: 4.26096248626709


 77%|███████▋  | 350/454 [04:25<01:27,  1.19it/s]

loss: 4.36027193069458


 88%|████████▊ | 400/454 [05:03<00:44,  1.21it/s]

loss: 3.879857063293457


 99%|█████████▉| 450/454 [05:41<00:03,  1.19it/s]

loss: 4.004842758178711


100%|██████████| 454/454 [05:44<00:00,  1.32it/s]


epoch 1, loss: 4.552691787875172


 11%|█         | 50/454 [00:37<05:28,  1.23it/s]

loss: 3.2804207801818848


 22%|██▏       | 100/454 [01:15<04:48,  1.23it/s]

loss: 3.602877378463745


 33%|███▎      | 150/454 [01:53<04:08,  1.22it/s]

loss: 3.4674103260040283


 44%|████▍     | 200/454 [02:31<03:28,  1.22it/s]

loss: 3.404792547225952


 55%|█████▌    | 250/454 [03:09<02:46,  1.23it/s]

loss: 3.5119540691375732


 66%|██████▌   | 300/454 [03:47<02:05,  1.23it/s]

loss: 3.361170530319214


 77%|███████▋  | 350/454 [04:25<01:24,  1.23it/s]

loss: 3.4960684776306152


 88%|████████▊ | 400/454 [05:03<00:43,  1.23it/s]

loss: 3.391343832015991


 99%|█████████▉| 450/454 [05:41<00:03,  1.22it/s]

loss: 3.4719972610473633


100%|██████████| 454/454 [05:43<00:00,  1.32it/s]


epoch 2, loss: 3.4295321324848396


 11%|█         | 50/454 [00:37<05:27,  1.23it/s]

loss: 2.743879556655884


 22%|██▏       | 100/454 [01:15<04:46,  1.23it/s]

loss: 2.7227704524993896


 33%|███▎      | 150/454 [01:54<04:06,  1.23it/s]

loss: 2.830497980117798


 44%|████▍     | 200/454 [02:32<03:25,  1.23it/s]

loss: 3.0341947078704834


 55%|█████▌    | 250/454 [03:10<02:45,  1.23it/s]

loss: 2.934190511703491


 66%|██████▌   | 300/454 [03:48<02:04,  1.23it/s]

loss: 3.1114845275878906


 77%|███████▋  | 350/454 [04:25<01:24,  1.23it/s]

loss: 2.912285327911377


 88%|████████▊ | 400/454 [05:03<00:43,  1.23it/s]

loss: 2.985647201538086


 99%|█████████▉| 450/454 [05:42<00:03,  1.23it/s]

loss: 2.755014419555664


100%|██████████| 454/454 [05:44<00:00,  1.32it/s]


epoch 3, loss: 2.896275753491776


 11%|█         | 50/454 [00:38<05:27,  1.23it/s]

loss: 2.396852731704712


 22%|██▏       | 100/454 [01:16<04:47,  1.23it/s]

loss: 2.3729960918426514


 33%|███▎      | 150/454 [01:54<04:06,  1.23it/s]

loss: 2.7629220485687256


 44%|████▍     | 200/454 [02:32<03:25,  1.23it/s]

loss: 2.590620756149292


 55%|█████▌    | 250/454 [03:10<02:45,  1.23it/s]

loss: 2.3215179443359375


 66%|██████▌   | 300/454 [03:48<02:04,  1.23it/s]

loss: 2.632402181625366


 77%|███████▋  | 350/454 [04:26<01:24,  1.23it/s]

loss: 2.4917421340942383


 88%|████████▊ | 400/454 [05:04<00:43,  1.23it/s]

loss: 2.89280104637146


 99%|█████████▉| 450/454 [05:42<00:03,  1.24it/s]

loss: 2.813589572906494


100%|██████████| 454/454 [05:44<00:00,  1.32it/s]


epoch 4, loss: 2.5683263629543625


 11%|█         | 50/454 [00:38<05:40,  1.19it/s]

loss: 2.1504180431365967


 22%|██▏       | 100/454 [01:16<05:01,  1.17it/s]

loss: 2.479210615158081


 33%|███▎      | 150/454 [01:54<04:14,  1.19it/s]

loss: 2.527387857437134


 44%|████▍     | 200/454 [02:32<03:37,  1.17it/s]

loss: 2.1520700454711914


 55%|█████▌    | 250/454 [03:10<02:53,  1.18it/s]

loss: 2.403449296951294


 66%|██████▌   | 300/454 [03:47<02:10,  1.18it/s]

loss: 2.3139588832855225


 77%|███████▋  | 350/454 [04:25<01:27,  1.19it/s]

loss: 2.368497848510742


 88%|████████▊ | 400/454 [05:03<00:44,  1.20it/s]

loss: 2.47396183013916


 99%|█████████▉| 450/454 [05:41<00:03,  1.19it/s]

loss: 2.5835726261138916


100%|██████████| 454/454 [05:44<00:00,  1.32it/s]


epoch 5, loss: 2.395236929345236


 11%|█         | 50/454 [00:37<05:33,  1.21it/s]

loss: 2.120027780532837


 22%|██▏       | 100/454 [01:15<04:52,  1.21it/s]

loss: 2.54156231880188


 33%|███▎      | 150/454 [01:53<04:12,  1.20it/s]

loss: 2.8230676651000977


 44%|████▍     | 200/454 [02:31<03:30,  1.21it/s]

loss: 2.475785970687866


 55%|█████▌    | 250/454 [03:09<02:48,  1.21it/s]

loss: 2.3957645893096924


 66%|██████▌   | 300/454 [03:47<02:11,  1.17it/s]

loss: 2.3728444576263428


 77%|███████▋  | 350/454 [04:25<01:28,  1.18it/s]

loss: 2.590304374694824


 88%|████████▊ | 400/454 [05:03<00:46,  1.17it/s]

loss: 2.1966516971588135


 99%|█████████▉| 450/454 [05:41<00:03,  1.18it/s]

loss: 2.4649922847747803


100%|██████████| 454/454 [05:44<00:00,  1.32it/s]


epoch 6, loss: 2.31603172618387


 11%|█         | 50/454 [00:38<05:31,  1.22it/s]

loss: 2.3786933422088623


 22%|██▏       | 100/454 [01:16<04:49,  1.22it/s]

loss: 2.285254716873169


 33%|███▎      | 150/454 [01:54<04:08,  1.22it/s]

loss: 2.239077091217041


 44%|████▍     | 200/454 [02:32<03:27,  1.22it/s]

loss: 2.120357036590576


 55%|█████▌    | 250/454 [03:10<02:47,  1.22it/s]

loss: 2.1743321418762207


 66%|██████▌   | 300/454 [03:48<02:07,  1.21it/s]

loss: 2.3024168014526367


 77%|███████▋  | 350/454 [04:26<01:25,  1.22it/s]

loss: 2.2603304386138916


 88%|████████▊ | 400/454 [05:04<00:44,  1.22it/s]

loss: 2.1960062980651855


 99%|█████████▉| 450/454 [05:42<00:03,  1.22it/s]

loss: 2.6153061389923096


100%|██████████| 454/454 [05:44<00:00,  1.32it/s]


epoch 7, loss: 2.220437453181733


 11%|█         | 50/454 [00:37<05:27,  1.23it/s]

loss: 2.1619479656219482


 22%|██▏       | 100/454 [01:15<04:47,  1.23it/s]

loss: 2.379166841506958


 33%|███▎      | 150/454 [01:53<04:06,  1.23it/s]

loss: 2.0445258617401123


 44%|████▍     | 200/454 [02:31<03:25,  1.23it/s]

loss: 2.3862826824188232


 55%|█████▌    | 250/454 [03:09<02:45,  1.23it/s]

loss: 2.195406913757324


 66%|██████▌   | 300/454 [03:47<02:05,  1.23it/s]

loss: 1.9859575033187866


 77%|███████▋  | 350/454 [04:25<01:24,  1.23it/s]

loss: 2.152296781539917


 88%|████████▊ | 400/454 [05:03<00:43,  1.23it/s]

loss: 2.2102954387664795


 99%|█████████▉| 450/454 [05:41<00:03,  1.23it/s]

loss: 2.5169053077697754


100%|██████████| 454/454 [05:44<00:00,  1.32it/s]


epoch 8, loss: 2.1484448345222136


 11%|█         | 50/454 [00:38<05:27,  1.23it/s]

loss: 1.9332878589630127


 22%|██▏       | 100/454 [01:16<04:47,  1.23it/s]

loss: 1.852005958557129


 33%|███▎      | 150/454 [01:54<04:06,  1.23it/s]

loss: 1.9141881465911865


 44%|████▍     | 200/454 [02:32<03:25,  1.24it/s]

loss: 1.9244824647903442


 55%|█████▌    | 250/454 [03:10<02:45,  1.23it/s]

loss: 2.208137035369873


 66%|██████▌   | 300/454 [03:47<02:04,  1.23it/s]

loss: 2.197152614593506


 77%|███████▋  | 350/454 [04:25<01:24,  1.24it/s]

loss: 2.142246723175049


 88%|████████▊ | 400/454 [05:03<00:43,  1.24it/s]

loss: 2.0070643424987793


 99%|█████████▉| 450/454 [05:41<00:03,  1.18it/s]

loss: 2.0493733882904053


100%|██████████| 454/454 [05:44<00:00,  1.32it/s]


epoch 9, loss: 2.1038270569057715


 11%|█         | 50/454 [00:38<05:33,  1.21it/s]

loss: 2.3345351219177246


 22%|██▏       | 100/454 [01:15<04:53,  1.21it/s]

loss: 2.1061770915985107


 33%|███▎      | 150/454 [01:53<04:11,  1.21it/s]

loss: 1.7414138317108154


 44%|████▍     | 200/454 [02:31<03:30,  1.21it/s]

loss: 1.7231324911117554


 55%|█████▌    | 250/454 [03:09<02:49,  1.21it/s]

loss: 1.9701906442642212


 66%|██████▌   | 300/454 [03:47<02:07,  1.21it/s]

loss: 1.7387306690216064


 77%|███████▋  | 350/454 [04:25<01:26,  1.20it/s]

loss: 1.6527422666549683


 88%|████████▊ | 400/454 [05:03<00:44,  1.21it/s]

loss: 2.6762852668762207


 99%|█████████▉| 450/454 [05:41<00:03,  1.21it/s]

loss: 1.8085353374481201


100%|██████████| 454/454 [05:44<00:00,  1.32it/s]

epoch 10, loss: 2.039632056253072





In [48]:
for epoch in range(50):
  epoch_loss = train_fn(model, train_loader, optimizer, criterion)
  print(f"epoch {epoch+1}, loss: {epoch_loss}")

 11%|█         | 50/454 [00:38<05:43,  1.18it/s]

loss: 1.9428837299346924


 22%|██▏       | 100/454 [01:16<04:58,  1.19it/s]

loss: 1.923464298248291


 33%|███▎      | 150/454 [01:54<04:18,  1.18it/s]

loss: 1.7988742589950562


 44%|████▍     | 200/454 [02:32<03:32,  1.20it/s]

loss: 2.1514062881469727


 55%|█████▌    | 250/454 [03:09<02:50,  1.20it/s]

loss: 1.8631951808929443


 66%|██████▌   | 300/454 [03:47<02:06,  1.21it/s]

loss: 2.0166711807250977


 77%|███████▋  | 350/454 [04:26<01:28,  1.18it/s]

loss: 1.9556920528411865


 88%|████████▊ | 400/454 [05:03<00:45,  1.18it/s]

loss: 1.8254420757293701


 99%|█████████▉| 450/454 [05:41<00:03,  1.19it/s]

loss: 2.1927342414855957


100%|██████████| 454/454 [05:44<00:00,  1.32it/s]


epoch 1, loss: 1.99160682498621


 11%|█         | 50/454 [00:38<05:32,  1.22it/s]

loss: 1.734790325164795


 22%|██▏       | 100/454 [01:16<04:51,  1.22it/s]

loss: 1.7234331369400024


 33%|███▎      | 150/454 [01:54<04:10,  1.22it/s]

loss: 2.029336452484131


 44%|████▍     | 200/454 [02:32<03:27,  1.23it/s]

loss: 1.780863881111145


 55%|█████▌    | 250/454 [03:10<02:47,  1.22it/s]

loss: 2.135040283203125


 66%|██████▌   | 300/454 [03:48<02:06,  1.22it/s]

loss: 1.9013608694076538


 77%|███████▋  | 350/454 [04:26<01:25,  1.22it/s]

loss: 2.0501534938812256


 88%|████████▊ | 400/454 [05:04<00:44,  1.22it/s]

loss: 2.1085972785949707


 99%|█████████▉| 450/454 [05:42<00:03,  1.22it/s]

loss: 2.1624302864074707


100%|██████████| 454/454 [05:44<00:00,  1.32it/s]


epoch 2, loss: 1.9534944962825018


 11%|█         | 50/454 [00:38<05:29,  1.23it/s]

loss: 1.7323077917099


 22%|██▏       | 100/454 [01:16<04:47,  1.23it/s]

loss: 2.1780900955200195


 33%|███▎      | 150/454 [01:54<04:07,  1.23it/s]

loss: 1.7964822053909302


 44%|████▍     | 200/454 [02:32<03:26,  1.23it/s]

loss: 1.7324222326278687


 55%|█████▌    | 250/454 [03:09<02:45,  1.23it/s]

loss: 1.9255964756011963


 66%|██████▌   | 300/454 [03:47<02:05,  1.23it/s]

loss: 1.7672526836395264


 77%|███████▋  | 350/454 [04:26<01:25,  1.21it/s]

loss: 2.050554037094116


 88%|████████▊ | 400/454 [05:04<00:43,  1.23it/s]

loss: 2.069547414779663


 99%|█████████▉| 450/454 [05:42<00:03,  1.23it/s]

loss: 1.993813157081604


100%|██████████| 454/454 [05:44<00:00,  1.32it/s]


epoch 3, loss: 1.9187746985368266


 11%|█         | 50/454 [00:38<05:28,  1.23it/s]

loss: 2.4158177375793457


 22%|██▏       | 100/454 [01:16<04:47,  1.23it/s]

loss: 2.0530471801757812


 33%|███▎      | 150/454 [01:54<04:06,  1.23it/s]

loss: 1.8189023733139038


 44%|████▍     | 200/454 [02:32<03:29,  1.21it/s]

loss: 2.089662790298462


 55%|█████▌    | 250/454 [03:10<02:45,  1.23it/s]

loss: 1.6320594549179077


 66%|██████▌   | 300/454 [03:48<02:05,  1.23it/s]

loss: 1.9364184141159058


 77%|███████▋  | 350/454 [04:26<01:24,  1.23it/s]

loss: 1.8672056198120117


 88%|████████▊ | 400/454 [05:04<00:43,  1.24it/s]

loss: 1.8595952987670898


 99%|█████████▉| 450/454 [05:42<00:03,  1.22it/s]

loss: 1.7383493185043335


100%|██████████| 454/454 [05:44<00:00,  1.32it/s]


epoch 4, loss: 1.901192499152364


 11%|█         | 50/454 [00:38<05:39,  1.19it/s]

loss: 1.7495311498641968


 22%|██▏       | 100/454 [01:16<04:55,  1.20it/s]

loss: 1.7258390188217163


 33%|███▎      | 150/454 [01:54<04:15,  1.19it/s]

loss: 1.9545162916183472


 44%|████▍     | 200/454 [02:32<03:32,  1.19it/s]

loss: 1.6322277784347534


 55%|█████▌    | 250/454 [03:10<02:54,  1.17it/s]

loss: 1.7149498462677002


 66%|██████▌   | 300/454 [03:48<02:08,  1.20it/s]

loss: 1.8351976871490479


 77%|███████▋  | 350/454 [04:26<01:27,  1.19it/s]

loss: 1.8145312070846558


 88%|████████▊ | 400/454 [05:04<00:45,  1.18it/s]

loss: 1.8183585405349731


 99%|█████████▉| 450/454 [05:42<00:03,  1.18it/s]

loss: 1.8666586875915527


100%|██████████| 454/454 [05:44<00:00,  1.32it/s]


epoch 5, loss: 1.8641643637077399


 11%|█         | 50/454 [00:38<05:32,  1.22it/s]

loss: 1.8200966119766235


 22%|██▏       | 100/454 [01:15<04:51,  1.21it/s]

loss: 2.5167882442474365


 33%|███▎      | 150/454 [01:53<04:09,  1.22it/s]

loss: 1.9314669370651245


 44%|████▍     | 200/454 [02:31<03:29,  1.21it/s]

loss: 1.6707782745361328


 55%|█████▌    | 250/454 [03:09<02:47,  1.22it/s]

loss: 1.9785746335983276


 66%|██████▌   | 300/454 [03:47<02:06,  1.22it/s]

loss: 1.7693952322006226


 77%|███████▋  | 350/454 [04:25<01:25,  1.22it/s]

loss: 2.3996760845184326


 88%|████████▊ | 400/454 [05:03<00:44,  1.22it/s]

loss: 1.9416013956069946


 99%|█████████▉| 450/454 [05:41<00:03,  1.21it/s]

loss: 1.807715654373169


100%|██████████| 454/454 [05:44<00:00,  1.32it/s]


epoch 6, loss: 1.8368519765164883


 11%|█         | 50/454 [00:38<05:29,  1.23it/s]

loss: 1.7711405754089355


 22%|██▏       | 100/454 [01:16<04:48,  1.23it/s]

loss: 1.6040189266204834


 33%|███▎      | 150/454 [01:54<04:08,  1.22it/s]

loss: 1.7372856140136719


 44%|████▍     | 200/454 [02:32<03:27,  1.23it/s]

loss: 1.9531680345535278


 55%|█████▌    | 250/454 [03:10<02:46,  1.22it/s]

loss: 1.6552166938781738


 66%|██████▌   | 300/454 [03:48<02:05,  1.22it/s]

loss: 1.7213796377182007


 77%|███████▋  | 350/454 [04:26<01:25,  1.22it/s]

loss: 1.6989742517471313


 88%|████████▊ | 400/454 [05:04<00:44,  1.23it/s]

loss: 1.837982177734375


 96%|█████████▋| 438/454 [05:33<00:12,  1.31it/s]


KeyboardInterrupt: 