In [290]:
import torch
import torch.nn as nn
import numpy as np
import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

#### Hyper-params

In [234]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [235]:
device

device(type='cuda', index=0)

#### Text Preprocessing

In [236]:
# read the text file
text = open("austen-emma.txt", "r").read()

In [238]:
## tokenization
chars = tuple(set(text))
idx2char = dict(enumerate(chars))
char2idx = dict(zip(idx2char.values(), idx2char.keys()))
#  encode the text into tokens
encoded = np.array([char2idx[ch] for ch in text])

In [300]:
## Convert the tokens into one-hot encoding
def one_hot_encoding(data, n_labels):
    
#     print(encoding.shape)
    # zeros matrix of size mx/n  where m --> size of encoding and n--> size of 
    one_hot = np.zeros((data.shape[0], n_labels), dtype=np.float32)
    
    one_hot[np.arange(one_hot.shape[0]), data.flatten()] = 1.

    return one_hot
#     # file the token with 1s
#     for i, token in enumerate(encoding):
#         one_hot[i][token] = 1
#     return one_hot
    

In [301]:
one_hot_encoding(np.array([1, 2, 3]), 9).shape

(3, 9)

In [321]:
def load_data(encoded, seq_length = 100):
    
    # keep only valid charas
    sequences = len(encoded)//(seq_len+1)
    encoded = encoded[:sequences*(seq_len+1)]
    
    x = np.empty((sequences, seq_len, len(chars)), dtype=np.float32)
    y = np.empty((sequences, seq_len), dtype=np.float32)
    
    print(x.shape, y.shape)
    
    for i, current in enumerate(range(0, encoded.shape[0], seq_len + 1)):
        x[i] = one_hot_encoding(np.array(encoded[current:current+seq_len]), n_labels=len(chars))
        y[i] = np.array(encoded[current+1:current+seq_len+1])
    return x, y
    

In [322]:
x, y = load_data(encoded)

(13647, 64, 77) (13647, 64)


In [323]:
print(x.shape, y.shape)
print(x.dtype, y.dtype)

(13647, 64, 77) (13647, 64)
float32 float32


In [324]:
class CustomDataset(Dataset):
    
    def __init__(self):
        self.x, self.y = load_data(encoded, seq_len)
        self.len = x.shape[0]
    
    def __len__(self):
        return self.len
    
    def __getitem__(self, index):
        return self.x[index], self.y[index]
    

In [325]:
# data 
seq_len = 64
batch_size = 64

In [326]:
dataset = CustomDataset()
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

(13647, 64, 77) (13647, 64)


In [327]:
# see the text of each seq 
x, y = next(iter(dataloader))
s = ""
for each in x[0]:
    s = s +""+ idx2char[torch.argmax(each).item()]
s

'od opinion.--He had been sitting with them\nhalf an hour, she fou'

### RNN Model

In [388]:
class RNN(nn.Module):
    
    def __init__(self, input_size, output_size, hidden_dim=64, n_layers=1,dropout=0.20):
        
        super(RNN, self).__init__()
        self.n_hidden = hidden_dim
        self.lstm = nn.LSTM(input_size, hidden_dim, n_layers, batch_first=True)
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(in_features=hidden_dim, out_features=output_size)
        
    def forward(self, x, hidden):
        
        lstm_out, hidden = self.lstm(x, hidden)
        out = self.dropout(lstm_out)
        out = out.contiguous().view(-1, self.n_hidden)
        out = self.fc(out)
        return out, hidden    

#### Params configuration model and training

In [401]:
# model
input_size = len(chars)
hidden_dim = 64
n_layers = 1
output_size = len(chars)

epochs = 50
lr = 1e-2

writer = SummaryWriter(log_dir="./runs")
# training

In [402]:
# create the model
model = RNN(input_size=input_size, output_size=output_size).to(device)

In [403]:
### Optimizer and Loss
optmizer = torch.optim.Adam(params=model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

In [404]:
### Configurations

epoch_progress = tqdm.tqdm(total=epochs, desc="Epoch", position=0)

steps = 0

for epoch in range(epochs):
    
    batch_progress = tqdm.tqdm(total=len(dataloader), desc="Batch", position=0)
    epoch_loss = []

    
    for i, batch in enumerate(dataloader):
        x, y = batch
        
        x = x.to(device)
        y = y.to(device)
        
        
        outputs, _ = model(x, None)
        
#         print(x.size(), y.size(), outputs.size(), y.view(batch_size, seq_len).size(0))
        
        loss = criterion(outputs, y.view(-1).long())
        
        # backpropagate the loss update the params
        optmizer.zero_grad()
        loss.backward()
        optmizer.step()
        
        if steps%100==0:
            print(f"Epoch {epoch} | Steps {steps} | Loss = {loss.item()}")
        
        writer.add_scalar("Step Wise Loss", loss.item(), steps)
        epoch_loss.append(loss.item())
        
        batch_progress.update(1)
        steps += 1
    
    avg_loss = sum(epoch_loss)/len(epoch_loss)
    writer.add_scalar("Epoch Loss", avg_loss, epoch)
    if epoch%(epochs//10)==0:
        print(f"Epoch = {epoch} | Avg Loss = {avg_loss}")
    epoch_progress.update(1)


Batch:  12%|█▏        | 26/214 [00:00<00:01, 119.85it/s]

Epoch 0 | Steps 0 | Loss = 4.34761905670166


Batch:  55%|█████▌    | 118/214 [00:00<00:00, 134.54it/s]

Epoch 0 | Steps 100 | Loss = 2.3699333667755127


Batch: 100%|██████████| 214/214 [00:01<00:00, 147.09it/s]
Batch:   7%|▋         | 15/214 [00:00<00:01, 149.52it/s]

Epoch 0 | Steps 200 | Loss = 2.149277448654175
Epoch = 0 | Avg Loss = 2.5299650374974045


Batch:  50%|█████     | 108/214 [00:00<00:00, 151.06it/s]

Epoch 1 | Steps 300 | Loss = 2.0315728187561035


Batch:  98%|█████████▊| 210/214 [00:01<00:00, 145.54it/s]

Epoch 1 | Steps 400 | Loss = 1.9651515483856201


Batch: 100%|██████████| 214/214 [00:01<00:00, 152.21it/s]
Batch:  44%|████▍     | 94/214 [00:00<00:00, 152.41it/s]

Epoch 2 | Steps 500 | Loss = 1.8696117401123047


Batch:  96%|█████████▌| 205/214 [00:01<00:00, 155.08it/s]

Epoch 2 | Steps 600 | Loss = 1.8664608001708984


Batch: 100%|██████████| 214/214 [00:01<00:00, 152.30it/s]
Batch:  38%|███▊      | 82/214 [00:00<00:00, 161.44it/s]

Epoch 3 | Steps 700 | Loss = 1.7709965705871582


Batch:  84%|████████▎ | 179/214 [00:01<00:00, 162.27it/s]

Epoch 3 | Steps 800 | Loss = 1.8146792650222778


Batch: 100%|██████████| 214/214 [00:01<00:00, 157.91it/s]
Batch:  30%|███       | 65/214 [00:00<00:00, 162.09it/s]

Epoch 4 | Steps 900 | Loss = 1.8145389556884766


Batch:  77%|███████▋  | 165/214 [00:01<00:00, 165.60it/s]

Epoch 4 | Steps 1000 | Loss = 1.725896954536438


Batch: 100%|██████████| 214/214 [00:01<00:00, 161.50it/s]
Batch:  23%|██▎       | 50/214 [00:00<00:01, 162.99it/s]

Epoch 5 | Steps 1100 | Loss = 1.7471871376037598


Batch:  69%|██████▊   | 147/214 [00:00<00:00, 154.13it/s]

Epoch 5 | Steps 1200 | Loss = 1.706241488456726


Batch: 100%|██████████| 214/214 [00:01<00:00, 158.24it/s]
Batch:   3%|▎         | 6/214 [00:00<00:03, 52.20it/s]

Epoch = 5 | Avg Loss = 1.7079572215258518


Batch:  16%|█▌        | 34/214 [00:00<00:02, 67.74it/s]

Epoch 6 | Steps 1300 | Loss = 1.707268476486206


Batch:  58%|█████▊    | 124/214 [00:01<00:00, 123.35it/s]

Epoch 6 | Steps 1400 | Loss = 1.660194993019104


Batch:  14%|█▍        | 30/214 [00:00<00:01, 141.76it/s]]

Epoch 7 | Steps 1500 | Loss = 1.7051069736480713


Batch:  58%|█████▊    | 124/214 [00:00<00:00, 147.39it/s]

Epoch 7 | Steps 1600 | Loss = 1.6928067207336426


Batch: 100%|██████████| 214/214 [00:01<00:00, 149.36it/s]
Batch:   7%|▋         | 14/214 [00:00<00:01, 132.71it/s]

Epoch 7 | Steps 1700 | Loss = 1.6995114088058472


Batch:  50%|████▉     | 106/214 [00:00<00:00, 146.20it/s]

Epoch 8 | Steps 1800 | Loss = 1.6282737255096436


Batch: 100%|██████████| 214/214 [00:01<00:00, 148.12it/s]


Epoch 8 | Steps 1900 | Loss = 1.6271501779556274


Batch:  42%|████▏     | 90/214 [00:00<00:00, 145.69it/s]

Epoch 9 | Steps 2000 | Loss = 1.6421960592269897


Batch:  92%|█████████▏| 196/214 [00:01<00:00, 134.92it/s]

Epoch 9 | Steps 2100 | Loss = 1.6057405471801758


Batch: 100%|██████████| 214/214 [00:01<00:00, 131.63it/s]
Batch:  36%|███▌      | 77/214 [00:00<00:00, 146.89it/s]

Epoch 10 | Steps 2200 | Loss = 1.6051030158996582


Batch:  86%|████████▋ | 185/214 [00:01<00:00, 127.28it/s]

Epoch 10 | Steps 2300 | Loss = 1.5750596523284912


Batch: 100%|██████████| 214/214 [00:01<00:00, 125.88it/s]
Batch:   8%|▊         | 17/214 [00:00<00:01, 166.02it/s]

Epoch = 10 | Avg Loss = 1.5994071269703802


Batch:  32%|███▏      | 68/214 [00:00<00:01, 105.11it/s]

Epoch 11 | Steps 2400 | Loss = 1.5656623840332031


Batch:  80%|████████  | 172/214 [00:01<00:00, 160.21it/s]

Epoch 11 | Steps 2500 | Loss = 1.583791732788086


Batch: 100%|██████████| 214/214 [00:01<00:00, 145.47it/s]
Batch:  25%|██▍       | 53/214 [00:00<00:00, 170.88it/s]

Epoch 12 | Steps 2600 | Loss = 1.5623761415481567


Batch:  72%|███████▏  | 154/214 [00:00<00:00, 166.31it/s]

Epoch 12 | Steps 2700 | Loss = 1.557012677192688


Batch:  24%|██▍       | 51/214 [00:00<00:00, 167.39it/s]]

Epoch 13 | Steps 2800 | Loss = 1.584449291229248


Batch:  70%|██████▉   | 149/214 [00:00<00:00, 161.13it/s]

Epoch 13 | Steps 2900 | Loss = 1.571823239326477


Batch: 100%|██████████| 214/214 [00:01<00:00, 161.91it/s]
Batch:  15%|█▌        | 33/214 [00:00<00:01, 162.26it/s]

Epoch 14 | Steps 3000 | Loss = 1.560146450996399


Batch:  62%|██████▏   | 133/214 [00:00<00:00, 165.54it/s]

Epoch 14 | Steps 3100 | Loss = 1.56004798412323


Batch: 100%|██████████| 214/214 [00:01<00:00, 163.09it/s]
Batch:   8%|▊         | 17/214 [00:00<00:01, 165.77it/s]

Epoch 14 | Steps 3200 | Loss = 1.578438401222229


Batch:  52%|█████▏    | 112/214 [00:00<00:00, 160.15it/s]

Epoch 15 | Steps 3300 | Loss = 1.5770460367202759


Batch: 100%|██████████| 214/214 [00:01<00:00, 158.06it/s]

Epoch 15 | Steps 3400 | Loss = 1.5137784481048584
Epoch = 15 | Avg Loss = 1.552286943542623



Batch:  47%|████▋     | 100/214 [00:00<00:00, 162.24it/s]

Epoch 16 | Steps 3500 | Loss = 1.5033999681472778


Batch:  93%|█████████▎| 199/214 [00:01<00:00, 155.87it/s]

Epoch 16 | Steps 3600 | Loss = 1.5258383750915527


Batch: 100%|██████████| 214/214 [00:01<00:00, 158.67it/s]
Batch:  38%|███▊      | 81/214 [00:00<00:00, 155.82it/s]

Epoch 17 | Steps 3700 | Loss = 1.5249446630477905


Batch:  84%|████████▍ | 180/214 [00:01<00:00, 163.63it/s]

Epoch 17 | Steps 3800 | Loss = 1.4950004816055298


Batch: 100%|██████████| 214/214 [00:01<00:00, 158.99it/s]
Batch:  38%|███▊      | 82/214 [00:00<00:00, 161.32it/s]

Epoch 18 | Steps 3900 | Loss = 1.5067440271377563


Batch:  78%|███████▊  | 166/214 [00:01<00:00, 164.69it/s]

Epoch 18 | Steps 4000 | Loss = 1.5183757543563843


Batch: 100%|██████████| 214/214 [00:01<00:00, 156.16it/s]
Batch:  25%|██▍       | 53/214 [00:00<00:00, 171.70it/s]

Epoch 19 | Steps 4100 | Loss = 1.5259908437728882


Batch:  74%|███████▍  | 158/214 [00:00<00:00, 172.34it/s]

Epoch 19 | Steps 4200 | Loss = 1.541332483291626


Batch: 100%|██████████| 214/214 [00:01<00:00, 168.59it/s]
Batch:  23%|██▎       | 50/214 [00:00<00:01, 163.10it/s]]

Epoch 20 | Steps 4300 | Loss = 1.5271075963974


Batch:  70%|███████   | 150/214 [00:00<00:00, 163.82it/s]

Epoch 20 | Steps 4400 | Loss = 1.5529139041900635


Batch:   8%|▊         | 17/214 [00:00<00:01, 165.57it/s]]

Epoch = 20 | Avg Loss = 1.5278094539018434
Epoch 21 | Steps 4500 | Loss = 1.4745237827301025


Batch:  62%|██████▏   | 133/214 [00:00<00:00, 163.28it/s]

Epoch 21 | Steps 4600 | Loss = 1.554053544998169


Batch: 100%|██████████| 214/214 [00:01<00:00, 161.09it/s]
Batch:   8%|▊         | 17/214 [00:00<00:01, 163.01it/s]

Epoch 21 | Steps 4700 | Loss = 1.5248442888259888


Batch:  53%|█████▎    | 113/214 [00:00<00:00, 158.35it/s]

Epoch 22 | Steps 4800 | Loss = 1.5697635412216187


Batch: 100%|██████████| 214/214 [00:01<00:00, 155.81it/s]


Epoch 22 | Steps 4900 | Loss = 1.5106953382492065


Batch:  52%|█████▏    | 112/214 [00:00<00:00, 153.79it/s]

Epoch 23 | Steps 5000 | Loss = 1.511215329170227


Batch:  98%|█████████▊| 209/214 [00:01<00:00, 158.27it/s]

Epoch 23 | Steps 5100 | Loss = 1.5033361911773682


Batch: 100%|██████████| 214/214 [00:01<00:00, 155.02it/s]
Batch:  45%|████▍     | 96/214 [00:00<00:00, 150.08it/s]

Epoch 24 | Steps 5200 | Loss = 1.5170753002166748


Batch:  91%|█████████ | 195/214 [00:01<00:00, 158.33it/s]

Epoch 24 | Steps 5300 | Loss = 1.4870576858520508


Batch: 100%|██████████| 214/214 [00:01<00:00, 154.51it/s]
Batch:  38%|███▊      | 82/214 [00:00<00:00, 160.02it/s]

Epoch 25 | Steps 5400 | Loss = 1.4744057655334473


Batch:  85%|████████▍ | 181/214 [00:01<00:00, 159.55it/s]

Epoch 25 | Steps 5500 | Loss = 1.485007405281067


Batch: 100%|██████████| 214/214 [00:01<00:00, 146.62it/s]
Batch:   7%|▋         | 16/214 [00:00<00:01, 153.88it/s]

Epoch = 25 | Avg Loss = 1.5126137655472087


Batch:  26%|██▌       | 56/214 [00:00<00:01, 129.56it/s]

Epoch 26 | Steps 5600 | Loss = 1.5227636098861694


Batch:  70%|███████   | 150/214 [00:01<00:00, 104.69it/s]

Epoch 26 | Steps 5700 | Loss = 1.5560731887817383


Batch: 100%|██████████| 214/214 [00:01<00:00, 122.42it/s]
Batch:  24%|██▍       | 51/214 [00:00<00:01, 88.25it/s]s]

Epoch 27 | Steps 5800 | Loss = 1.469330072402954


Batch:  68%|██████▊   | 145/214 [00:01<00:00, 108.80it/s]

Epoch 27 | Steps 5900 | Loss = 1.5038747787475586


Batch:  15%|█▌        | 33/214 [00:00<00:01, 159.50it/s]]

Epoch 28 | Steps 6000 | Loss = 1.4731800556182861


Batch:  62%|██████▏   | 132/214 [00:00<00:00, 160.60it/s]

Epoch 28 | Steps 6100 | Loss = 1.5133757591247559


Batch: 100%|██████████| 214/214 [00:01<00:00, 159.18it/s]
Batch:   8%|▊         | 17/214 [00:00<00:01, 164.43it/s]

Epoch 28 | Steps 6200 | Loss = 1.5759392976760864


Batch:  54%|█████▍    | 116/214 [00:00<00:00, 125.93it/s]

Epoch 29 | Steps 6300 | Loss = 1.5084898471832275


Batch:  92%|█████████▏| 197/214 [00:01<00:00, 105.72it/s]

Epoch 29 | Steps 6400 | Loss = 1.4886040687561035


Batch: 100%|██████████| 214/214 [00:02<00:00, 100.80it/s]
Batch:  44%|████▍     | 95/214 [00:01<00:01, 59.98it/s]

Epoch 30 | Steps 6500 | Loss = 1.5046442747116089


Batch:  96%|█████████▌| 205/214 [00:02<00:00, 132.99it/s]

Epoch 30 | Steps 6600 | Loss = 1.478358507156372


Batch: 100%|██████████| 214/214 [00:02<00:00, 82.30it/s] 
Batch:   7%|▋         | 15/214 [00:00<00:01, 141.47it/s]

Epoch = 30 | Avg Loss = 1.5019053254172066


Batch:  44%|████▍     | 94/214 [00:00<00:01, 113.66it/s]

Epoch 31 | Steps 6700 | Loss = 1.4949283599853516


Batch:  82%|████████▏ | 175/214 [00:01<00:00, 87.21it/s] 

Epoch 31 | Steps 6800 | Loss = 1.5049445629119873


Batch: 100%|██████████| 214/214 [00:02<00:00, 105.42it/s]
Batch:  38%|███▊      | 82/214 [00:00<00:00, 153.92it/s]]

Epoch 32 | Steps 6900 | Loss = 1.490114688873291


Batch:  79%|███████▉  | 169/214 [00:01<00:00, 128.10it/s]

Epoch 32 | Steps 7000 | Loss = 1.533011555671692


Batch:  21%|██        | 44/214 [00:01<00:04, 42.46it/s]] 

Epoch 33 | Steps 7100 | Loss = 1.446194052696228


Batch:  79%|███████▉  | 169/214 [00:02<00:00, 119.96it/s]

Epoch 33 | Steps 7200 | Loss = 1.468514323234558


Batch: 100%|██████████| 214/214 [00:02<00:00, 81.08it/s] 
Batch:  18%|█▊        | 39/214 [00:00<00:01, 91.70it/s] 

Epoch 34 | Steps 7300 | Loss = 1.5031628608703613


Batch:  69%|██████▊   | 147/214 [00:01<00:00, 145.52it/s]

Epoch 34 | Steps 7400 | Loss = 1.4934322834014893


Batch: 100%|██████████| 214/214 [00:01<00:00, 136.70it/s]
Batch:  15%|█▌        | 33/214 [00:00<00:01, 159.33it/s]

Epoch 35 | Steps 7500 | Loss = 1.498509407043457


Batch:  59%|█████▉    | 127/214 [00:00<00:00, 150.35it/s]

Epoch 35 | Steps 7600 | Loss = 1.5073248147964478


Batch: 100%|██████████| 214/214 [00:01<00:00, 154.85it/s]
Batch:   8%|▊         | 17/214 [00:00<00:01, 165.34it/s]

Epoch 35 | Steps 7700 | Loss = 1.4655438661575317
Epoch = 35 | Avg Loss = 1.4946674136357887


Batch:  54%|█████▎    | 115/214 [00:00<00:00, 163.59it/s]

Epoch 36 | Steps 7800 | Loss = 1.4345686435699463


Batch: 100%|██████████| 214/214 [00:01<00:00, 146.90it/s]


Epoch 36 | Steps 7900 | Loss = 1.5165880918502808


Batch:  42%|████▏     | 90/214 [00:00<00:00, 136.92it/s]

Epoch 37 | Steps 8000 | Loss = 1.4265434741973877


Batch:  89%|████████▉ | 190/214 [00:02<00:00, 44.81it/s]]

Epoch 37 | Steps 8100 | Loss = 1.521941900253296


Batch:  40%|███▉      | 85/214 [00:01<00:01, 68.29it/s]]

Epoch 38 | Steps 8200 | Loss = 1.4691271781921387


Batch:  87%|████████▋ | 186/214 [00:02<00:00, 142.10it/s]

Epoch 38 | Steps 8300 | Loss = 1.5251168012619019


Batch: 100%|██████████| 214/214 [00:02<00:00, 95.15it/s] 
Batch:  37%|███▋      | 79/214 [00:00<00:00, 147.30it/s]

Epoch 39 | Steps 8400 | Loss = 1.463073492050171


Batch:  87%|████████▋ | 186/214 [00:01<00:00, 168.71it/s]

Epoch 39 | Steps 8500 | Loss = 1.506534457206726


Batch: 100%|██████████| 214/214 [00:01<00:00, 162.12it/s]
Batch:  29%|██▊       | 61/214 [00:00<00:01, 82.97it/s]

Epoch 40 | Steps 8600 | Loss = 1.54047691822052


Batch:  72%|███████▏  | 154/214 [00:01<00:00, 126.14it/s]

Epoch 40 | Steps 8700 | Loss = 1.5041189193725586


Batch: 100%|██████████| 214/214 [00:01<00:00, 115.52it/s]
Batch:   8%|▊         | 17/214 [00:00<00:01, 160.90it/s]

Epoch = 40 | Avg Loss = 1.4887615419993891
Epoch 41 | Steps 8800 | Loss = 1.5120134353637695


Batch:  67%|██████▋   | 144/214 [00:00<00:00, 153.66it/s]

Epoch 41 | Steps 8900 | Loss = 1.4357918500900269


Batch: 100%|██████████| 214/214 [00:01<00:00, 150.03it/s]
Batch:  16%|█▌        | 34/214 [00:00<00:01, 165.81it/s]

Epoch 42 | Steps 9000 | Loss = 1.5419437885284424


Batch:  63%|██████▎   | 134/214 [00:00<00:00, 148.57it/s]

Epoch 42 | Steps 9100 | Loss = 1.4741053581237793


Batch: 100%|██████████| 214/214 [00:01<00:00, 161.61it/s]
Batch:   8%|▊         | 17/214 [00:00<00:01, 165.77it/s]

Epoch 42 | Steps 9200 | Loss = 1.5080243349075317


Batch:  55%|█████▌    | 118/214 [00:01<00:00, 96.89it/s]

Epoch 43 | Steps 9300 | Loss = 1.5050454139709473


Batch:   8%|▊         | 17/214 [00:00<00:01, 166.17it/s] 

Epoch 43 | Steps 9400 | Loss = 1.4560378789901733


Batch:  43%|████▎     | 93/214 [00:01<00:02, 55.52it/s] 

Epoch 44 | Steps 9500 | Loss = 1.509315013885498


Batch: 100%|██████████| 214/214 [00:02<00:00, 78.76it/s]


Epoch 44 | Steps 9600 | Loss = 1.5296515226364136


Batch:  48%|████▊     | 102/214 [00:00<00:00, 166.08it/s]

Epoch 45 | Steps 9700 | Loss = 1.4834978580474854


Batch:  86%|████████▋ | 185/214 [00:01<00:00, 157.31it/s]

Epoch 45 | Steps 9800 | Loss = 1.4893994331359863


Batch: 100%|██████████| 214/214 [00:01<00:00, 163.53it/s]
Batch:   8%|▊         | 17/214 [00:00<00:01, 165.71it/s]

Epoch = 45 | Avg Loss = 1.4850217883831986


Batch:  40%|███▉      | 85/214 [00:00<00:00, 164.77it/s]

Epoch 46 | Steps 9900 | Loss = 1.505232572555542


Batch:  83%|████████▎ | 178/214 [00:01<00:00, 145.52it/s]

Epoch 46 | Steps 10000 | Loss = 1.5124047994613647


Batch: 100%|██████████| 214/214 [00:01<00:00, 143.16it/s]
Batch:  32%|███▏      | 68/214 [00:00<00:00, 165.25it/s]

Epoch 47 | Steps 10100 | Loss = 1.4954311847686768


Batch:  79%|███████▉  | 170/214 [00:01<00:00, 166.45it/s]

Epoch 47 | Steps 10200 | Loss = 1.48506498336792


Batch: 100%|██████████| 214/214 [00:01<00:00, 165.88it/s]
Batch:  22%|██▏       | 47/214 [00:00<00:01, 104.79it/s]

Epoch 48 | Steps 10300 | Loss = 1.460173487663269


Batch:  70%|██████▉   | 149/214 [00:01<00:00, 154.67it/s]

Epoch 48 | Steps 10400 | Loss = 1.5039653778076172


Batch: 100%|██████████| 214/214 [00:01<00:00, 150.42it/s]
Batch:  13%|█▎        | 27/214 [00:00<00:01, 129.78it/s]

Epoch 49 | Steps 10500 | Loss = 1.4996896982192993


Batch:  67%|██████▋   | 143/214 [00:01<00:00, 149.40it/s]

Epoch 49 | Steps 10600 | Loss = 1.427970290184021


Batch: 100%|██████████| 214/214 [00:20<00:00, 160.55it/s]

#### Save the model checkpoint

In [None]:
# save the model if you want