In [1]:
# Read data.txt

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

# Read data.txt line by line
with open('data.txt', 'r') as f:
    lines = f.readlines()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# Create a dataclass from torch
from torch.utils.data import Dataset, DataLoader

# Create a class for the dataset
class WordEmbeddingDataset(Dataset):
    def __init__(self, lines, context_size=2):
        # Transform args to self
        for key, value in locals().items():
            if key != 'self':
                setattr(self, key, value)
                
        # Read data.txt line by line
        with open('data.txt', 'r') as f:
            self.lines = f.readlines()

        # Process lines
        self.process_lines()

        # Concatenate all lines
        self.text = ' '.join(self.lines)

        # Create a list of words
        self.words = self.text.split()
        self.vocab_size = len(set(self.words))


        # Create a dictionary of words
        # for one-hot encoding
        self.word2idx = {word: idx for idx, word in enumerate(set(self.words))}
        self.idx2word = {idx: word for idx, word in enumerate(set(self.words))}

        # Convert words to vectors with one-hot encoding
        self.words = [self.one_hot_encode(word) for word in tqdm(self.words, desc='One-hot encoding')]
        
        data = []
        target = []
        # Create a list of tuples
        # (next_word, [context_words])
        for i in tqdm(range(context_size, len(self.words) - context_size), desc='Preparing data'):
            context = []
            for j in range(i - context_size, i):
                context.append(self.words[j])
            data.append(context)
            target.append(self.words[i])

        # Convert to numpy array, this makes it faster aparently
        data = np.array(data)
        target = np.array(target)

        # Here data has the shape of (n_samples, context_size, vocab_size)
        # Target has the shape of (n_samples, vocab_size)
        # Let's reshape data to (n_samples, context_size * vocab_size)
        data = data.reshape(data.shape[0], data.shape[1] * data.shape[2])
            
        # Convert to torch tensor
        self.data = torch.tensor(data, dtype=torch.float32).to(device)
        self.target = torch.tensor(target, dtype=torch.float32).to(device)


    def process_lines(self):
        self.lines = [line.lower() for line in self.lines]
        self.lines = [line.replace('\n', '') for line in self.lines]
        self.lines = [''.join([c for c in line if c.isalnum() or c == ' ']) for line in self.lines]
            
    # One-hot encoding
    def one_hot_encode(self, word):
        x = np.zeros(len(self.word2idx))
        x[self.word2idx[word]] = 1
        return x

    def one_hot_decode(self, x):
        return self.idx2word[np.argmax(x)]

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]

    def get_all(self):
        return self.data, self.target

In [3]:
# Create an instance of the dataset
dataset = WordEmbeddingDataset(lines, context_size=5)

One-hot encoding: 100%|██████████| 24042/24042 [00:00<00:00, 67706.98it/s]
Preparing data: 100%|██████████| 24032/24032 [00:00<00:00, 681732.20it/s]


In [4]:
# Inicialize our model
from torch import nn

# Create a class for the model
class WordEmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, context_size):
        super(WordEmbeddingModel, self).__init__()
        
        # This will be a NN with 3 layers
        # 1. Input layer
        # 2. Hidden layer
        # 3. Output layer

        # Input layer
        self.lay1 = nn.Linear(vocab_size * context_size, 128)
        # Activation function
        self.relu1 = nn.ReLU()

        self.lay2 = nn.Linear(128, embedding_dim)
        # Activation function
        self.relu2 = nn.ReLU()

        self.lay3 = nn.Linear(embedding_dim, vocab_size)
        # Activation function
        self.relu3 = nn.ReLU()

        # Output layer
        self.lay4 = nn.Linear(vocab_size, vocab_size)
        # Activation function
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.lay1(x)
        x = self.relu1(x)

        x = self.lay2(x)
        x = self.relu2(x)

        x = self.lay3(x)
        x = self.relu3(x)

        x = self.lay4(x)
        x = self.softmax(x)

        return x

    def embedd(self, x):
        x = self.lay1(x)
        x = self.relu1(x)

        x = self.lay2(x)
        x = self.relu2(x)

        return x

# Create an instance of the model
# which will try to predict the next word
# given a word
model = WordEmbeddingModel(vocab_size=dataset.vocab_size, embedding_dim=20, context_size=5).to(device)

# Create a loss function
loss_fn = nn.MSELoss()

# Create an optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Create a data loader
dataloader = DataLoader(dataset, batch_size=100, shuffle=False)

# Train the model
max_patience = 10
patience = 0

previous_loss = np.inf
for epoch in range(100):
    for i, batch in tqdm(enumerate(dataloader), desc=f'Epoch {epoch + 1}', total=len(dataloader), leave=True):
        data, target = batch
        
        # Forward pass
        y_pred = model(data)

        # Compute loss
        loss = loss_fn(y_pred, target)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch: {epoch + 1}, Loss: {loss.item():.8f}')

    # Save
    if loss.item() < previous_loss:
        previous_loss = loss.item()
        print('Saving model...')
        torch.save(model.state_dict(), f'model.pth')
        patience = 0

    # Early stopping
    else:
        patience += 1
        if patience == max_patience:
            print('Early stopping...')
            break


Epoch 1: 100%|██████████| 241/241 [00:02<00:00, 87.19it/s] 


Epoch: 1, Loss: 0.00027107
Saving model...


Epoch 2: 100%|██████████| 241/241 [00:02<00:00, 113.22it/s]


Epoch: 2, Loss: 0.00027485


Epoch 3: 100%|██████████| 241/241 [00:02<00:00, 113.22it/s]


Epoch: 3, Loss: 0.00027394


Epoch 4: 100%|██████████| 241/241 [00:02<00:00, 113.25it/s]


Epoch: 4, Loss: 0.00026483
Saving model...


Epoch 5: 100%|██████████| 241/241 [00:02<00:00, 113.20it/s]


Epoch: 5, Loss: 0.00026713


Epoch 6: 100%|██████████| 241/241 [00:02<00:00, 113.17it/s]


Epoch: 6, Loss: 0.00026533


Epoch 7: 100%|██████████| 241/241 [00:02<00:00, 113.16it/s]


Epoch: 7, Loss: 0.00025007
Saving model...


Epoch 8: 100%|██████████| 241/241 [00:02<00:00, 113.13it/s]


Epoch: 8, Loss: 0.00024246
Saving model...


Epoch 9: 100%|██████████| 241/241 [00:02<00:00, 113.08it/s]


Epoch: 9, Loss: 0.00023827
Saving model...


Epoch 10: 100%|██████████| 241/241 [00:02<00:00, 113.24it/s]


Epoch: 10, Loss: 0.00022381
Saving model...


Epoch 11: 100%|██████████| 241/241 [00:02<00:00, 113.18it/s]


Epoch: 11, Loss: 0.00022083
Saving model...


Epoch 12: 100%|██████████| 241/241 [00:02<00:00, 113.17it/s]


Epoch: 12, Loss: 0.00022672


Epoch 13: 100%|██████████| 241/241 [00:02<00:00, 113.12it/s]


Epoch: 13, Loss: 0.00022082
Saving model...


Epoch 14: 100%|██████████| 241/241 [00:02<00:00, 113.17it/s]


Epoch: 14, Loss: 0.00022307


Epoch 15: 100%|██████████| 241/241 [00:02<00:00, 113.15it/s]


Epoch: 15, Loss: 0.00022091


Epoch 16: 100%|██████████| 241/241 [00:02<00:00, 113.13it/s]


Epoch: 16, Loss: 0.00022189


Epoch 17: 100%|██████████| 241/241 [00:02<00:00, 113.11it/s]


Epoch: 17, Loss: 0.00021555
Saving model...


Epoch 18: 100%|██████████| 241/241 [00:02<00:00, 113.13it/s]


Epoch: 18, Loss: 0.00021121
Saving model...


Epoch 19: 100%|██████████| 241/241 [00:02<00:00, 113.09it/s]


Epoch: 19, Loss: 0.00021241


Epoch 20: 100%|██████████| 241/241 [00:02<00:00, 113.14it/s]


Epoch: 20, Loss: 0.00020656
Saving model...


Epoch 21: 100%|██████████| 241/241 [00:02<00:00, 113.09it/s]


Epoch: 21, Loss: 0.00020472
Saving model...


Epoch 22: 100%|██████████| 241/241 [00:02<00:00, 113.12it/s]


Epoch: 22, Loss: 0.00020794


Epoch 23: 100%|██████████| 241/241 [00:02<00:00, 113.05it/s]


Epoch: 23, Loss: 0.00020514


Epoch 24: 100%|██████████| 241/241 [00:02<00:00, 113.10it/s]


Epoch: 24, Loss: 0.00020553


Epoch 25: 100%|██████████| 241/241 [00:02<00:00, 113.11it/s]


Epoch: 25, Loss: 0.00020261
Saving model...


Epoch 26: 100%|██████████| 241/241 [00:02<00:00, 113.16it/s]


Epoch: 26, Loss: 0.00020574


Epoch 27: 100%|██████████| 241/241 [00:02<00:00, 113.14it/s]


Epoch: 27, Loss: 0.00020217
Saving model...


Epoch 28: 100%|██████████| 241/241 [00:02<00:00, 113.20it/s]


Epoch: 28, Loss: 0.00020226


Epoch 29: 100%|██████████| 241/241 [00:02<00:00, 113.16it/s]


Epoch: 29, Loss: 0.00020184
Saving model...


Epoch 30: 100%|██████████| 241/241 [00:02<00:00, 113.16it/s]


Epoch: 30, Loss: 0.00020135
Saving model...


Epoch 31: 100%|██████████| 241/241 [00:02<00:00, 113.14it/s]


Epoch: 31, Loss: 0.00020248


Epoch 32: 100%|██████████| 241/241 [00:02<00:00, 113.07it/s]


Epoch: 32, Loss: 0.00019887
Saving model...


Epoch 33: 100%|██████████| 241/241 [00:02<00:00, 113.12it/s]


Epoch: 33, Loss: 0.00020113


Epoch 34: 100%|██████████| 241/241 [00:02<00:00, 113.00it/s]


Epoch: 34, Loss: 0.00019870
Saving model...


Epoch 35: 100%|██████████| 241/241 [00:02<00:00, 113.02it/s]


Epoch: 35, Loss: 0.00021201


Epoch 36: 100%|██████████| 241/241 [00:02<00:00, 113.02it/s]


Epoch: 36, Loss: 0.00022449


Epoch 37: 100%|██████████| 241/241 [00:02<00:00, 113.01it/s]


Epoch: 37, Loss: 0.00020745


Epoch 38: 100%|██████████| 241/241 [00:02<00:00, 113.01it/s]


Epoch: 38, Loss: 0.00020399


Epoch 39: 100%|██████████| 241/241 [00:02<00:00, 112.93it/s]


Epoch: 39, Loss: 0.00020416


Epoch 40: 100%|██████████| 241/241 [00:02<00:00, 112.97it/s]


Epoch: 40, Loss: 0.00020468


Epoch 41: 100%|██████████| 241/241 [00:02<00:00, 112.85it/s]


Epoch: 41, Loss: 0.00020406


Epoch 42: 100%|██████████| 241/241 [00:02<00:00, 112.86it/s]


Epoch: 42, Loss: 0.00021894


Epoch 43: 100%|██████████| 241/241 [00:02<00:00, 112.93it/s]


Epoch: 43, Loss: 0.00021604


Epoch 44: 100%|██████████| 241/241 [00:15<00:00, 15.38it/s]


Epoch: 44, Loss: 0.00022354
Early stopping...


In [6]:
# Load the model
model.load_state_dict(torch.load('model.pth'))

<All keys matched successfully>

# Analysing the results of the word embedding

In [11]:
# Get top words
words = dataset.words
words = [dataset.one_hot_decode(word) for word in words]

# Predict next word
def predict_next_word(context):
    # Convert to one-hot
    context = [dataset.one_hot_encode(word) for word in context]
    # Flatten
    context = [item for sublist in context for item in sublist]
    # Convert to tensor
    context = torch.tensor(context, dtype=torch.float32).to(device)
    # Reshape
    context = context.reshape(1, context.shape[0])
    # Predict
    y_pred = model(context)
    # Get the word
    word = words[np.argmax(y_pred.detach().cpu().numpy())]
    return word

In [13]:
predict_next_word(['inform', 'the', 'commander', 'that', 'lord'])

'blob'

In [14]:
predict_next_word(['the', 'commander', 'that', 'lord', 'blob'])

'coughing'

In [15]:
predict_next_word(['commander', 'that', 'lord', 'blob', 'coughing'])

'the'

In [19]:
# Embedd a sentence
def embedd_sentence(sentence):
    # Convert to one-hot
    sentence = [dataset.one_hot_encode(word) for word in sentence]
    # Flatten
    sentence = [item for sublist in sentence for item in sublist]
    # Convert to tensor
    sentence = torch.tensor(sentence, dtype=torch.float32).to(device)
    # Reshape
    sentence = sentence.reshape(1, sentence.shape[0])
    # Embedd
    embedding = model.embedd(sentence)
    return embedding

embedd_sentence(['inform', 'the', 'commander', 'that', 'lord'])

tensor([[0.0000, 0.1742, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 3.1596, 0.0000, 0.0000,
         0.0000, 3.2487]], device='cuda:0', grad_fn=<ReluBackward0>)