In [1]:
from utils.preproc_functions import *
from utils.model import Word2Vec, CustomDataset
from utils.embedding_trainer import train_word2vec
from torch.utils.data import DataLoader

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Vladimir\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [2]:
# Load lenta data from source
corpus = get_corpus(num_doc=1500)

In [3]:
# Prepare training data and vocab from corpus
preproc_corpus = corpus_prepros(corpus)
data, vocab = data_preparation(preproc_corpus, method='cbow')

100%|██████████| 1500/1500 [00:09<00:00, 153.82it/s]


In [4]:
VOCAB_SIZE = len(vocab)
EMBEDDING_DIM = 10
BATCH_SIZE = 64

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Split the data into training and test sets
split_index = int(len(data) * 0.90)
train_dataset = CustomDataset(data[:split_index])
valid_dataset = CustomDataset(data[split_index:])

# Define dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Define and train model
model = Word2Vec(VOCAB_SIZE, EMBEDDING_DIM)
model = train_word2vec(model, train_dataloader, valid_dataloader, device, epochs=10, learning_rate=0.001)

# Get result dict
params = list(model.parameters())
word_vectors = params[0].detach()
unique_words = list(vocab.keys())
word_dict = {word: vector.cpu().numpy() for word, vector in zip(unique_words, word_vectors)} 

Epoch 1/10: Train Loss:  7.965456402063736 ||| Validation Loss:  7.476635755341628
Epoch 2/10: Train Loss:  7.044787921297568 ||| Validation Loss:  7.16904864804498
Epoch 3/10: Train Loss:  6.737176429657709 ||| Validation Loss:  7.042442236275509
Epoch 4/10: Train Loss:  6.560011132887798 ||| Validation Loss:  6.978517044001612
Epoch 5/10: Train Loss:  6.439085873590637 ||| Validation Loss:  6.94392278605494
Epoch 6/10: Train Loss:  6.345834787540172 ||| Validation Loss:  6.925826810968333
Epoch 7/10: Train Loss:  6.26879526576322 ||| Validation Loss:  6.913416102836872
Epoch 8/10: Train Loss:  6.201809193864579 ||| Validation Loss:  6.903815118197737
Epoch 9/10: Train Loss:  6.142847096132609 ||| Validation Loss:  6.8980178339727996
Epoch 10/10: Train Loss:  6.0906869780999 ||| Validation Loss:  6.89230867254323
