In [19]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## 2.4 Dates

In [31]:
import torch
from utils import seed_everything, parse_date_with_prompt, generate_random_date, read_file
from dataset_text import TextDataset, TextTokenizer
from train_text import TextTransformer
seed_everything(42)
torch.set_default_dtype(torch.float64)
import csv
## IF not backend is detected
# import torchaudio

# # Check the current backend
# torchaudio.set_audio_backend("soundfile")
# torchaudio.list_audio_backends()

### 2.4.1 Syntetic Data Generation

El codigo sintetico esta creado de forma sencilla  basandome en los tipos de prompts de training y test.
Lo mas dificil de ver es si las fechas tipo proximo viernes, martes, etc. son correctas de forma manual una visualizacion sencilla:

In [50]:
# A dia 4 de enero de 2025 (sabado)
print(parse_date_with_prompt("04/01/25 el proximo viernes")) # 10/01/25
print(parse_date_with_prompt("04/01/25 el proximo martes")) # 07/01/25
print(parse_date_with_prompt("04/01/25 el proximo sábado")) # 11/01/25

10/01/25
07/01/25
11/01/25


In [46]:
def create_data(output_path, input_path="fechas1/fechas1_train.csv"):
    promt= []
    syntetic_date = []
    for date in read_file(input_path)[1]:
        random_date=generate_random_date()
        promt.append(f"{random_date} {date}")
        syntetic_date.append(parse_date_with_prompt(promt[-1]))
    print(f"{promt[0]} -> {syntetic_date[0]}")

    with open(output_path, mode="w", newline="", encoding="utf-8") as file:
        writer = csv.writer(file)
        # Escribir los encabezados
        writer.writerow(["Instrucción", "Fecha Resultado"])
        # Escribir las filas de datos combinando las listas
        for instruccion,fecha  in zip(promt, syntetic_date):
            writer.writerow([instruccion, fecha])

In [47]:
create_data("fechas1/fechas1_train_sintetic.csv")
create_data("fechas1/fechas1_test_sintetic.csv", "fechas1/fechas1_test.csv")

13/05/57 por favor el siguiente jueves -> 17/05/57
31/07/08 pasado mañana gracias -> 02/08/08


### 2.4.2 Tokenizer

In [32]:
tokenizer=TextTokenizer("fechas1/fechas1_train_sintetic.csv")
train_dataset = TextDataset("fechas1/fechas1_train_sintetic.csv", tokenizer)

In [67]:
tokenizer.word2idx

{'gracias': 15,
 'siguiente': 16,
 'el': 17,
 'viernes': 18,
 'de': 19,
 'por': 20,
 'jueves': 21,
 'tres': 22,
 'este': 23,
 'mañana': 24,
 'martes': 25,
 'favor': 26,
 'en': 27,
 'que': 28,
 'lunes': 29,
 'días': 30,
 'próximo': 31,
 'un': 32,
 'pasado': 33,
 'miércoles': 34,
 'par': 35,
 'viene': 36,
 '<unk>': 0,
 '<pad>': 1,
 '<sos>': 2,
 '<eos>': 3,
 '/': 4,
 '0': 5,
 '1': 6,
 '2': 7,
 '3': 8,
 '4': 9,
 '5': 10,
 '6': 11,
 '7': 12,
 '8': 13,
 '9': 14}

In [33]:
a="13/05/57 por favor el siguiente jueves"
print(tokenizer.encode(a))
print(tokenizer.decode(tokenizer.encode(a)))

tensor([ 2,  6,  8,  4,  5, 10,  4, 10, 12, 34, 29, 35, 32, 24,  3])
13/05/57 por favor el siguiente jueves


In [34]:
a="13/05/57"
print(tokenizer.encode(a))
print(tokenizer.decode(tokenizer.encode(a)))

tensor([ 2,  6,  8,  4,  5, 10,  4, 10, 12,  3])
13/05/57


### 2.4.3 Dataset

In [35]:
tokenizer=TextTokenizer("fechas1/fechas1_train_sintetic.csv")
train_dataset = TextDataset("fechas1/fechas1_train_sintetic.csv", tokenizer)

In [36]:
data=train_dataset[0]
print(data)
print(tokenizer.decode(data[0]))
print(tokenizer.decode(data[1]))
print(train_dataset.promts[0], train_dataset.dates[0])

(tensor([ 2,  6,  8,  4,  5, 10,  4, 10, 12, 34, 29, 35, 32, 24,  3,  1,  1,  1,
         1,  1,  1,  1,  1,  1]), tensor([ 2,  6, 12,  4,  5, 10,  4, 10, 12,  3,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1]))
13/05/57 por favor el siguiente jueves
17/05/57
13/05/57 por favor el siguiente jueves 17/05/57


### 2.4.3 Training

In [37]:
model = TextTransformer(vocab_size=len(tokenizer.idx2word.keys()), d_model=256, nb_layers=4, 
                         d_ff=512, n_heads=8, d_head=32, dropout=0.1, seq_len=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
opt = torch.optim.Adam(model.parameters(), lr=3e-4)

nb_epochs = 5
batch_size = 16
model.train()

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
for e in range(nb_epochs):
    avg_loss = 0
    for x, y in trainloader:
        x = x.to(device)
        y = y.to(device)
        opt.zero_grad()
        loss = model.loss(x, y)
        loss.backward()
        opt.step()
        avg_loss += loss.item()
    print('epoch %d/%d: avg_loss: %.2f' % (e,nb_epochs,avg_loss/len(trainloader)))
       
torch.save([model, opt], 'model_24.pt')
torch.save(tokenizer, 'tokenizer_24.pth')


epoch 0/5: avg_loss: 0.18
epoch 1/5: avg_loss: 0.09
epoch 2/5: avg_loss: 0.08
epoch 3/5: avg_loss: 0.07
epoch 4/5: avg_loss: 0.07


### 1.3.2 Test

In [38]:
import editdistance

In [39]:
# [model, opt] = torch.load('model_24.pt')
# tokenizer = torch.load('tokenizer_24.pth')
testset=TextDataset('fechas1/fechas1_test_sintetic.csv', tokenizer)

In [49]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x,y = testset[2]
x = x.to(device)
y_pred = model.generate(x[None,...])
hyp = ' '.join([str(i) for i in y_pred[1:-1]])
y = y.numpy().tolist()
y = y[:y.index(3)]
ref = ' '.join([str(i) for i in y[1:]])
print(hyp)
print(ref)
print(editdistance.eval(hyp, ref))

7 7 4 5 6 4 9 5
7 7 4 5 6 4 9 5
0


In [46]:
model.eval()
err = 0
num = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for i,(x, y) in enumerate(testset):    
    x = x.to(device)    
    y_pred = model.generate(x[None,...])
    hyp = ' '.join([str(i) for i in y_pred[1:-1]])
    # print('hyp', hyp)

    y = y.numpy().tolist()
    # find the first 3 <eos> in list y
    y = y[:y.index(3)]
    ref = ' '.join([str(i) for i in y[1:]])
    # print('ref', ref)
    # print('(%d/%d)' % (i, len(testset)) )
    
    # edit distance
    err += editdistance.eval(hyp, ref)
    num += len(ref.split())
    
print(f'error rate {err/num:.2%},  ({err}/{num})')

error rate 15.17%,  (1214/8000)
