In [1]:
import numpy as np
import tensorflow as tf
from itertools import compress
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics.pairwise import pairwise_distances

In [2]:
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')

Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
[1m1115394/1115394[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [4]:
with open(path_to_file) as f:
    words = [word for line in f.readlines() for word in line.split()]

print(f'Number of words: {len(words)}')

Number of words: 202651


In [12]:
import random
from collections import Counter, deque
from typing import List, Tuple

def prepareData(words: List[str], vocab_size: int = 50000):
    """
    Prepares the data for word vectorization by converting words to indices and creating dictionaries for word-to-index and index-to-word mappings.

    Parameters:
    words:      The corpus of words to be processed.
    vocab_size: The maximum size of the vocabulary. Default is 50,000.

    Returns:
    Tuple[List[int], List[Tuple[str, int]], dict, dict]:
        - data:        The corpus converted to a list of word indices.
        - count:       A list of tuples where each tuple contains a word and its frequency, including the <unk> token for rare words.
        - dictionary:  A dictionary mapping words to their corresponding indices.
        - reverse_dictionary:  A dictionary mapping indices to their corresponding words.
    """
    ## Rare words are replaced by <unk> token
    count = [['<unk>', -1]]
    count.extend(Counter(words).most_common(vocab_size - 1))

    ## Initialize dictionary for index to word mapping
    dictionary = dict()
    for word, _ in count:
        dictionary[word] = len(dictionary)
    
    ## Convert corpus to list of indices
    data = list()
    unk_count = 0
    for word in words:
        if word in dictionary:
            index = dictionary[word]
        else:
            index = 0
            unk_count += 1
        data.append(index)

    count[0][1] = unk_count
    reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
    
    return data, count, dictionary, reverse_dictionary



In [13]:
def skipgram (data: List[int], batch_size: int, num_skips: int, skip_window: int, data_index: int = 0):
    """
    Generate a batch of data for the skip-gram model.

    Parameters:
    data:        List of word indices.
    batch_size:  Number of words in each batch.
    num_skips:   How many times to reuse an input to generate a label.
    skip_window: How many words to consider left and right.
    data_index:  Index to start with in the data list. Default is 0.

    Returns:
    Tuple[np.ndarray, np.ndarray]: Batch of input words and corresponding labels.
    """
    assert batch_size % num_skips == 0
    assert num_skips <= 2 * skip_window

    batch = np.ndarray(shape=(batch_size), dtype=np.int32)
    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
    window_size = 2 * skip_window + 1

    # Create a buffer to store the data
    buffer = deque(maxlen=window_size)
    for _ in range(window_size):
        buffer.append(data[data_index])
        data_index = (data_index + 1) % len(data)

    # Generates the batch of context words and labels
    for i in range(batch_size // num_skips):
        target = skip_window
        targets_to_avoid = [skip_window]

        for j in range(num_skips):
            while target in targets_to_avoid:
                target = random.randint(0, window_size - 1)
            targets_to_avoid.append(target)
            batch[i * num_skips + j] = buffer[skip_window]
            labels[i * num_skips + j, 0] = buffer[target]

        # Move the window
        buffer.append(data[data_index])
        data_index = (data_index + 1) % len(data)

    return batch, labels

In [14]:
def cbow(data: List[int], batch_size: int, num_skips: int, skip_window: int, data_index: int = 0):
    """
    Generate a batch of data for the CBOW model.

    Parameters:
    data:        List of word indices.
    batch_size:  Number of words in each batch.
    num_skips:   How many times to reuse an input to generate a label.
    skip_window: How many words to consider left and right.
    data_index:  Index to start with in the data list. Default is 0.

    Returns:
    Tuple[np.ndarray, np.ndarray]: Batch of context words and corresponding labels.
    """    
    assert batch_size % num_skips == 0
    assert num_skips <= 2 * skip_window

    batch = np.ndarray(shape=(batch_size, num_skips), dtype=np.int32)
    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
    window_size = 2 * skip_window + 1
    
    # Create a buffer to store the data
    buffer = deque(maxlen=window_size)
    for _ in range(window_size):
        buffer.append(data[data_index])
        data_index = (data_index + 1) % len(data)

    # Generates the batch of context words and labels
    for i in range(batch_size):
        mask = [1] * window_size
        mask[skip_window] = 0
        batch[i] = list(compress(buffer, mask))
        labels[i, 0] = buffer[skip_window]

        # Move the window
        buffer.append(data[data_index])
        data_index = (data_index + 1) % len(data)
        
    return batch, labels

In [17]:
## CBOW is functioning - Run below to check  ----XXXX---- Note: punctation is not removed from the text

data, count, dictionary, reverse_dictionary = prepareData(words)

# Call the cbow function
batch_size = 8
num_skips = 4
skip_window = 2
batch, labels = cbow(data, batch_size, num_skips, skip_window)

# Print the first 5 examples
for i in range(5):
    context_words = [reverse_dictionary[idx] for idx in batch[i]]
    target_word = reverse_dictionary[labels[i, 0]]
    print(f"Context words: {context_words}, Target word: {target_word}")

Context words: ['First', 'Citizen:', 'we', 'proceed'], Target word: Before
Context words: ['Citizen:', 'Before', 'proceed', 'any'], Target word: we
Context words: ['Before', 'we', 'any', 'further,'], Target word: proceed
Context words: ['we', 'proceed', 'further,', 'hear'], Target word: any
Context words: ['proceed', 'any', 'hear', 'me'], Target word: further,
