In [None]:
import os
import json
import re

import pandas as pd

from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.rnn as rnn_utils

from torch.utils.data import Dataset, DataLoader

from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
# Define vocabulary class
class Vocabulary:
    def __init__(self, min_freq=1):
        self.word2idx = {'<UNK>': 0}  # Add a special token for unknown words
        self.idx2word = {0: '<UNK>'}
        self.min_freq = min_freq

    def build_vocab(self, texts):
        word_counts = Counter(re.findall(r'\w+', ' '.join(texts).lower()))

        idx = 1  # Start indices from 1 since 0 is reserved for <UNK>
        for word, count in word_counts.items():
            if count >= self.min_freq:
                self.word2idx[word] = idx
                self.idx2word[idx] = word
                idx += 1

    def encode(self, text):
        # Return word index or 0 if word is not found in vocab
        # TODO: There must be a smarter way to deal with words not in vocabulary.
        # Using the <UNK> token above gets me by for now.
        return [self.word2idx.get(word, 0) for word in re.findall(r'\w+', text.lower())]

    def vocab_size(self):
        return len(self.word2idx)


class FastTextVocabulary(Vocabulary):
    def __init__(self, min_freq=1, ngram_range=(3, 6)):
        super().__init__(min_freq)
        self.ngram_range = ngram_range
        self.ngram2idx = {}
        self.idx2ngram = {}
        self.ngram_count = 1 # as with the old version, we start at 1, and leave 0 for <UNK>

    def _get_ngrams(self, word):
        ngrams = []
        word = f'<{word}>'
        for n in range(self.ngram_range[0], self.ngram_range[1] + 1):
            ngrams.extend([word[i:i+n] for i in range(len(word) - n + 1)])
        return ngrams

    def build_vocab(self, texts):
        super().build_vocab(texts)

        for word in self.word2idx:
            ngrams = self._get_ngrams(word)
            for ngram in ngrams:
                if ngram not in self.ngram2idx:
                    self.ngram2idx[ngram] = self.ngram_count
                    self.idx2ngram[self.ngram_count] = ngram
                    self.ngram_count += 1

    def encode_word(self, word):
        word_idx = self.word2idx.get(word, 0)
        ngram_idxs = [self.ngram2idx.get(ng, 0) for ng in self._get_ngrams(word)]
        return word_idx, ngram_idxs

    def ngram_vocab_size(self):
        return len(self.ngram2idx)

In [None]:
# Define dataset
class CustomColeridgeDataset(Dataset):
    def __init__(self, csv_file, json_dir, vocab, window_size=2, n_samples=1000, random_state=42):
        self.train = pd.read_csv(csv_file)
        self.train_items = self.train.sample(n=n_samples, random_state=random_state)
        self.json_dir = json_dir
        self.vocab = vocab
        self.window_size = window_size

    def __len__(self):
        return len(self.train_items)

    def __getitem__(self, idx):
        train_id = self.train_items.iloc[idx]['Id']
        curr_path = os.path.join(self.json_dir, train_id + '.json')

        with open(curr_path, 'r') as file:
            curr_json = json.load(file)

        text = ''.join([cj['text'] for cj in curr_json])
        word_indices = self.vocab.encode(text)

        # generate center-context pairs using the window size
        center_context_pairs = []
        for i, center_word_idx in enumerate(word_indices):
            for j in range(max(0, i - self.window_size), min(len(word_indices), i + self.window_size + 1)):
                if i != j:
                    context_word_idx = word_indices[j]
                    center_context_pairs.append((center_word_idx, context_word_idx))

        return center_context_pairs



class FastTextColeridgeDataset(CustomColeridgeDataset):
    def __getitem__(self, idx):
        train_id = self.train_items.iloc[idx]['Id']
        curr_path = os.path.join(self.json_dir, train_id + '.json')
        with open(curr_path, 'r') as file:
            curr_json = json.load(file)

        text = ''.join([cj['text'] for cj in curr_json])
        word_indices = self.vocab.encode(text)

        center_context_pairs = []
        for i, center_word_idx in enumerate(word_indices):
            center_word, center_ngrams = self.vocab.encode_word(center_word_idx)
            for j in range(max(0, i - self.window_size), min(len(word_indices), i + self.window_size + 1)):
                if i != j:
                    context_word, context_ngrams = self.vocab.encode_word(word_indices[j])
                    center_context_pairs.append((center_word, center_ngrams, context_word, context_ngrams))

        return center_context_pairs


In [None]:
# Define loss function
def skipgram_loss(scores, true_labels):
    loss = nn.BCEWithLogitsLoss()(scores, true_labels)
    return loss

In [None]:
class FastTextSkipGramModel(nn.Module):
    def __init__(self, vocab_size, ngram_vocab_size, embedding_dim):
        super(FastTextSkipGramModel, self).__init__()
        self.word_embeddings = nn.Parameter(torch.randn(vocab_size, embedding_dim) * 0.01)
        self.ngram_embeddings = nn.Parameter(torch.randn(ngram_vocab_size, embedding_dim) * 0.01)

    def forward(self, center_word_idx, center_ngram_idxs, context_word_idx, context_ngram_idxs):
        # Get word embeddings for center and context words
        center_word_embedding = self.word_embeddings[center_word_idx]  # (batch_size, embedding_dim)
        context_word_embedding = self.word_embeddings[context_word_idx]  # (batch_size, embedding_dim)

        # Gather n-gram embeddings for center and context words
        center_ngram_embeddings = self.ngram_embeddings.index_select(0, center_ngram_idxs.view(-1))
        center_ngram_embeddings = center_ngram_embeddings.view(center_ngram_idxs.size(0), center_ngram_idxs.size(1), -1)
        center_ngram_embeddings = torch.sum(center_ngram_embeddings, dim=1)

        context_ngram_embeddings = self.ngram_embeddings.index_select(0, context_ngram_idxs.view(-1))
        context_ngram_embeddings = context_ngram_embeddings.view(context_ngram_idxs.size(0), context_ngram_idxs.size(1), -1)
        context_ngram_embeddings = torch.sum(context_ngram_embeddings, dim=1)

        # Combine word and n-gram embeddings for center and context
        center_embedding = center_word_embedding + center_ngram_embeddings
        context_embedding = context_word_embedding + context_ngram_embeddings

        # Compute similarity (dot product) between center and context embeddings
        score = torch.sum(center_embedding * context_embedding, dim=1)
        return score

In [None]:
# Load data
vocab = FastTextVocabulary(min_freq=5)

train = pd.read_csv('/content/drive/My Drive/Datasets/Coleridge/datasets/train.csv')
train_items = train.sample(n=1000, random_state=42)
texts = []

for i in range(len(train_items)):
    curr_path = os.path.join(
        os.getcwd(),
        'drive',
        'My Drive',
        'Datasets',
        'Coleridge',
        'datasets',
        'train',
        train_items.iloc[i]['Id'] + '.json')
    with open(curr_path, 'r') as file:
        curr_json = json.load(file)
        texts.append(''.join([cj['text'] for cj in curr_json]))

vocab.build_vocab(texts)

dataset = FastTextColeridgeDataset(csv_file='/content/drive/My Drive/Datasets/Coleridge/datasets/train.csv', json_dir='/content/drive/My Drive/Datasets/Coleridge/datasets/train/', vocab=vocab)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=lambda x: [item for sublist in x for item in sublist])

In [None]:

vocab_size = vocab.vocab_size()
ngram_vocab_size = vocab.ngram_vocab_size()
embedding_dim = 100
model = FastTextSkipGramModel(vocab_size, ngram_vocab_size, embedding_dim)

optimizer = optim.Adam(model.parameters(), lr=0.01)
import torch.nn.utils.rnn as rnn_utils

# Ensure n-gram indices are padded and converted to tensors in the DataLoader or preprocessing step
for batch in dataloader:
    center_word_idxs, center_ngram_idxs, context_word_idxs, context_ngram_idxs = zip(*batch)

    # Convert word indices to tensors
    center_word_idxs = torch.tensor(center_word_idxs, dtype=torch.long)
    context_word_idxs = torch.tensor(context_word_idxs, dtype=torch.long)

    # Convert n-gram indices to tensors and pad them
    center_ngram_idxs = [torch.tensor(ngrams, dtype=torch.long) for ngrams in center_ngram_idxs]
    context_ngram_idxs = [torch.tensor(ngrams, dtype=torch.long) for ngrams in context_ngram_idxs]

    center_ngram_idxs_padded = rnn_utils.pad_sequence(center_ngram_idxs, batch_first=True, padding_value=0)
    context_ngram_idxs_padded = rnn_utils.pad_sequence(context_ngram_idxs, batch_first=True, padding_value=0)

    # Pass everything as tensors into the model
    scores = model(center_word_idxs, center_ngram_idxs_padded, context_word_idxs, context_ngram_idxs_padded)

    true_labels = torch.ones_like(scores)
    loss = skipgram_loss(scores, true_labels)
    loss.backward()
    optimizer.step()

    print(f'Loss: {loss.item()}')

Loss: 0.3760277330875397
