In [1]:
from pathlib import Path

import torch
import torchtext
import torchtext.transforms as T
from torch import nn
from torch.utils.data import DataLoader
from torchlake.common.schemas.nlp import NlpContext
from torchlake.common.utils.platform import get_file_size, get_num_workers
from torchlake.common.utils.text import (build_vocab, get_context,
                                         get_unigram_counts, is_corpus_title,
                                         is_longer_text)
from torchlake.representation.models import HellingerPCA
from torchlake.representation.models.hellinger.helper import \
    CoOccurrenceCounter
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import WikiText2, WikiText103

To utilize this notebook, you have to install `portalocker` first

In [2]:
torch.__version__

'2.1.0+cu118'

In [3]:
torchtext.__version__

'0.16.0+cpu'

# setting


In [4]:
DATASET_NAME = "WikiText2"

In [5]:
artifact_path = Path('../../artifacts/hellinger-pca')
artifact_path.mkdir(exist_ok=True)

artifact_dataset_path = artifact_path / DATASET_NAME
artifact_dataset_path.mkdir(exist_ok=True)

data_path = Path('../../data') / DATASET_NAME

In [6]:
CONTEXT = NlpContext(device="cpu")

In [7]:
BATCH_SIZE = 32
CONTEXT_SIZE = 2

In [8]:
DEVICE = torch.device(CONTEXT.device)

In [9]:
tokenizer = get_tokenizer('basic_english')

In [10]:
def datapipe_factory(datapipe, context_size: int = 1, transform = None):
    datapipe = (
        datapipe
        .map(lambda text: text.strip())
        .map(lambda text: text.lower())
        .filter(lambda text: is_longer_text(text, context_size))
        .filter(lambda text: not is_corpus_title(text))
        .map(tokenizer)
    )

    if transform:
      datapipe = datapipe.map(transform)

    return datapipe

In [11]:
train_datapipe, val_datapipe, test_datapipe = WikiText2(data_path.as_posix())

In [12]:
cloned_train_datapipe = datapipe_factory(train_datapipe, CONTEXT_SIZE)
vocab = build_vocab(cloned_train_datapipe)



In [13]:
VOCAB_SIZE = len(vocab)
VOCAB_SIZE

20351

In [14]:
# write_json_file(
#     artifact_dataset_path.joinpath("vocab.json"),
#     list(vocab.get_stoi().keys()),
# )

In [15]:
vocab.lookup_tokens(range(20))

['<unk>',
 '<bos>',
 '<eos>',
 '<pad>',
 'the',
 ',',
 '.',
 'of',
 'and',
 'in',
 'to',
 'a',
 'was',
 "'",
 '@-@',
 'on',
 'as',
 's',
 'that',
 'for']

In [16]:
text_transform = T.Sequential(
    T.VocabTransform(vocab),
    T.Truncate(CONTEXT.max_seq_len - 2),
    T.AddToken(token=CONTEXT.bos_idx, begin=True),
    T.AddToken(token=CONTEXT.eos_idx, begin=False),
    T.ToTensor(),
    T.PadTransform(CONTEXT.max_seq_len, CONTEXT.padding_idx),
)

In [17]:
train_datapipe = datapipe_factory(
    train_datapipe,
    CONTEXT_SIZE,
    text_transform,
)

val_datapipe = datapipe_factory(
    val_datapipe,
    CONTEXT_SIZE,
    text_transform,
)

test_datapipe = datapipe_factory(
    test_datapipe,
    CONTEXT_SIZE,
    text_transform,
)



In [18]:
NUM_WORKERS = get_num_workers()

In [19]:
collate_fn = lambda data: get_context(
    data,
    CONTEXT_SIZE - 1,
    0,
    enable_symmetric_context=False,
    flatten_output=True,
)

train_dataloader = DataLoader(
    train_datapipe,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
    pin_memory=True,
    num_workers=NUM_WORKERS,
)

val_dataloader = DataLoader(
    val_datapipe,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
    pin_memory=True,
    num_workers=NUM_WORKERS,
)

test_dataloader = DataLoader(
    test_datapipe,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
    pin_memory=True,
    num_workers=NUM_WORKERS,
)

In [20]:
# check context size
# for data in train_datapipe:
#     if len(data) < 5:
#         print(data)
#         break

In [21]:
# sample
# for data in train_datapipe:
#     if len(data) > 5:
#         print(data)
#         print(vocab.lookup_tokens(data))
#         break

In [22]:
# count = 0
# for data in train_datapipe:
#   count += len(data)
# count

In [23]:
# number of words in training corpurs
# wikitext2 count 1,993,228
# wikitext103 count 101,227,641

In [24]:
for gram, context in train_dataloader:
    print(gram.shape, context.shape)
    break

torch.Size([8160, 1]) torch.Size([8160, 1])


# training


In [25]:
word_counts = get_unigram_counts(
    map(lambda x: x.tolist(), train_datapipe),
    VOCAB_SIZE,
).to(DEVICE)



In [26]:
sum(word_counts)

tensor(4485120)

In [27]:
counter = CoOccurrenceCounter(VOCAB_SIZE, CONTEXT.padding_idx)

for gram, context in train_dataloader:
    counter.update_counts(gram, context)

In [28]:
for (g, c), count in counter.counts.most_common(5):
    print(g, c, count)

4 7 16928
2 6 14193
17 13 13596
4 9 11584
5 0 11460


In [29]:
# in paper, they used 10000 context words from 170000 vocabulary
model = HellingerPCA(VOCAB_SIZE, maximum_context_size=VOCAB_SIZE // 20).to(DEVICE)

In [30]:
model.fit(counter, word_counts)

Step 1: Build co-occurrence matrix
Step 2: Select most significant context words
Step 1: Compute the kernel matrix
Step 2: Center the kernel matrix
Step 3: Eigenvalue decomposition


In [30]:
model.fit(counter, word_counts)

Step 1: Build co-occurrence matrix
Step 2: Select most significant context words
Step 1: Compute the kernel matrix
Step 2: Center the kernel matrix
Step 3: Eigenvalue decomposition


In [31]:
model.get_embedding().shape

torch.Size([20351, 50])

In [32]:
model.transform(next(iter(train_datapipe))).shape

torch.Size([256, 50])

# Save & Load

In [33]:
model_path = artifact_path / f"hellinger-pca.{DATASET_NAME}.pth"

In [34]:
model.save(model_path)

In [35]:
get_file_size(model_path, "G")

'1.62GB'

In [36]:
model.load(model_path)

# Evaluation


## word analogy


In [37]:
from torchlake.language_model.datasets import WordAnalogyDataset 

In [38]:
analogy_data_path = Path('../../data') / "word-analogy"

In [39]:
dataset = WordAnalogyDataset(analogy_data_path)

In [40]:
dataset.data[:20]

[('Athens', 'Greece', 'Baghdad', 'Iraq'),
 ('Athens', 'Greece', 'Bangkok', 'Thailand'),
 ('Athens', 'Greece', 'Beijing', 'China'),
 ('Athens', 'Greece', 'Berlin', 'Germany'),
 ('Athens', 'Greece', 'Bern', 'Switzerland'),
 ('Athens', 'Greece', 'Cairo', 'Egypt'),
 ('Athens', 'Greece', 'Canberra', 'Australia'),
 ('Athens', 'Greece', 'Hanoi', 'Vietnam'),
 ('Athens', 'Greece', 'Havana', 'Cuba'),
 ('Athens', 'Greece', 'Helsinki', 'Finland'),
 ('Athens', 'Greece', 'Islamabad', 'Pakistan'),
 ('Athens', 'Greece', 'Kabul', 'Afghanistan'),
 ('Athens', 'Greece', 'London', 'England'),
 ('Athens', 'Greece', 'Madrid', 'Spain'),
 ('Athens', 'Greece', 'Moscow', 'Russia'),
 ('Athens', 'Greece', 'Oslo', 'Norway'),
 ('Athens', 'Greece', 'Ottawa', 'Canada'),
 ('Athens', 'Greece', 'Paris', 'France'),
 ('Athens', 'Greece', 'Rome', 'Italy'),
 ('Athens', 'Greece', 'Stockholm', 'Sweden')]

In [41]:
tokens = [list(vocab[word.lower()] for word in pairs) for pairs in dataset.data]
tokens[:10]

[[12958, 6970, 18304, 5313],
 [12958, 6970, 9016, 6480],
 [12958, 6970, 10276, 1346],
 [12958, 6970, 4686, 1270],
 [12958, 6970, 0, 4409],
 [12958, 6970, 14017, 2280],
 [12958, 6970, 16617, 481],
 [12958, 6970, 14323, 1572],
 [12958, 6970, 17076, 6546],
 [12958, 6970, 0, 6575]]

In [42]:
vectors = torch.stack([model.transform(token) for token in tokens if len(token) == 4])

In [43]:
# word: 25.2
metric = nn.CosineSimilarity()

country_a, capital_a, country_b, capital_b = vectors.transpose(0, 1)
score = metric(country_a - capital_a - country_b, -capital_b)
print(score.mean().item())

0.25176307559013367
