In [40]:
import functools
import logging
import os
from typing import Any, Callable, Iterable

import datasets
import numpy as np
from gensim.models.doc2vec import Doc2Vec, TaggedDocument

os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")  # Report only TF errors by default

import tensorflow as tf

import transformer_document_embedding as tde

logging.basicConfig(
    format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
)

2022-12-13 10:32:10,362 : INFO : {"jsonrpc": "2.0", "method": "SyncRequest", "params": {"data": {"file_name": "/home/dburian/docs/transformer_document_embedding/notebooks/doc2vec_imdb.sync.py", "contents": "# ---\n# jupyter:\n#  jupytext:\n#   text_representation:\n#    extension: .py\n#    format_name: percent\n#    format_version: '1.3'\n#    jupytext_version: 1.3.4\n#  kernelspec:\n#   display_name: Python 3\n#   language: python\n#   name: python3\n# ---\n# %%\nimport functools\nimport logging\nimport os\nfrom typing import Any, Callable, Iterable\n\nimport datasets\nimport numpy as np\nfrom gensim.models.doc2vec import Doc2Vec, TaggedDocument\n\nos.environ.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"2\") # Report only TF errors by default\n\nimport tensorflow as tf\n\nimport transformer_document_embedding as tde\n\nlogging.basicConfig(\n  format=\"%(asctime)s : %(levelname)s : %(message)s\", level=logging.INFO\n)\n# %%\nimdb = datasets.load_dataset(\"imdb\")\nprint(imdb)\n# %%\ndef get

In [22]:
imdb = datasets.load_dataset("imdb")
print(imdb)



  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})


In [41]:
def get_id(id: int, split: str) -> int:
    split_ind = ["test", "train", "unsuper"].index(split)
    return split_ind * 25000


def add_doc_id(
    example: dict[str, Any], idx: int, get_id: Callable[[int], int]
) -> dict[str, Any]:
    return {"id": get_id(idx), "text": example["text"], "label": example["label"]}


train_add_doc_id = functools.partial(
    add_doc_id, get_id=lambda idx: get_id(idx, "train")
)
unsuper_add_doc_id = functools.partial(
    add_doc_id, get_id=lambda idx: get_id(idx, "unsuper")
)
test_add_doc_id = functools.partial(add_doc_id, get_id=lambda idx: get_id(idx, "train"))
train = imdb["train"].map(train_add_doc_id, with_indices=True)
unsuper = imdb["unsupervised"].map(unsuper_add_doc_id, with_indices=True)
test = imdb["test"].map(test_add_doc_id, with_indices=True)

  0%|          | 0/25000 [00:00<?, ?ex/s]

  0%|          | 0/50000 [00:00<?, ?ex/s]

  0%|          | 0/25000 [00:00<?, ?ex/s]

2022-12-13 10:33:08,041 : INFO : {"jsonrpc": "2.0", "method": "SyncRequest", "params": {"data": {"file_name": "/home/dburian/docs/transformer_document_embedding/notebooks/doc2vec_imdb.sync.py", "contents": "# ---\n# jupyter:\n#  jupytext:\n#   text_representation:\n#    extension: .py\n#    format_name: percent\n#    format_version: '1.3'\n#    jupytext_version: 1.3.4\n#  kernelspec:\n#   display_name: Python 3\n#   language: python\n#   name: python3\n# ---\n# %%\nimport functools\nimport logging\nimport os\nfrom typing import Any, Callable, Iterable\n\nimport datasets\nimport numpy as np\nfrom gensim.models.doc2vec import Doc2Vec, TaggedDocument\n\nos.environ.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"2\") # Report only TF errors by default\n\nimport tensorflow as tf\n\nimport transformer_document_embedding as tde\n\nlogging.basicConfig(\n  format=\"%(asctime)s : %(levelname)s : %(message)s\", level=logging.INFO\n)\n# %%\nimdb = datasets.load_dataset(\"imdb\")\nprint(imdb)\n# %%\ndef get

In [37]:
for i, x in enumerate(train):
    print(x)
    if i > 5:
        break

{'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far be

In [42]:
def preprocess_document(text: str) -> list[str]:
    """Preprocesses document according to Paragraph Vector paper."""
    words = text.split()
    while len(words) < 10:
        words.insert(0, "NULL")

    return words


class MyGensimCorpus:
    def __iter__(self) -> Iterable[TaggedDocument]:
        for doc in train:
            yield TaggedDocument(preprocess_document(doc["text"]), [doc["id"]])

        for doc in unsuper:
            yield TaggedDocument(preprocess_document(doc["text"]), [doc["id"]])


for i, x in enumerate(MyGensimCorpus()):
    print(x)
    if i > 3:
        break

TaggedDocument<['I', 'rented', 'I', 'AM', 'CURIOUS-YELLOW', 'from', 'my', 'video', 'store', 'because', 'of', 'all', 'the', 'controversy', 'that', 'surrounded', 'it', 'when', 'it', 'was', 'first', 'released', 'in', '1967.', 'I', 'also', 'heard', 'that', 'at', 'first', 'it', 'was', 'seized', 'by', 'U.S.', 'customs', 'if', 'it', 'ever', 'tried', 'to', 'enter', 'this', 'country,', 'therefore', 'being', 'a', 'fan', 'of', 'films', 'considered', '"controversial"', 'I', 'really', 'had', 'to', 'see', 'this', 'for', 'myself.<br', '/><br', '/>The', 'plot', 'is', 'centered', 'around', 'a', 'young', 'Swedish', 'drama', 'student', 'named', 'Lena', 'who', 'wants', 'to', 'learn', 'everything', 'she', 'can', 'about', 'life.', 'In', 'particular', 'she', 'wants', 'to', 'focus', 'her', 'attentions', 'to', 'making', 'some', 'sort', 'of', 'documentary', 'on', 'what', 'the', 'average', 'Swede', 'thought', 'about', 'certain', 'political', 'issues', 'such', 'as', 'the', 'Vietnam', 'War', 'and', 'race', 'issues

2022-12-13 10:33:08,427 : INFO : {"jsonrpc": "2.0", "method": "SyncRequest", "params": {"data": {"file_name": "/home/dburian/docs/transformer_document_embedding/notebooks/doc2vec_imdb.sync.py", "contents": "# ---\n# jupyter:\n#  jupytext:\n#   text_representation:\n#    extension: .py\n#    format_name: percent\n#    format_version: '1.3'\n#    jupytext_version: 1.3.4\n#  kernelspec:\n#   display_name: Python 3\n#   language: python\n#   name: python3\n# ---\n# %%\nimport functools\nimport logging\nimport os\nfrom typing import Any, Callable, Iterable\n\nimport datasets\nimport numpy as np\nfrom gensim.models.doc2vec import Doc2Vec, TaggedDocument\n\nos.environ.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"2\") # Report only TF errors by default\n\nimport tensorflow as tf\n\nimport transformer_document_embedding as tde\n\nlogging.basicConfig(\n  format=\"%(asctime)s : %(levelname)s : %(message)s\", level=logging.INFO\n)\n# %%\nimdb = datasets.load_dataset(\"imdb\")\nprint(imdb)\n# %%\ndef get

In [43]:
model = Doc2Vec(vector_size=400, window=5, dm_concat=1, workers=6)
model.build_vocab(MyGensimCorpus())

2022-12-13 10:33:17,623 : INFO : using concatenative 4400-dimensional layer1
2022-12-13 10:33:17,637 : INFO : Doc2Vec lifecycle event {'params': 'Doc2Vec<dm/c,d400,n5,w5,mc5,s0.001,t6>', 'datetime': '2022-12-13T10:33:17.637903', 'gensim': '4.2.0', 'python': '3.10.2 (main, Jan 15 2022, 19:56:27) [GCC 11.1.0]', 'platform': 'Linux-5.17.0-1-MANJARO-x86_64-with-glibc2.35', 'event': 'created'}
2022-12-13 10:33:17,776 : INFO : collecting all words and their counts
2022-12-13 10:33:17,777 : INFO : PROGRESS: at example #0, processed 0 words (0 words/s), 0 word types, 0 tags
2022-12-13 10:33:18,546 : INFO : PROGRESS: at example #10000, processed 2317452 words (3016356 words/s), 152756 word types, 0 tags
2022-12-13 10:33:19,378 : INFO : PROGRESS: at example #20000, processed 4663219 words (2821971 words/s), 242583 word types, 0 tags
2022-12-13 10:33:20,160 : INFO : PROGRESS: at example #30000, processed 7020384 words (3017394 words/s), 315459 word types, 0 tags
2022-12-13 10:33:20,941 : INFO : PR

In [44]:
model.train(MyGensimCorpus(), total_examples=model.corpus_count, epochs=1)

2022-12-13 10:33:40,791 : INFO : Doc2Vec lifecycle event {'msg': 'training model with 6 workers on 100929 vocabulary and 4400 features, using sg=0 hs=0 sample=0.001 negative=5 window=5 shrink_windows=True', 'datetime': '2022-12-13T10:33:40.791066', 'gensim': '4.2.0', 'python': '3.10.2 (main, Jan 15 2022, 19:56:27) [GCC 11.1.0]', 'platform': 'Linux-5.17.0-1-MANJARO-x86_64-with-glibc2.35', 'event': 'train'}
2022-12-13 10:33:41,797 : INFO : EPOCH 0 - PROGRESS: at 0.48% examples, 59495 words/s, in_qsize 11, out_qsize 0
2022-12-13 10:33:42,804 : INFO : EPOCH 0 - PROGRESS: at 1.33% examples, 87875 words/s, in_qsize 11, out_qsize 0
2022-12-13 10:33:43,947 : INFO : EPOCH 0 - PROGRESS: at 2.35% examples, 96167 words/s, in_qsize 11, out_qsize 0
2022-12-13 10:33:44,977 : INFO : EPOCH 0 - PROGRESS: at 3.38% examples, 104609 words/s, in_qsize 11, out_qsize 0
2022-12-13 10:33:46,044 : INFO : EPOCH 0 - PROGRESS: at 4.39% examples, 107553 words/s, in_qsize 11, out_qsize 0
2022-12-13 10:33:47,111 : INF

2022-12-13 10:34:54,536 : INFO : EPOCH 0 - PROGRESS: at 69.83% examples, 125495 words/s, in_qsize 11, out_qsize 0
2022-12-13 10:34:55,554 : INFO : EPOCH 0 - PROGRESS: at 70.86% examples, 125479 words/s, in_qsize 12, out_qsize 0
2022-12-13 10:34:56,592 : INFO : EPOCH 0 - PROGRESS: at 71.91% examples, 125705 words/s, in_qsize 11, out_qsize 0
2022-12-13 10:34:57,635 : INFO : EPOCH 0 - PROGRESS: at 72.99% examples, 125815 words/s, in_qsize 12, out_qsize 0
2022-12-13 10:34:58,676 : INFO : EPOCH 0 - PROGRESS: at 73.97% examples, 125918 words/s, in_qsize 11, out_qsize 0
2022-12-13 10:34:59,687 : INFO : EPOCH 0 - PROGRESS: at 75.11% examples, 126279 words/s, in_qsize 11, out_qsize 0
2022-12-13 10:35:00,813 : INFO : EPOCH 0 - PROGRESS: at 76.22% examples, 126340 words/s, in_qsize 11, out_qsize 0
2022-12-13 10:35:01,841 : INFO : EPOCH 0 - PROGRESS: at 77.28% examples, 126484 words/s, in_qsize 11, out_qsize 0
2022-12-13 10:35:02,860 : INFO : EPOCH 0 - PROGRESS: at 78.48% examples, 126820 words/s,

In [33]:
model.dv[25000]

KeyError: "Key '75000' not present"

2022-12-13 10:15:08,329 : INFO : {"jsonrpc": "2.0", "method": "SyncRequest", "params": {"data": {"file_name": "/home/dburian/docs/transformer_document_embedding/notebooks/doc2vec_imdb.sync.py", "contents": "# ---\n# jupyter:\n#  jupytext:\n#   text_representation:\n#    extension: .py\n#    format_name: percent\n#    format_version: '1.3'\n#    jupytext_version: 1.3.4\n#  kernelspec:\n#   display_name: Python 3\n#   language: python\n#   name: python3\n# ---\n# %%\nimport logging\nimport os\nfrom typing import Any, Iterable\n\nimport datasets\nimport numpy as np\nfrom gensim.models.doc2vec import Doc2Vec, TaggedDocument\n\nos.environ.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"2\") # Report only TF errors by default\n\nimport tensorflow as tf\n\nimport transformer_document_embedding as tde\n\nlogging.basicConfig(\n  format=\"%(asctime)s : %(levelname)s : %(message)s\", level=logging.INFO\n)\n# %%\nimdb = datasets.load_dataset(\"imdb\")\nprint(imdb)\n# %%\ntrain = imdb[\"train\"]\nunsuper = 

In [55]:


def add_feature_vec(doc: dict[str, Any]) -> dict[str, Any]:
    feature_vector = model.dv[doc["id"]]
    return {"features": feature_vector, "label": doc["label"]}


softmax_train = train.map(add_feature_vec, remove_columns=["text", "id"]).to_tf_dataset(
    1
)
softmax_train = softmax_train.unbatch()
softmax_train = softmax_train.map(lambda doc: (doc["features"], doc["label"]))
softmax_train = softmax_train.shuffle(1000)
softmax_train = softmax_train.batch(2)
# tf_dataset_softmax_train = tf.data.Dataset.from_generator(
#     lambda: softmax_train,
#     output_types=(tf.float32, tf.int32),
#     output_shapes=((1,), (0,)),
# )
for x in softmax_train.take(5):
    print(x)

  0%|          | 0/25000 [00:00<?, ?ex/s]

(<tf.Tensor: shape=(2, 400), dtype=float32, numpy=
array([[ 1.10688299e-01,  6.38920590e-02,  2.41885632e-01,
         2.34931186e-01, -4.56061773e-02,  3.65586340e-01,
        -8.90850648e-02, -3.02319497e-01, -2.70096540e-01,
        -7.35404417e-02,  3.56509656e-01, -1.11738861e-01,
        -9.09382850e-02,  3.32294494e-01,  1.94648579e-01,
         8.21251646e-02, -3.01080644e-01,  1.45174950e-01,
        -2.20949024e-01, -1.33951873e-01, -4.03231382e-02,
        -2.66635984e-01, -2.49873862e-01,  1.12693347e-01,
        -3.30327690e-01, -1.34036690e-01, -3.52654338e-01,
        -2.92524785e-01,  1.25382155e-01,  1.55797213e-01,
        -1.35885194e-01,  8.83836821e-02, -2.33235762e-01,
         1.55905247e-01,  5.79732433e-02,  4.71169986e-02,
         2.50259846e-01, -3.04724336e-01, -1.06475003e-01,
         3.75052720e-01, -2.27052763e-01, -2.10722089e-01,
        -1.01622686e-01,  3.12036037e-01,  2.60935009e-01,
         2.99622715e-01, -2.36304790e-01, -1.35474578e-01,
     

2022-12-13 10:54:02,479 : INFO : {"jsonrpc": "2.0", "method": "SyncRequest", "params": {"data": {"file_name": "/home/dburian/docs/transformer_document_embedding/notebooks/doc2vec_imdb.sync.py", "contents": "# ---\n# jupyter:\n#  jupytext:\n#   text_representation:\n#    extension: .py\n#    format_name: percent\n#    format_version: '1.3'\n#    jupytext_version: 1.3.4\n#  kernelspec:\n#   display_name: Python 3\n#   language: python\n#   name: python3\n# ---\n# %%\nimport functools\nimport logging\nimport os\nfrom typing import Any, Callable, Iterable\n\nimport datasets\nimport numpy as np\nfrom gensim.models.doc2vec import Doc2Vec, TaggedDocument\n\nos.environ.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"2\") # Report only TF errors by default\n\nimport tensorflow as tf\n\nimport transformer_document_embedding as tde\n\nlogging.basicConfig(\n  format=\"%(asctime)s : %(levelname)s : %(message)s\", level=logging.INFO\n)\n# %%\nimdb = datasets.load_dataset(\"imdb\")\nprint(imdb)\n# %%\ndef get

In [56]:
softmax_net = tf.keras.Sequential(
    [
        tf.keras.layers.Input(400),
        tf.keras.layers.Dense(50, activation=tf.nn.relu),
        tf.keras.layers.Dense(1, activation=tf.nn.sigmoid),
    ]
)
softmax_net.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.losses.BinaryCrossentropy(),
    metrics=[tf.metrics.BinaryAccuracy()],
)

In [57]:
softmax_net.fit(softmax_train)



<keras.callbacks.History at 0x7f5da2248e20>