In [None]:
import torch 
import torch.nn as nn 
import sys 
from pathlib import Path
import math 
from utils import simple_cleansing
from data import TransformerDataset
from model import Model

from inference import generate_text
from torch import cuda

from time import perf_counter
from torch.utils.data import DataLoader
import csv

import numpy as np 

In [None]:
n_sequence = 50 
path = Path('texts/Pride_and_Prejudice.txt')
dataset = TransformerDataset(path.read_text(),n_sequence)
loader = DataLoader(dataset, batch_size=10, shuffle=True)


n_features = 256 
n_heads = 4 
n_vocabulary = len(dataset.vocabulary)

model = Model(n_vocabulary, n_heads, n_features)

epochs = 20



In [None]:
ce_loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
device = 'cuda:0' if cuda.is_available() else 'cpu'
model = model.to(device)

# Text of the output
print(f'Start the training for {epochs} epochs')

total_time = 0  # Time consumed from the first epoch

model.train()
epoch_accuracies = []  # The list of the average accuracies for each epoch
epoch_losses = []  # The list of the average losses for each epoch

for epoch in range(1, epochs):
    tic = perf_counter()

    running_loss = []  # Loss for each batch
    running_accuracy = []  # Number of good prediction for each batch
    n_words = []  # List of number of words for each  batch

    for idx, data in enumerate(loader):
        optimizer.zero_grad()
        x = data[:, :-1]
        y = data[:, 1:]

        x = x.to(device)
        y = y.to(device)

        y_hat = model(x)
        loss = ce_loss(y_hat.transpose(1,2), y)

        loss.backward()
        optimizer.step()

        running_loss.append(loss.item())
        running_accuracy.append(torch.sum( y == torch.argmax(y_hat, dim=-1)).item())
        n_words.append(y.shape[0] * y.shape[1])
        print(f'{idx} {np.mean(running_loss):.4f}', end='\r')  # Output the average loss so far in this epoch

    tac = perf_counter()
    time = tac - tic  # Time for one epoch
    total_time += time

    epoch_accuracy = np.sum(running_accuracy) / np.sum(n_words)
    epoch_loss = np.sum(np.array(running_loss) * np.array(n_words)) / np.sum(n_words)
    epoch_accuracies.append(epoch_accuracy)
    epoch_losses.append(epoch_loss)
    print(f'epoch {epoch} loss : {epoch_loss:.3f} accuracy: {epoch_accuracy:.3%}')

print(f'Average time by epoch :{total_time / epoch:.1f} s')
model.to('cpu')


In [None]:
torch.save(model, 'outputs/model.pt')


with open('outputs/vocabulary.csv', 'w', newline='') as csvfile:
    dicowriter = csv.writer(csvfile, delimiter=' ',quoting=csv.QUOTE_ALL)
    for word in dataset.vocabulary:
        dicowriter.writerow([word])

In [None]:
with open('outputs/vocabulary.csv', newline='') as csvfile:
    dicoreader = csv.reader(csvfile, delimiter=' ')
    vocabulary = [ row[0] for row in dicoreader ]

model = torch.load('outputs/model.pt')
model.eval() 

start = '<SOS>  '
generate_text(model, vocabulary, start, length=35)