In [114]:
import torch
import torch.nn as nn
import pandas as pd
import re
import string
from torch.utils.data import Dataset,DataLoader

In [115]:
class CustomDataset(Dataset):

    def __init__(self,csv_file):
        self.data = pd.read_csv(csv_file,sep='\t')
        self.word2index = {}  # palavra : indice (numero inteiro)
        self.index2word = {}
        self.build_vocab()

    def build_vocab(self):
        words = ' '.join(self.data['text']).lower()
        words = re.sub('['+string.punctuation+']','',words)
        words = words.split()
        unique_words = set(words) #vocabulario nao pode ter palavras repetidas
        self.word2index = {word: index for index, word in enumerate(unique_words)}
        self.index2word = {index: word for word,index in self.word2index.items()}
        tam = len(self.word2index)
        self.word2index[''] = tam
        self.index2word[tam] = ''

    def __len__(self):
        return len(self.data)

    # receber um indice, pre-processar a sentenca e retorna-la para onde foi chamada
    def __getitem__(self,index):
        text = self.data.iloc[index]['text']
        text = text.lower()
        text = re.sub('['+string.punctuation+']','',text)

        tokens = [token for token in text.split() if token != '']
        labels = [tokens[i+1] if i < len(tokens)-1 else '' for i in range(0,len(tokens))]

        input_ids = [self.word2index[token] for token in tokens]
        label_ids = [self.word2index[label] for label in labels]

        return {'input_ids' : torch.tensor(input_ids), 'labels_ids' : torch.tensor(label_ids)}

In [160]:
class RNN(nn.Module):
    def __init__(self,vocab_size,embedding_dim,hidden_dim,num_layers):
        super(RNN,self).__init__()
        self.embedding = nn.Embedding(vocab_size,embedding_dim)
        self.rnn = nn.RNN(embedding_dim,hidden_dim,num_layers,nonlinearity='relu') #dimensao do vetor de palavra, dimensao do estado interno
        self.fcl = nn.Linear(hidden_dim,vocab_size)
        self.hidden_dim = hidden_dim

    def forward(self,word):
        x = self.embedding(word)
        output,hidden = self.rnn(x)
        output = self.fcl(output)
        return output,hidden

    def init_hidden(self,batch_size):
        torch.zeros(self.num_layers,batch_size,self.hidden_dim)

Hiperparâmetros

In [117]:
embedding_dim = 128
hidden_dim = 256
num_layers = 1
learning_rate = 0.00001
num_epochs = 50

In [118]:
corpus = CustomDataset('https://raw.githubusercontent.com/giacicunb/enap_pln2024/main/corpora/simple_corpus.csv')
dataloader = DataLoader(corpus,batch_size=1)

In [119]:
vocab_size = len(corpus.word2index)
print(vocab_size)

362


In [None]:
corpus.word2index

In [162]:
language_model = RNN(vocab_size,embedding_dim,hidden_dim,num_layers=1)

Aplica implicitamente a função softmax

In [163]:
loss_function = nn.CrossEntropyLoss()

In [164]:
optimizer = torch.optim.Adam(language_model.parameters(),lr=learning_rate)

In [165]:
for epoch in range(0,num_epochs):

    total_loss = 0

    for batch in dataloader:

        tokens,labels = batch['input_ids'],batch['labels_ids']

        optimizer.zero_grad()

        output,_ = language_model(tokens)

        output_flat = output.view(-1,output.shape[-1])
        labels_flat = labels.view(-1)

        loss = loss_function(output_flat,labels_flat)

        loss.backward()

        optimizer.step()

        total_loss += loss.item()
    print(f'Epoch {epoch+1} ======= Loss: {total_loss/len(dataloader)}')



In [170]:
def predict_next_word(model,input_text):

    text = input_text.lower()
    text = re.sub('['+string.punctuation+']','',text)
    text = text.split()

    input_tensor =  torch.tensor([[corpus.word2index[word] for word in text]])

    output,_ = model(input_tensor)

    prob_value,predicted_index=torch.max(output[:,-1],1)

    predicted_word = corpus.index2word[predicted_index.item()]
    return predicted_word

In [172]:
input_text = "cacau"

predicted_word = predict_next_word(language_model,input_text)
print(f'A proxima palavra apos {input_text} eh {predicted_word}')

A proxima palavra apos cacau eh algumas
