# Text Gen with LSTM

In [3]:
import sys
import numpy as np
import torch 
import torch.nn as nn

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
# load data as ASCII
# and lower all
filename = "data/wonderland.txt"
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()

In [6]:
# Summarize the dataset
# set to remove duplicates
# sorted to sort the characters
# preapare a dictionary for mapping characters to integers
chars = sorted(list(set(raw_text)))
chars_to_int = dict((c, i) for i, c in enumerate(chars))
int_to_chars = dict((i, c) for i, c in enumerate(chars))
n_chars = len(raw_text)
n_vocab = len(chars)
print("Total Characters: ", n_chars)
print("Total Vocab: ", n_vocab)

Total Characters:  144512
Total Vocab:  45


In [7]:
# preapre data_length to gen the next character
seq_length = 50
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):
    seq_in = raw_text[i:i + seq_length]
    seq_out = raw_text[i + seq_length]
    dataX.append([chars_to_int[char] for char in seq_in])
    dataY.append(chars_to_int[seq_out])

n_patterns = len(dataX)
print("Total Patterns: ", n_patterns)

Total Patterns:  144462


In [8]:
# change the shape of the data to format for LSTM
# [samples, time steps, features]
# normalize the data
X = np.reshape(dataX, (n_patterns, seq_length, 1))
X = X / float(n_vocab)
y = dataY

In [9]:
class TextModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.lstm = nn.LSTM(1, 256, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Linear(256, 256),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(256, n_vocab),
        )
    def forward(self, x):
        output, (h, c) = self.lstm(x)
        # h is the hidden state of the LSTM
        # c is the cell state of the LSTM
        x = self.classifier(h.squeeze(0))
        return x

In [10]:
x = torch.randn(2, 50, 1).to(device)
net = TextModel().to(device)


In [11]:
from tqdm import tqdm
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader

In [12]:
# It is designed to store and provide training data for a sequence-based task, such as language modeling or text generation
class TextDataset(Dataset):
    def __init__(self, data, next_chars):
        super().__init__()

        self.data = data
        self.next_chars = next_chars

    def __getitem__(self, index):
        return torch.tensor(self.data[index], dtype=torch.float32), self.next_chars[index]

    def __len__(self):
        return len(self.data)

In [13]:
text_dataset = TextDataset(X, y)
text_loader = DataLoader(
    dataset=text_dataset,
    shuffle=True,
    batch_size=32,
    num_workers=0
)


In [14]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

In [15]:
from tqdm import tqdm


num_epochs = 500
best_loss = 99999999
for epoch in range(num_epochs):
    train_tqdm = tqdm(enumerate(text_loader), total=len(text_loader))
    total_loss = 0
    total_correct = 0
    total_samples = 0
    for i, data in train_tqdm:
        # Separete input and output
        inputs, labels = data
        # feed forward

        # move data to device
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = net(inputs)

        # loss calculation
        loss = loss_fn(outputs, labels)

        # reset gradient
        optimizer.zero_grad()
        # calculate gradient
        loss.backward()

        # update weight
        optimizer.step()

        # calculate total loss
        total_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        total_correct += (predicted == labels).sum().item()
        total_samples += labels.size(0)
        # update progress and show loss
        train_tqdm.set_description(f"Epoch {epoch}: Total loss: {total_loss/(i + 1)}, Accuracy: {total_correct / total_samples:.4f}")

    train_loss = total_loss / len(text_loader)

    if train_loss <= best_loss:
        print(f"Save best model with loss = {train_loss}")
        best_loss = train_loss
        torch.save(net.state_dict(), f"weights/best_char_gen.pth")

Epoch 0: Total loss: 2.766194622381977, Accuracy: 0.2330: 100%|██████████| 4515/4515 [00:23<00:00, 191.69it/s] 


Save best model with loss = 2.766194622381977


Epoch 1: Total loss: 2.4867214611334396, Accuracy: 0.2929: 100%|██████████| 4515/4515 [00:23<00:00, 192.59it/s]


Save best model with loss = 2.4867214611334396


Epoch 2: Total loss: 2.3324534436529, Accuracy: 0.3318: 100%|██████████| 4515/4515 [00:22<00:00, 196.59it/s]   


Save best model with loss = 2.3324534436529


Epoch 3: Total loss: 2.2028206046633545, Accuracy: 0.3678: 100%|██████████| 4515/4515 [00:22<00:00, 198.54it/s]


Save best model with loss = 2.2028206046633545


Epoch 4: Total loss: 2.104635653464105, Accuracy: 0.3932: 100%|██████████| 4515/4515 [00:22<00:00, 199.83it/s] 


Save best model with loss = 2.104635653464105


Epoch 5: Total loss: 2.0233613750863313, Accuracy: 0.4136: 100%|██████████| 4515/4515 [00:22<00:00, 197.89it/s]


Save best model with loss = 2.0233613750863313


Epoch 6: Total loss: 1.9529382000465858, Accuracy: 0.4300: 100%|██████████| 4515/4515 [00:23<00:00, 196.30it/s]


Save best model with loss = 1.9529382000465858


Epoch 7: Total loss: 1.891975955947293, Accuracy: 0.4467: 100%|██████████| 4515/4515 [00:22<00:00, 197.61it/s] 


Save best model with loss = 1.891975955947293


Epoch 8: Total loss: 1.8406418929200368, Accuracy: 0.4593: 100%|██████████| 4515/4515 [00:22<00:00, 199.12it/s]


Save best model with loss = 1.8406418929200368


Epoch 9: Total loss: 1.7871006035593524, Accuracy: 0.4730: 100%|██████████| 4515/4515 [00:23<00:00, 196.15it/s]


Save best model with loss = 1.7871006035593524


Epoch 10: Total loss: 1.7443735960710087, Accuracy: 0.4818: 100%|██████████| 4515/4515 [00:22<00:00, 199.60it/s]


Save best model with loss = 1.7443735960710087


Epoch 11: Total loss: 1.7022929077396625, Accuracy: 0.4936: 100%|██████████| 4515/4515 [00:23<00:00, 190.39it/s]


Save best model with loss = 1.7022929077396625


Epoch 12: Total loss: 1.659168501747803, Accuracy: 0.5036: 100%|██████████| 4515/4515 [00:24<00:00, 180.93it/s] 


Save best model with loss = 1.659168501747803


Epoch 13: Total loss: 1.625417039811545, Accuracy: 0.5097: 100%|██████████| 4515/4515 [00:24<00:00, 184.77it/s] 


Save best model with loss = 1.625417039811545


Epoch 14: Total loss: 1.5892631211550132, Accuracy: 0.5204: 100%|██████████| 4515/4515 [00:23<00:00, 190.45it/s]


Save best model with loss = 1.5892631211550132


Epoch 15: Total loss: 1.5573169736899146, Accuracy: 0.5271: 100%|██████████| 4515/4515 [00:23<00:00, 195.40it/s]


Save best model with loss = 1.5573169736899146


Epoch 16: Total loss: 1.5273986336664769, Accuracy: 0.5339: 100%|██████████| 4515/4515 [00:22<00:00, 196.82it/s]


Save best model with loss = 1.5273986336664769


Epoch 17: Total loss: 1.4953168171584805, Accuracy: 0.5414: 100%|██████████| 4515/4515 [00:22<00:00, 205.03it/s]


Save best model with loss = 1.4953168171584805


Epoch 18: Total loss: 1.4698564516349488, Accuracy: 0.5485: 100%|██████████| 4515/4515 [00:22<00:00, 200.69it/s]


Save best model with loss = 1.4698564516349488


Epoch 19: Total loss: 1.4458377007498167, Accuracy: 0.5534: 100%|██████████| 4515/4515 [00:22<00:00, 204.72it/s]


Save best model with loss = 1.4458377007498167


Epoch 20: Total loss: 1.419925053491941, Accuracy: 0.5606: 100%|██████████| 4515/4515 [00:22<00:00, 202.70it/s] 


Save best model with loss = 1.419925053491941


Epoch 21: Total loss: 1.398134809831448, Accuracy: 0.5661: 100%|██████████| 4515/4515 [00:21<00:00, 207.67it/s] 


Save best model with loss = 1.398134809831448


Epoch 22: Total loss: 1.3768215602154483, Accuracy: 0.5709: 100%|██████████| 4515/4515 [00:21<00:00, 205.56it/s]


Save best model with loss = 1.3768215602154483


Epoch 23: Total loss: 1.3557753303782356, Accuracy: 0.5759: 100%|██████████| 4515/4515 [00:21<00:00, 205.78it/s]


Save best model with loss = 1.3557753303782356


Epoch 24: Total loss: 1.3358347409191322, Accuracy: 0.5817: 100%|██████████| 4515/4515 [00:22<00:00, 203.51it/s]


Save best model with loss = 1.3358347409191322


Epoch 25: Total loss: 1.3140025776882107, Accuracy: 0.5863: 100%|██████████| 4515/4515 [00:22<00:00, 202.02it/s]


Save best model with loss = 1.3140025776882107


Epoch 26: Total loss: 1.2988812082233618, Accuracy: 0.5908: 100%|██████████| 4515/4515 [00:22<00:00, 202.05it/s]


Save best model with loss = 1.2988812082233618


Epoch 27: Total loss: 1.2810793632288708, Accuracy: 0.5948: 100%|██████████| 4515/4515 [00:22<00:00, 203.90it/s]


Save best model with loss = 1.2810793632288708


Epoch 28: Total loss: 1.2650211887766225, Accuracy: 0.5981: 100%|██████████| 4515/4515 [00:22<00:00, 199.61it/s]


Save best model with loss = 1.2650211887766225


Epoch 29: Total loss: 1.2463971077009688, Accuracy: 0.6034: 100%|██████████| 4515/4515 [00:21<00:00, 205.91it/s]


Save best model with loss = 1.2463971077009688


Epoch 30: Total loss: 1.2351007993831191, Accuracy: 0.6054: 100%|██████████| 4515/4515 [00:22<00:00, 202.69it/s]


Save best model with loss = 1.2351007993831191


Epoch 31: Total loss: 1.2206786671086138, Accuracy: 0.6085: 100%|██████████| 4515/4515 [00:22<00:00, 204.17it/s]


Save best model with loss = 1.2206786671086138


Epoch 32: Total loss: 1.206980159010734, Accuracy: 0.6137: 100%|██████████| 4515/4515 [00:22<00:00, 202.06it/s] 


Save best model with loss = 1.206980159010734


Epoch 33: Total loss: 1.194997566872393, Accuracy: 0.6154: 100%|██████████| 4515/4515 [00:22<00:00, 203.01it/s] 


Save best model with loss = 1.194997566872393


Epoch 34: Total loss: 1.1801089018484288, Accuracy: 0.6195: 100%|██████████| 4515/4515 [00:22<00:00, 202.76it/s]


Save best model with loss = 1.1801089018484288


Epoch 35: Total loss: 1.174023895973911, Accuracy: 0.6227: 100%|██████████| 4515/4515 [00:22<00:00, 203.30it/s] 


Save best model with loss = 1.174023895973911


Epoch 36: Total loss: 1.1543913258204561, Accuracy: 0.6282: 100%|██████████| 4515/4515 [00:21<00:00, 206.33it/s]


Save best model with loss = 1.1543913258204561


Epoch 37: Total loss: 1.1439434367622912, Accuracy: 0.6291: 100%|██████████| 4515/4515 [00:22<00:00, 202.80it/s]


Save best model with loss = 1.1439434367622912


Epoch 38: Total loss: 1.1354701010161725, Accuracy: 0.6317: 100%|██████████| 4515/4515 [00:22<00:00, 200.82it/s]


Save best model with loss = 1.1354701010161725


Epoch 39: Total loss: 1.1235142288479958, Accuracy: 0.6354: 100%|██████████| 4515/4515 [00:22<00:00, 204.83it/s]


Save best model with loss = 1.1235142288479958


Epoch 40: Total loss: 1.1137048363619602, Accuracy: 0.6377: 100%|██████████| 4515/4515 [00:22<00:00, 201.52it/s]


Save best model with loss = 1.1137048363619602


Epoch 41: Total loss: 1.1017117275418105, Accuracy: 0.6409: 100%|██████████| 4515/4515 [00:22<00:00, 204.65it/s]


Save best model with loss = 1.1017117275418105


Epoch 42: Total loss: 1.0934137098812977, Accuracy: 0.6421: 100%|██████████| 4515/4515 [00:22<00:00, 201.40it/s]


Save best model with loss = 1.0934137098812977


Epoch 43: Total loss: 1.0849622462741033, Accuracy: 0.6456: 100%|██████████| 4515/4515 [00:22<00:00, 204.34it/s]


Save best model with loss = 1.0849622462741033


Epoch 44: Total loss: 1.0756037350178822, Accuracy: 0.6486: 100%|██████████| 4515/4515 [00:22<00:00, 203.63it/s]


Save best model with loss = 1.0756037350178822


Epoch 45: Total loss: 1.0678967271598339, Accuracy: 0.6498: 100%|██████████| 4515/4515 [00:22<00:00, 202.78it/s]


Save best model with loss = 1.0678967271598339


Epoch 46: Total loss: 1.055363470447948, Accuracy: 0.6544: 100%|██████████| 4515/4515 [00:22<00:00, 203.30it/s] 


Save best model with loss = 1.055363470447948


Epoch 47: Total loss: 1.0477096158074646, Accuracy: 0.6566: 100%|██████████| 4515/4515 [00:22<00:00, 201.47it/s]


Save best model with loss = 1.0477096158074646


Epoch 48: Total loss: 1.0392741319217027, Accuracy: 0.6584: 100%|██████████| 4515/4515 [00:22<00:00, 204.49it/s]


Save best model with loss = 1.0392741319217027


Epoch 49: Total loss: 1.0374860538926765, Accuracy: 0.6580: 100%|██████████| 4515/4515 [00:22<00:00, 204.09it/s]


Save best model with loss = 1.0374860538926765


Epoch 50: Total loss: 1.0280701469609375, Accuracy: 0.6606: 100%|██████████| 4515/4515 [00:22<00:00, 201.20it/s]


Save best model with loss = 1.0280701469609375


Epoch 51: Total loss: 1.0191588761204502, Accuracy: 0.6644: 100%|██████████| 4515/4515 [00:22<00:00, 203.60it/s]


Save best model with loss = 1.0191588761204502


Epoch 52: Total loss: 1.0117187117091313, Accuracy: 0.6669: 100%|██████████| 4515/4515 [00:21<00:00, 205.54it/s]


Save best model with loss = 1.0117187117091313


Epoch 53: Total loss: 1.0065667409107402, Accuracy: 0.6674: 100%|██████████| 4515/4515 [00:22<00:00, 203.41it/s]


Save best model with loss = 1.0065667409107402


Epoch 54: Total loss: 0.9966794754249577, Accuracy: 0.6700: 100%|██████████| 4515/4515 [00:22<00:00, 203.23it/s]


Save best model with loss = 0.9966794754249577


Epoch 55: Total loss: 0.9951214239578838, Accuracy: 0.6715: 100%|██████████| 4515/4515 [00:22<00:00, 202.62it/s]


Save best model with loss = 0.9951214239578838


Epoch 56: Total loss: 0.9829793885366729, Accuracy: 0.6735: 100%|██████████| 4515/4515 [00:22<00:00, 203.88it/s]


Save best model with loss = 0.9829793885366729


Epoch 57: Total loss: 0.9767223277295283, Accuracy: 0.6751: 100%|██████████| 4515/4515 [00:23<00:00, 195.56it/s]


Save best model with loss = 0.9767223277295283


Epoch 58: Total loss: 0.9756987229006103, Accuracy: 0.6753: 100%|██████████| 4515/4515 [00:22<00:00, 203.06it/s]


Save best model with loss = 0.9756987229006103


Epoch 59: Total loss: 0.9674692335978959, Accuracy: 0.6788: 100%|██████████| 4515/4515 [00:22<00:00, 201.37it/s]


Save best model with loss = 0.9674692335978959


Epoch 60: Total loss: 0.9590281005053557, Accuracy: 0.6806: 100%|██████████| 4515/4515 [00:22<00:00, 201.72it/s]


Save best model with loss = 0.9590281005053557


Epoch 61: Total loss: 0.9585126144413404, Accuracy: 0.6807: 100%|██████████| 4515/4515 [00:22<00:00, 202.79it/s]


Save best model with loss = 0.9585126144413404


Epoch 62: Total loss: 0.9499756266376372, Accuracy: 0.6835: 100%|██████████| 4515/4515 [00:22<00:00, 203.34it/s]


Save best model with loss = 0.9499756266376372


Epoch 63: Total loss: 0.9457246348113316, Accuracy: 0.6835: 100%|██████████| 4515/4515 [00:22<00:00, 202.43it/s]


Save best model with loss = 0.9457246348113316


Epoch 64: Total loss: 0.9407827825509301, Accuracy: 0.6860: 100%|██████████| 4515/4515 [00:22<00:00, 203.18it/s]


Save best model with loss = 0.9407827825509301


Epoch 65: Total loss: 0.9307458250031518, Accuracy: 0.6894: 100%|██████████| 4515/4515 [00:22<00:00, 203.50it/s]


Save best model with loss = 0.9307458250031518


Epoch 66: Total loss: 0.9280237861447952, Accuracy: 0.6890: 100%|██████████| 4515/4515 [00:22<00:00, 203.39it/s]


Save best model with loss = 0.9280237861447952


Epoch 67: Total loss: 0.9204116164911889, Accuracy: 0.6912: 100%|██████████| 4515/4515 [00:22<00:00, 201.04it/s]


Save best model with loss = 0.9204116164911889


Epoch 68: Total loss: 0.9144640635876429, Accuracy: 0.6936: 100%|██████████| 4515/4515 [00:21<00:00, 206.74it/s]


Save best model with loss = 0.9144640635876429


Epoch 69: Total loss: 0.9137485228635676, Accuracy: 0.6944: 100%|██████████| 4515/4515 [00:22<00:00, 203.48it/s]


Save best model with loss = 0.9137485228635676


Epoch 70: Total loss: 0.9074132174154452, Accuracy: 0.6965: 100%|██████████| 4515/4515 [00:21<00:00, 207.13it/s]


Save best model with loss = 0.9074132174154452


Epoch 71: Total loss: 0.9064879079131721, Accuracy: 0.6951: 100%|██████████| 4515/4515 [00:22<00:00, 203.06it/s]


Save best model with loss = 0.9064879079131721


Epoch 72: Total loss: 0.8986360215863516, Accuracy: 0.6997: 100%|██████████| 4515/4515 [00:21<00:00, 205.66it/s]


Save best model with loss = 0.8986360215863516


Epoch 73: Total loss: 0.8949151603222422, Accuracy: 0.7002: 100%|██████████| 4515/4515 [00:22<00:00, 202.28it/s]


Save best model with loss = 0.8949151603222422


Epoch 74: Total loss: 0.8937097576105449, Accuracy: 0.7006: 100%|██████████| 4515/4515 [00:22<00:00, 202.28it/s]


Save best model with loss = 0.8937097576105449


Epoch 75: Total loss: 0.8851124614939209, Accuracy: 0.7020: 100%|██████████| 4515/4515 [00:22<00:00, 202.93it/s]


Save best model with loss = 0.8851124614939209


Epoch 76: Total loss: 0.881085236679148, Accuracy: 0.7031: 100%|██████████| 4515/4515 [00:22<00:00, 203.07it/s] 


Save best model with loss = 0.881085236679148


Epoch 77: Total loss: 0.8705854357675065, Accuracy: 0.7065: 100%|██████████| 4515/4515 [00:22<00:00, 202.93it/s]


Save best model with loss = 0.8705854357675065


Epoch 78: Total loss: 0.8709334268390406, Accuracy: 0.7071: 100%|██████████| 4515/4515 [00:22<00:00, 204.92it/s]
Epoch 79: Total loss: 0.8716522105442461, Accuracy: 0.7057: 100%|██████████| 4515/4515 [00:23<00:00, 190.97it/s]
Epoch 80: Total loss: 0.8647029217790527, Accuracy: 0.7096: 100%|██████████| 4515/4515 [00:23<00:00, 196.29it/s]


Save best model with loss = 0.8647029217790527


Epoch 81: Total loss: 0.8636644287164821, Accuracy: 0.7103: 100%|██████████| 4515/4515 [00:22<00:00, 197.74it/s]


Save best model with loss = 0.8636644287164821


Epoch 82: Total loss: 0.8585123281476241, Accuracy: 0.7088: 100%|██████████| 4515/4515 [00:22<00:00, 198.37it/s]


Save best model with loss = 0.8585123281476241


Epoch 83: Total loss: 0.8560013509535974, Accuracy: 0.7124: 100%|██████████| 4515/4515 [00:22<00:00, 198.35it/s]


Save best model with loss = 0.8560013509535974


Epoch 84: Total loss: 0.8551672062637534, Accuracy: 0.7111: 100%|██████████| 4515/4515 [00:22<00:00, 203.82it/s]


Save best model with loss = 0.8551672062637534


Epoch 85: Total loss: 0.8452143907084946, Accuracy: 0.7135: 100%|██████████| 4515/4515 [00:21<00:00, 207.96it/s]


Save best model with loss = 0.8452143907084946


Epoch 86: Total loss: 0.8458129470068487, Accuracy: 0.7150: 100%|██████████| 4515/4515 [00:22<00:00, 204.02it/s]
Epoch 87: Total loss: 0.839702624627126, Accuracy: 0.7170: 100%|██████████| 4515/4515 [00:21<00:00, 208.64it/s] 


Save best model with loss = 0.839702624627126


Epoch 88: Total loss: 0.8391391026881314, Accuracy: 0.7165: 100%|██████████| 4515/4515 [00:21<00:00, 205.60it/s]


Save best model with loss = 0.8391391026881314


Epoch 89: Total loss: 0.8310639964913741, Accuracy: 0.7185: 100%|██████████| 4515/4515 [00:21<00:00, 206.42it/s]


Save best model with loss = 0.8310639964913741


Epoch 90: Total loss: 0.8277770907346856, Accuracy: 0.7190: 100%|██████████| 4515/4515 [00:22<00:00, 204.65it/s]


Save best model with loss = 0.8277770907346856


Epoch 91: Total loss: 0.8207522976537084, Accuracy: 0.7225: 100%|██████████| 4515/4515 [00:21<00:00, 205.81it/s]


Save best model with loss = 0.8207522976537084


Epoch 92: Total loss: 0.8186088541483431, Accuracy: 0.7219: 100%|██████████| 4515/4515 [00:21<00:00, 205.96it/s]


Save best model with loss = 0.8186088541483431


Epoch 93: Total loss: 0.8192734082143832, Accuracy: 0.7226: 100%|██████████| 4515/4515 [00:22<00:00, 204.20it/s]
Epoch 94: Total loss: 0.8126723997121633, Accuracy: 0.7242: 100%|██████████| 4515/4515 [00:22<00:00, 203.93it/s]


Save best model with loss = 0.8126723997121633


Epoch 95: Total loss: 0.8083297580744605, Accuracy: 0.7252: 100%|██████████| 4515/4515 [00:22<00:00, 202.29it/s]


Save best model with loss = 0.8083297580744605


Epoch 96: Total loss: 0.8081178957012669, Accuracy: 0.7260: 100%|██████████| 4515/4515 [00:22<00:00, 202.08it/s]


Save best model with loss = 0.8081178957012669


Epoch 97: Total loss: 0.8043799383035662, Accuracy: 0.7274: 100%|██████████| 4515/4515 [00:21<00:00, 205.50it/s]


Save best model with loss = 0.8043799383035662


Epoch 98: Total loss: 0.8080570940757511, Accuracy: 0.7259: 100%|██████████| 4515/4515 [00:22<00:00, 201.22it/s]
Epoch 99: Total loss: 0.7995002424532129, Accuracy: 0.7269: 100%|██████████| 4515/4515 [00:22<00:00, 205.00it/s]


Save best model with loss = 0.7995002424532129


Epoch 100: Total loss: 0.7984323241352903, Accuracy: 0.7286: 100%|██████████| 4515/4515 [00:22<00:00, 203.69it/s]


Save best model with loss = 0.7984323241352903


Epoch 101: Total loss: 0.7943918878585925, Accuracy: 0.7290: 100%|██████████| 4515/4515 [00:21<00:00, 206.12it/s]


Save best model with loss = 0.7943918878585925


Epoch 102: Total loss: 0.7908672231079336, Accuracy: 0.7306: 100%|██████████| 4515/4515 [00:22<00:00, 204.64it/s]


Save best model with loss = 0.7908672231079336


Epoch 103: Total loss: 0.7875699349713088, Accuracy: 0.7315: 100%|██████████| 4515/4515 [00:22<00:00, 204.00it/s]


Save best model with loss = 0.7875699349713088


Epoch 104: Total loss: 0.7887975094425321, Accuracy: 0.7312: 100%|██████████| 4515/4515 [00:22<00:00, 203.94it/s]
Epoch 105: Total loss: 0.7809398749209981, Accuracy: 0.7326: 100%|██████████| 4515/4515 [00:21<00:00, 205.51it/s]


Save best model with loss = 0.7809398749209981


Epoch 106: Total loss: 0.7805132936269606, Accuracy: 0.7335: 100%|██████████| 4515/4515 [00:22<00:00, 204.75it/s]


Save best model with loss = 0.7805132936269606


Epoch 107: Total loss: 0.7746111304392979, Accuracy: 0.7358: 100%|██████████| 4515/4515 [00:22<00:00, 203.57it/s]


Save best model with loss = 0.7746111304392979


Epoch 108: Total loss: 0.7799876466938032, Accuracy: 0.7332: 100%|██████████| 4515/4515 [00:22<00:00, 202.95it/s]
Epoch 109: Total loss: 0.7714275706349021, Accuracy: 0.7368: 100%|██████████| 4515/4515 [00:21<00:00, 206.26it/s]


Save best model with loss = 0.7714275706349021


Epoch 110: Total loss: 0.770172449573462, Accuracy: 0.7368: 100%|██████████| 4515/4515 [00:22<00:00, 202.73it/s] 


Save best model with loss = 0.770172449573462


Epoch 111: Total loss: 0.7644808554867177, Accuracy: 0.7387: 100%|██████████| 4515/4515 [00:21<00:00, 206.02it/s]


Save best model with loss = 0.7644808554867177


Epoch 112: Total loss: 0.7653872878382504, Accuracy: 0.7377: 100%|██████████| 4515/4515 [00:22<00:00, 204.02it/s]
Epoch 113: Total loss: 0.7613936661551826, Accuracy: 0.7400: 100%|██████████| 4515/4515 [00:22<00:00, 203.30it/s]


Save best model with loss = 0.7613936661551826


Epoch 114: Total loss: 0.7605274604371484, Accuracy: 0.7411: 100%|██████████| 4515/4515 [00:22<00:00, 204.91it/s]


Save best model with loss = 0.7605274604371484


Epoch 115: Total loss: 0.7581285132060152, Accuracy: 0.7404: 100%|██████████| 4515/4515 [00:22<00:00, 203.51it/s]


Save best model with loss = 0.7581285132060152


Epoch 116: Total loss: 0.7554250951654227, Accuracy: 0.7412: 100%|██████████| 4515/4515 [00:22<00:00, 203.56it/s]


Save best model with loss = 0.7554250951654227


Epoch 117: Total loss: 0.7511103468974959, Accuracy: 0.7429: 100%|██████████| 4515/4515 [00:22<00:00, 202.25it/s]


Save best model with loss = 0.7511103468974959


Epoch 118: Total loss: 0.7531677232793267, Accuracy: 0.7431: 100%|██████████| 4515/4515 [00:22<00:00, 203.46it/s]
Epoch 119: Total loss: 0.7463425838703069, Accuracy: 0.7439: 100%|██████████| 4515/4515 [00:24<00:00, 187.76it/s]


Save best model with loss = 0.7463425838703069


Epoch 120: Total loss: 0.7436191885376293, Accuracy: 0.7462: 100%|██████████| 4515/4515 [00:23<00:00, 191.99it/s]


Save best model with loss = 0.7436191885376293


Epoch 121: Total loss: 0.7414596329396745, Accuracy: 0.7453: 100%|██████████| 4515/4515 [00:22<00:00, 198.22it/s]


Save best model with loss = 0.7414596329396745


Epoch 122: Total loss: 0.7372238145681183, Accuracy: 0.7467: 100%|██████████| 4515/4515 [00:23<00:00, 192.79it/s]


Save best model with loss = 0.7372238145681183


Epoch 123: Total loss: 0.7412382345460974, Accuracy: 0.7459: 100%|██████████| 4515/4515 [00:24<00:00, 187.34it/s]
Epoch 124: Total loss: 0.7324989209539471, Accuracy: 0.7476: 100%|██████████| 4515/4515 [00:24<00:00, 182.91it/s]


Save best model with loss = 0.7324989209539471


Epoch 125: Total loss: 0.7373524168203044, Accuracy: 0.7486: 100%|██████████| 4515/4515 [00:28<00:00, 161.04it/s]
Epoch 126: Total loss: 0.7344250998029677, Accuracy: 0.7481: 100%|██████████| 4515/4515 [00:28<00:00, 156.92it/s]
Epoch 127: Total loss: 0.7295464498722144, Accuracy: 0.7507: 100%|██████████| 4515/4515 [00:30<00:00, 149.33it/s]


Save best model with loss = 0.7295464498722144


Epoch 128: Total loss: 0.7295948353782709, Accuracy: 0.7497: 100%|██████████| 4515/4515 [00:25<00:00, 179.47it/s]
Epoch 129: Total loss: 0.7258689821426523, Accuracy: 0.7508: 100%|██████████| 4515/4515 [00:24<00:00, 183.99it/s]


Save best model with loss = 0.7258689821426523


Epoch 130: Total loss: 0.7239373878189099, Accuracy: 0.7511: 100%|██████████| 4515/4515 [00:25<00:00, 176.35it/s]


Save best model with loss = 0.7239373878189099


Epoch 131: Total loss: 0.7263656803275792, Accuracy: 0.7517: 100%|██████████| 4515/4515 [00:25<00:00, 175.99it/s]
Epoch 132: Total loss: 0.7227818834590225, Accuracy: 0.7537: 100%|██████████| 4515/4515 [00:24<00:00, 187.27it/s]


Save best model with loss = 0.7227818834590225


Epoch 133: Total loss: 0.7218267777829471, Accuracy: 0.7518: 100%|██████████| 4515/4515 [00:24<00:00, 182.31it/s]


Save best model with loss = 0.7218267777829471


Epoch 134: Total loss: 0.7166747815090161, Accuracy: 0.7544: 100%|██████████| 4515/4515 [00:24<00:00, 184.32it/s]


Save best model with loss = 0.7166747815090161


Epoch 135: Total loss: 0.7147549709872418, Accuracy: 0.7561: 100%|██████████| 4515/4515 [00:25<00:00, 180.04it/s]


Save best model with loss = 0.7147549709872418


Epoch 136: Total loss: 0.7095599947544692, Accuracy: 0.7563: 100%|██████████| 4515/4515 [00:25<00:00, 175.20it/s]


Save best model with loss = 0.7095599947544692


Epoch 137: Total loss: 0.7135002562274965, Accuracy: 0.7556: 100%|██████████| 4515/4515 [00:27<00:00, 166.61it/s]
Epoch 138: Total loss: 0.7092805430507607, Accuracy: 0.7579: 100%|██████████| 4515/4515 [00:26<00:00, 169.25it/s]


Save best model with loss = 0.7092805430507607


Epoch 139: Total loss: 0.7069895388511064, Accuracy: 0.7579: 100%|██████████| 4515/4515 [00:30<00:00, 150.50it/s]


Save best model with loss = 0.7069895388511064


Epoch 140: Total loss: 0.7060281155084323, Accuracy: 0.7581: 100%|██████████| 4515/4515 [00:27<00:00, 166.86it/s]


Save best model with loss = 0.7060281155084323


Epoch 141: Total loss: 0.7015583984083511, Accuracy: 0.7583: 100%|██████████| 4515/4515 [00:27<00:00, 161.28it/s]


Save best model with loss = 0.7015583984083511


Epoch 142: Total loss: 0.7019103980374627, Accuracy: 0.7586: 100%|██████████| 4515/4515 [00:27<00:00, 164.34it/s]
Epoch 143: Total loss: 0.7011269480270537, Accuracy: 0.7592: 100%|██████████| 4515/4515 [00:26<00:00, 173.32it/s]


Save best model with loss = 0.7011269480270537


Epoch 144: Total loss: 0.695067617409385, Accuracy: 0.7616: 100%|██████████| 4515/4515 [00:26<00:00, 173.64it/s] 


Save best model with loss = 0.695067617409385


Epoch 145: Total loss: 0.7027132077272549, Accuracy: 0.7582: 100%|██████████| 4515/4515 [00:24<00:00, 183.00it/s]
Epoch 146: Total loss: 0.6968304328117128, Accuracy: 0.7605: 100%|██████████| 4515/4515 [00:26<00:00, 173.06it/s]
Epoch 147: Total loss: 0.6943532867190054, Accuracy: 0.7619: 100%|██████████| 4515/4515 [00:25<00:00, 174.63it/s]


Save best model with loss = 0.6943532867190054


Epoch 148: Total loss: 0.687576812056872, Accuracy: 0.7639: 100%|██████████| 4515/4515 [00:26<00:00, 173.49it/s] 


Save best model with loss = 0.687576812056872


Epoch 149: Total loss: 0.6839913973545579, Accuracy: 0.7640: 100%|██████████| 4515/4515 [00:26<00:00, 171.63it/s]


Save best model with loss = 0.6839913973545579


Epoch 150: Total loss: 0.6863516335636543, Accuracy: 0.7647: 100%|██████████| 4515/4515 [00:25<00:00, 174.60it/s]
Epoch 151: Total loss: 0.6867698142595862, Accuracy: 0.7636: 100%|██████████| 4515/4515 [00:26<00:00, 173.25it/s]
Epoch 152: Total loss: 0.687340613694682, Accuracy: 0.7651: 100%|██████████| 4515/4515 [00:25<00:00, 174.94it/s] 
Epoch 153: Total loss: 0.6817318368443223, Accuracy: 0.7660: 100%|██████████| 4515/4515 [00:26<00:00, 173.44it/s]


Save best model with loss = 0.6817318368443223


Epoch 154: Total loss: 0.6830576424780874, Accuracy: 0.7649: 100%|██████████| 4515/4515 [00:25<00:00, 174.64it/s]
Epoch 155: Total loss: 0.674754128206608, Accuracy: 0.7687: 100%|██████████| 4515/4515 [00:25<00:00, 174.40it/s] 


Save best model with loss = 0.674754128206608


Epoch 156: Total loss: 0.6779035959363908, Accuracy: 0.7675: 100%|██████████| 4515/4515 [00:25<00:00, 176.54it/s]
Epoch 157: Total loss: 0.6757940922761676, Accuracy: 0.7681: 100%|██████████| 4515/4515 [00:25<00:00, 176.34it/s]
Epoch 158: Total loss: 0.6711352002977137, Accuracy: 0.7689: 100%|██████████| 4515/4515 [00:27<00:00, 166.50it/s]


Save best model with loss = 0.6711352002977137


Epoch 159: Total loss: 0.6709793071942208, Accuracy: 0.7697: 100%|██████████| 4515/4515 [00:23<00:00, 193.10it/s]


Save best model with loss = 0.6709793071942208


Epoch 160: Total loss: 0.6678664319571994, Accuracy: 0.7701: 100%|██████████| 4515/4515 [00:23<00:00, 193.72it/s]


Save best model with loss = 0.6678664319571994


Epoch 161: Total loss: 0.6741111688109596, Accuracy: 0.7679: 100%|██████████| 4515/4515 [00:23<00:00, 189.43it/s]
Epoch 162: Total loss: 0.6621895777591703, Accuracy: 0.7715: 100%|██████████| 4515/4515 [00:28<00:00, 157.55it/s]


Save best model with loss = 0.6621895777591703


Epoch 163: Total loss: 0.6625017385282126, Accuracy: 0.7719: 100%|██████████| 4515/4515 [00:27<00:00, 161.94it/s]
Epoch 164: Total loss: 0.6641592564558402, Accuracy: 0.7713: 100%|██████████| 4515/4515 [00:28<00:00, 160.71it/s]
Epoch 165: Total loss: 0.662957776965757, Accuracy: 0.7709: 100%|██████████| 4515/4515 [00:28<00:00, 159.90it/s] 
Epoch 166: Total loss: 0.6622754717057989, Accuracy: 0.7713: 100%|██████████| 4515/4515 [00:28<00:00, 158.80it/s]
Epoch 167: Total loss: 0.6588668469739779, Accuracy: 0.7737: 100%|██████████| 4515/4515 [00:29<00:00, 155.42it/s]


Save best model with loss = 0.6588668469739779


Epoch 168: Total loss: 0.6595314931400854, Accuracy: 0.7724: 100%|██████████| 4515/4515 [00:25<00:00, 175.32it/s]
Epoch 169: Total loss: 0.6571319539879643, Accuracy: 0.7741: 100%|██████████| 4515/4515 [00:28<00:00, 160.26it/s]


Save best model with loss = 0.6571319539879643


Epoch 170: Total loss: 0.6549700960104383, Accuracy: 0.7740: 100%|██████████| 4515/4515 [00:26<00:00, 173.33it/s]


Save best model with loss = 0.6549700960104383


Epoch 171: Total loss: 0.6537368420218577, Accuracy: 0.7735: 100%|██████████| 4515/4515 [00:27<00:00, 167.02it/s]


Save best model with loss = 0.6537368420218577


Epoch 172: Total loss: 0.6499526667832536, Accuracy: 0.7759: 100%|██████████| 4515/4515 [00:27<00:00, 166.54it/s]


Save best model with loss = 0.6499526667832536


Epoch 173: Total loss: 0.6515159891194283, Accuracy: 0.7768: 100%|██████████| 4515/4515 [00:27<00:00, 165.29it/s]
Epoch 174: Total loss: 0.6475018759535529, Accuracy: 0.7772: 100%|██████████| 4515/4515 [00:24<00:00, 183.28it/s]


Save best model with loss = 0.6475018759535529


Epoch 175: Total loss: 0.6431034545631767, Accuracy: 0.7778: 100%|██████████| 4515/4515 [00:22<00:00, 197.96it/s]


Save best model with loss = 0.6431034545631767


Epoch 176: Total loss: 0.6476047061507488, Accuracy: 0.7757: 100%|██████████| 4515/4515 [00:21<00:00, 205.24it/s]
Epoch 177: Total loss: 0.6472924045979118, Accuracy: 0.7773: 100%|██████████| 4515/4515 [00:21<00:00, 208.89it/s]
Epoch 178: Total loss: 0.6438629465443747, Accuracy: 0.7785: 100%|██████████| 4515/4515 [00:21<00:00, 211.00it/s]
Epoch 179: Total loss: 0.6417659571367509, Accuracy: 0.7779: 100%|██████████| 4515/4515 [00:21<00:00, 209.54it/s]


Save best model with loss = 0.6417659571367509


Epoch 180: Total loss: 0.6410343376356502, Accuracy: 0.7797: 100%|██████████| 4515/4515 [00:21<00:00, 210.31it/s]


Save best model with loss = 0.6410343376356502


Epoch 181: Total loss: 0.6437840463199621, Accuracy: 0.7775: 100%|██████████| 4515/4515 [00:21<00:00, 208.46it/s]
Epoch 182: Total loss: 0.6400198762995064, Accuracy: 0.7803: 100%|██████████| 4515/4515 [00:21<00:00, 210.88it/s]


Save best model with loss = 0.6400198762995064


Epoch 183: Total loss: 0.6360092732333398, Accuracy: 0.7796: 100%|██████████| 4515/4515 [00:21<00:00, 210.75it/s]


Save best model with loss = 0.6360092732333398


Epoch 184: Total loss: 0.6345041425801723, Accuracy: 0.7821: 100%|██████████| 4515/4515 [00:21<00:00, 211.72it/s]


Save best model with loss = 0.6345041425801723


Epoch 185: Total loss: 0.6334073834575028, Accuracy: 0.7811: 100%|██████████| 4515/4515 [00:21<00:00, 211.34it/s]


Save best model with loss = 0.6334073834575028


Epoch 186: Total loss: 0.6387058569413351, Accuracy: 0.7795: 100%|██████████| 4515/4515 [00:21<00:00, 214.86it/s]
Epoch 187: Total loss: 0.6303449174626458, Accuracy: 0.7820: 100%|██████████| 4515/4515 [00:21<00:00, 211.59it/s]


Save best model with loss = 0.6303449174626458


Epoch 188: Total loss: 0.635301253339843, Accuracy: 0.7814: 100%|██████████| 4515/4515 [00:21<00:00, 211.22it/s] 
Epoch 189: Total loss: 0.6262369317502806, Accuracy: 0.7841: 100%|██████████| 4515/4515 [00:20<00:00, 216.13it/s]


Save best model with loss = 0.6262369317502806


Epoch 190: Total loss: 0.62790906224998, Accuracy: 0.7832: 100%|██████████| 4515/4515 [00:21<00:00, 213.70it/s]  
Epoch 191: Total loss: 0.6285326530816547, Accuracy: 0.7839: 100%|██████████| 4515/4515 [00:21<00:00, 214.82it/s]
Epoch 192: Total loss: 0.6220264739471416, Accuracy: 0.7850: 100%|██████████| 4515/4515 [00:20<00:00, 215.33it/s]


Save best model with loss = 0.6220264739471416


Epoch 193: Total loss: 0.628318770908042, Accuracy: 0.7841: 100%|██████████| 4515/4515 [00:21<00:00, 213.13it/s] 
Epoch 194: Total loss: 0.6244358728230329, Accuracy: 0.7849: 100%|██████████| 4515/4515 [00:20<00:00, 215.76it/s]
Epoch 195: Total loss: 0.6225369016131426, Accuracy: 0.7858: 100%|██████████| 4515/4515 [00:21<00:00, 212.07it/s]
Epoch 196: Total loss: 0.6200293504914572, Accuracy: 0.7862: 100%|██████████| 4515/4515 [00:21<00:00, 213.96it/s]


Save best model with loss = 0.6200293504914572


Epoch 197: Total loss: 0.6227585097649037, Accuracy: 0.7845: 100%|██████████| 4515/4515 [00:21<00:00, 214.30it/s]
Epoch 198: Total loss: 0.6177923494729225, Accuracy: 0.7870: 100%|██████████| 4515/4515 [00:21<00:00, 212.67it/s]


Save best model with loss = 0.6177923494729225


Epoch 199: Total loss: 0.6173806494388602, Accuracy: 0.7863: 100%|██████████| 4515/4515 [00:21<00:00, 213.85it/s]


Save best model with loss = 0.6173806494388602


Epoch 200: Total loss: 0.6126279440961143, Accuracy: 0.7874: 100%|██████████| 4515/4515 [00:21<00:00, 211.78it/s]


Save best model with loss = 0.6126279440961143


Epoch 201: Total loss: 0.6174863575502884, Accuracy: 0.7873: 100%|██████████| 4515/4515 [00:21<00:00, 212.34it/s]
Epoch 202: Total loss: 0.6173499721427297, Accuracy: 0.7874: 100%|██████████| 4515/4515 [00:21<00:00, 214.25it/s]
Epoch 203: Total loss: 0.6132436068648119, Accuracy: 0.7872: 100%|██████████| 4515/4515 [00:21<00:00, 213.35it/s]
Epoch 204: Total loss: 0.6147641996839646, Accuracy: 0.7883: 100%|██████████| 4515/4515 [00:21<00:00, 211.07it/s]
Epoch 205: Total loss: 0.607294265161849, Accuracy: 0.7901: 100%|██████████| 4515/4515 [00:21<00:00, 213.96it/s] 


Save best model with loss = 0.607294265161849


Epoch 206: Total loss: 0.6128384855663526, Accuracy: 0.7887: 100%|██████████| 4515/4515 [00:20<00:00, 217.62it/s]
Epoch 207: Total loss: 0.6091080055432858, Accuracy: 0.7886: 100%|██████████| 4515/4515 [00:21<00:00, 213.96it/s]
Epoch 208: Total loss: 0.6112286326992156, Accuracy: 0.7905: 100%|██████████| 4515/4515 [00:21<00:00, 214.83it/s]
Epoch 209: Total loss: 0.609491171307474, Accuracy: 0.7889: 100%|██████████| 4515/4515 [00:20<00:00, 215.28it/s] 
Epoch 210: Total loss: 0.6036667296656738, Accuracy: 0.7915: 100%|██████████| 4515/4515 [00:21<00:00, 214.33it/s]


Save best model with loss = 0.6036667296656738


Epoch 211: Total loss: 0.6062241001151593, Accuracy: 0.7903: 100%|██████████| 4515/4515 [00:21<00:00, 214.51it/s]
Epoch 212: Total loss: 0.6007995549626525, Accuracy: 0.7922: 100%|██████████| 4515/4515 [00:21<00:00, 214.62it/s]


Save best model with loss = 0.6007995549626525


Epoch 213: Total loss: 0.6013956950566961, Accuracy: 0.7928: 100%|██████████| 4515/4515 [00:21<00:00, 214.31it/s]
Epoch 214: Total loss: 0.5982210710034154, Accuracy: 0.7919: 100%|██████████| 4515/4515 [00:21<00:00, 212.06it/s]


Save best model with loss = 0.5982210710034154


Epoch 215: Total loss: 0.6019806268999347, Accuracy: 0.7923: 100%|██████████| 4515/4515 [00:21<00:00, 211.85it/s]
Epoch 216: Total loss: 0.5958127333806335, Accuracy: 0.7929: 100%|██████████| 4515/4515 [00:21<00:00, 210.44it/s]


Save best model with loss = 0.5958127333806335


Epoch 217: Total loss: 0.5990095056799691, Accuracy: 0.7920: 100%|██████████| 4515/4515 [00:21<00:00, 214.35it/s]
Epoch 218: Total loss: 0.5934641622130657, Accuracy: 0.7945: 100%|██████████| 4515/4515 [00:21<00:00, 210.79it/s]


Save best model with loss = 0.5934641622130657


Epoch 219: Total loss: 0.5934596116542156, Accuracy: 0.7946: 100%|██████████| 4515/4515 [00:21<00:00, 212.65it/s]


Save best model with loss = 0.5934596116542156


Epoch 220: Total loss: 0.5913603683147716, Accuracy: 0.7944: 100%|██████████| 4515/4515 [00:21<00:00, 212.63it/s]


Save best model with loss = 0.5913603683147716


Epoch 221: Total loss: 0.5911694135347738, Accuracy: 0.7960: 100%|██████████| 4515/4515 [00:21<00:00, 211.34it/s]


Save best model with loss = 0.5911694135347738


Epoch 222: Total loss: 0.5892086087023697, Accuracy: 0.7959: 100%|██████████| 4515/4515 [00:21<00:00, 213.07it/s]


Save best model with loss = 0.5892086087023697


Epoch 223: Total loss: 0.5880517199074443, Accuracy: 0.7957: 100%|██████████| 4515/4515 [00:20<00:00, 218.15it/s]


Save best model with loss = 0.5880517199074443


Epoch 224: Total loss: 0.5886603307634889, Accuracy: 0.7959: 100%|██████████| 4515/4515 [00:21<00:00, 213.67it/s]
Epoch 225: Total loss: 0.5917241071842834, Accuracy: 0.7954: 100%|██████████| 4515/4515 [00:20<00:00, 215.99it/s]
Epoch 226: Total loss: 0.581552132289209, Accuracy: 0.7989: 100%|██████████| 4515/4515 [00:21<00:00, 214.02it/s] 


Save best model with loss = 0.581552132289209


Epoch 227: Total loss: 0.588224077247834, Accuracy: 0.7969: 100%|██████████| 4515/4515 [00:21<00:00, 213.83it/s] 
Epoch 228: Total loss: 0.5858517886165236, Accuracy: 0.7967: 100%|██████████| 4515/4515 [00:21<00:00, 212.76it/s]
Epoch 229: Total loss: 0.586065346050434, Accuracy: 0.7969: 100%|██████████| 4515/4515 [00:20<00:00, 215.59it/s] 
Epoch 230: Total loss: 0.5814975228255506, Accuracy: 0.7986: 100%|██████████| 4515/4515 [00:21<00:00, 213.29it/s]


Save best model with loss = 0.5814975228255506


Epoch 231: Total loss: 0.5813572357098261, Accuracy: 0.7974: 100%|██████████| 4515/4515 [00:21<00:00, 213.50it/s]


Save best model with loss = 0.5813572357098261


Epoch 232: Total loss: 0.583326949533962, Accuracy: 0.7979: 100%|██████████| 4515/4515 [00:21<00:00, 212.57it/s] 
Epoch 233: Total loss: 0.5847882365517046, Accuracy: 0.7984: 100%|██████████| 4515/4515 [00:21<00:00, 213.22it/s]
Epoch 234: Total loss: 0.5824301114246034, Accuracy: 0.7987: 100%|██████████| 4515/4515 [00:21<00:00, 211.89it/s]
Epoch 235: Total loss: 0.5872868467390471, Accuracy: 0.7981: 100%|██████████| 4515/4515 [00:21<00:00, 212.99it/s]
Epoch 236: Total loss: 0.580250999975046, Accuracy: 0.7991: 100%|██████████| 4515/4515 [00:21<00:00, 212.33it/s] 


Save best model with loss = 0.580250999975046


Epoch 237: Total loss: 0.5727280940130304, Accuracy: 0.8008: 100%|██████████| 4515/4515 [00:21<00:00, 212.06it/s]


Save best model with loss = 0.5727280940130304


Epoch 238: Total loss: 0.5763711730457619, Accuracy: 0.8013: 100%|██████████| 4515/4515 [00:21<00:00, 212.37it/s]
Epoch 239: Total loss: 0.5768732924133968, Accuracy: 0.8006: 100%|██████████| 4515/4515 [00:21<00:00, 212.19it/s]
Epoch 240: Total loss: 0.5765993470361133, Accuracy: 0.8017: 100%|██████████| 4515/4515 [00:20<00:00, 216.73it/s]
Epoch 241: Total loss: 0.5751967276565129, Accuracy: 0.8009: 100%|██████████| 4515/4515 [00:21<00:00, 214.47it/s]
Epoch 242: Total loss: 0.5751603304976772, Accuracy: 0.8007: 100%|██████████| 4515/4515 [00:21<00:00, 214.17it/s]
Epoch 243: Total loss: 0.5713738622426459, Accuracy: 0.8025: 100%|██████████| 4515/4515 [00:20<00:00, 215.62it/s]


Save best model with loss = 0.5713738622426459


Epoch 244: Total loss: 0.5727258242080906, Accuracy: 0.8017: 100%|██████████| 4515/4515 [00:21<00:00, 212.52it/s]
Epoch 245: Total loss: 0.5685641007268943, Accuracy: 0.8035: 100%|██████████| 4515/4515 [00:21<00:00, 214.96it/s]


Save best model with loss = 0.5685641007268943


Epoch 246: Total loss: 0.57140805461544, Accuracy: 0.8012: 100%|██████████| 4515/4515 [00:20<00:00, 215.34it/s]  
Epoch 247: Total loss: 0.5712081915780424, Accuracy: 0.8023: 100%|██████████| 4515/4515 [00:21<00:00, 212.23it/s]
Epoch 248: Total loss: 0.5657627446970934, Accuracy: 0.8044: 100%|██████████| 4515/4515 [00:21<00:00, 214.43it/s]


Save best model with loss = 0.5657627446970934


Epoch 249: Total loss: 0.5656855894201089, Accuracy: 0.8042: 100%|██████████| 4515/4515 [00:21<00:00, 212.24it/s]


Save best model with loss = 0.5656855894201089


Epoch 250: Total loss: 0.5654700287486759, Accuracy: 0.8040: 100%|██████████| 4515/4515 [00:21<00:00, 212.03it/s]


Save best model with loss = 0.5654700287486759


Epoch 251: Total loss: 0.5629124496433294, Accuracy: 0.8045: 100%|██████████| 4515/4515 [00:21<00:00, 212.48it/s]


Save best model with loss = 0.5629124496433294


Epoch 252: Total loss: 0.5673216927685215, Accuracy: 0.8044: 100%|██████████| 4515/4515 [00:21<00:00, 213.44it/s]
Epoch 253: Total loss: 0.5641553477558451, Accuracy: 0.8048: 100%|██████████| 4515/4515 [00:22<00:00, 199.66it/s]
Epoch 254: Total loss: 0.557449280078691, Accuracy: 0.8065: 100%|██████████| 4515/4515 [00:21<00:00, 206.47it/s] 


Save best model with loss = 0.557449280078691


Epoch 255: Total loss: 0.5591565649069491, Accuracy: 0.8065: 100%|██████████| 4515/4515 [00:21<00:00, 205.42it/s]
Epoch 256: Total loss: 0.5608970146671615, Accuracy: 0.8057: 100%|██████████| 4515/4515 [00:21<00:00, 207.18it/s]
Epoch 257: Total loss: 0.5603536002395689, Accuracy: 0.8059: 100%|██████████| 4515/4515 [00:21<00:00, 208.99it/s]
Epoch 258: Total loss: 0.5615304682077892, Accuracy: 0.8062: 100%|██████████| 4515/4515 [00:21<00:00, 207.78it/s]
Epoch 259: Total loss: 0.5563162717262772, Accuracy: 0.8072: 100%|██████████| 4515/4515 [00:21<00:00, 206.57it/s]


Save best model with loss = 0.5563162717262772


Epoch 260: Total loss: 0.5563227049404073, Accuracy: 0.8078: 100%|██████████| 4515/4515 [00:21<00:00, 206.56it/s]
Epoch 261: Total loss: 0.5532171191798229, Accuracy: 0.8080: 100%|██████████| 4515/4515 [00:21<00:00, 207.65it/s]


Save best model with loss = 0.5532171191798229


Epoch 262: Total loss: 0.5553413162925843, Accuracy: 0.8068: 100%|██████████| 4515/4515 [00:21<00:00, 213.14it/s]
Epoch 263: Total loss: 0.5575433770320476, Accuracy: 0.8074: 100%|██████████| 4515/4515 [00:22<00:00, 204.86it/s]
Epoch 264: Total loss: 0.5590258048700864, Accuracy: 0.8079: 100%|██████████| 4515/4515 [00:21<00:00, 212.96it/s]
Epoch 265: Total loss: 0.5545458253402911, Accuracy: 0.8080: 100%|██████████| 4515/4515 [00:22<00:00, 202.92it/s]
Epoch 266: Total loss: 0.5506928790348313, Accuracy: 0.8098: 100%|██████████| 4515/4515 [00:23<00:00, 190.92it/s]


Save best model with loss = 0.5506928790348313


Epoch 267: Total loss: 0.5460933999019207, Accuracy: 0.8103: 100%|██████████| 4515/4515 [00:21<00:00, 207.56it/s]


Save best model with loss = 0.5460933999019207


Epoch 268: Total loss: 0.5505665554094684, Accuracy: 0.8088: 100%|██████████| 4515/4515 [00:21<00:00, 210.62it/s]
Epoch 269: Total loss: 0.549675970251444, Accuracy: 0.8106: 100%|██████████| 4515/4515 [00:21<00:00, 207.39it/s] 
Epoch 270: Total loss: 0.5476875339202575, Accuracy: 0.8106: 100%|██████████| 4515/4515 [00:21<00:00, 211.13it/s]
Epoch 271: Total loss: 0.545277435658448, Accuracy: 0.8102: 100%|██████████| 4515/4515 [00:21<00:00, 209.24it/s] 


Save best model with loss = 0.545277435658448


Epoch 272: Total loss: 0.5453003476079912, Accuracy: 0.8122: 100%|██████████| 4515/4515 [00:21<00:00, 212.18it/s]
Epoch 273: Total loss: 0.5501321280814484, Accuracy: 0.8086: 100%|██████████| 4515/4515 [00:20<00:00, 216.97it/s]
Epoch 274: Total loss: 0.5447286728971689, Accuracy: 0.8122: 100%|██████████| 4515/4515 [00:21<00:00, 214.16it/s]


Save best model with loss = 0.5447286728971689


Epoch 275: Total loss: 0.5419169302052042, Accuracy: 0.8116: 100%|██████████| 4515/4515 [00:21<00:00, 213.67it/s]


Save best model with loss = 0.5419169302052042


Epoch 276: Total loss: 0.5442216717622604, Accuracy: 0.8115: 100%|██████████| 4515/4515 [00:21<00:00, 214.48it/s]
Epoch 277: Total loss: 0.5464972716893337, Accuracy: 0.8107: 100%|██████████| 4515/4515 [00:21<00:00, 213.23it/s]
Epoch 278: Total loss: 0.541707568301315, Accuracy: 0.8122: 100%|██████████| 4515/4515 [00:21<00:00, 212.56it/s] 


Save best model with loss = 0.541707568301315


Epoch 279: Total loss: 0.5452176593697876, Accuracy: 0.8105: 100%|██████████| 4515/4515 [00:21<00:00, 213.27it/s]
Epoch 280: Total loss: 0.5412664460086347, Accuracy: 0.8131: 100%|██████████| 4515/4515 [00:21<00:00, 210.27it/s] 


Save best model with loss = 0.5412664460086347


Epoch 281: Total loss: 0.5345995447362305, Accuracy: 0.8141:  86%|████████▋ | 3903/4515 [00:18<00:02, 207.23it/s]

In [51]:
# Assuming the best weights are saved in a file 'best_model.pth'
model = TextModel()
model.load_state_dict(torch.load('weights/best_char_gen.pth'))
model.eval()  # Set model to evaluation mode

def predict_next_char(input_str, model, chars_to_int, int_to_chars, n_vocab):
    # Convert input string to a list of integers
    input_ints = [chars_to_int[char] for char in input_str]
    
    # Prepare the input tensor, adding batch dimension and sequence length
    input_tensor = torch.tensor(input_ints).unsqueeze(0).unsqueeze(-1).float()
    input_tensor = input_tensor / float(n_vocab)
    # Pass the input through the model to get the predicted index
    with torch.no_grad():  # No need to calculate gradients during inference
        output = model(input_tensor)
    
    # Get the predicted character's index (taking the argmax to get the most likely class)
    predicted_index = torch.argmax(output, dim=-1).item()
    # Convert predicted index back to the corresponding character
    predicted_char = int_to_chars[predicted_index]
    return predicted_char

# Example of predicting the next character after 'hello'
input_str = 'alic'
predicted_char = predict_next_char(input_str, model, chars_to_int, int_to_chars, n_vocab)
print(f"The next predicted character after '{input_str}' is: '{predicted_char}'")

The next predicted character after 'alic' is: ' '


  model.load_state_dict(torch.load('weights/best_char_gen.pth'))
