### Imports & utils

In [2]:
%pip install datasets transformers[torch]



In [3]:
from __future__ import annotations
import typing
from dataclasses import dataclass, field
import warnings
from contextlib import contextmanager
import itertools
import functools
import logging
from ast import literal_eval

In [4]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional
import transformers
import transformers.modeling_outputs
import datasets

In [6]:
@contextmanager
def localize_globals(*exceptions: str, restore_values: bool = True):
    exceptions: typing.Set[str] = set(exceptions)

    old_globals: typing.Dict[str, typing.Any] = dict(globals())
    allowed: typing.Set[str] = set(old_globals.keys())
    allowed.update(exceptions)

    yield None

    new_globals: typing.Dict[str, typing.Any] = globals()

    for name in tuple(new_globals.keys()):
        if name not in allowed:
            del new_globals[name]

    if not restore_values:
        return

    new_globals.update(
        {k: v for k, v in old_globals.items() if k not in exceptions}
    )

In [7]:
logging.basicConfig(
    level=logging.INFO,
    format="[{levelname}] {message}",
    style="{",
)

### Data preprocessing

In [8]:
data: pd.DataFrame = pd.read_csv("yc_essential_data.csv")

# Limit to the columns we're interested in
data = data[["name", "one_liner", "long_description", "tags"]]

# Convert tags to a list
data["tags"] = data["tags"].apply(literal_eval)
assert isinstance(data.at[0, "tags"], list), "Didn't work!"

# Okay, apparently an empty string makes a nan by default
# Gotta reverse it
data["one_liner"].replace(
    to_replace=np.nan,
    value="",
    inplace=True,
)

data["long_description"].replace(
    to_replace=np.nan,
    value="",
    inplace=True,
)

# Preview the results
data.head()

Unnamed: 0,name,one_liner,long_description,tags
0,Wufoo,Online form builder.,Wufoo is a web application that helps anybody ...,"[SaaS, Productivity]"
1,Project Wedding,,"Finding wedding vendors is hard. In 2007, a co...",[]
2,Clustrix,,Clustrix provides the leading scale-out relati...,[]
3,Inkling,,"Inkling, based in Chicago, Illinois, offers co...",[]
4,Audiobeta,,AudioBeta develops web-based applications that...,[]


In [9]:
# Gather all unique tags
with localize_globals("all_tags"):
    tags_set: set[str] = set(itertools.chain.from_iterable(data["tags"]))

    all_tags: pd.Series = pd.Series(sorted(tags_set))

all_tags

0          3D Printed Foods
1               3D Printing
2                        AI
3              AI Assistant
4      AI-Enhanced Learning
               ...         
324          Women's Health
325     Workflow Automation
326               eLearning
327                 eSports
328                    web3
Length: 329, dtype: object

### Pretrained models

In [10]:
tokenizer: transformers.DistilBertTokenizer = transformers.DistilBertTokenizer.from_pretrained(
    "distilbert-base-uncased",
)

nlp_model: transformers.DistilBertModel = transformers.DistilBertModel.from_pretrained(
    "distilbert-base-uncased",
)

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

In [11]:
MAX_TOKENS: int = 512
EMBEDDING_SIZE: int = 768

### Dataset preparation

In [12]:
with localize_globals("complete_dataset", "col_pad_len"):
    def preprocess(batch: dict[str, typing.Any]) -> dict[str, typing.Any]:
        for column in (
            "name",
            "one_liner",
            "long_description",
        ):
            tmp = tokenizer(
                batch[column],
                truncation=True,
                padding="max_length",
                max_length=MAX_TOKENS,
                return_tensors="pt",
            ).data

            # logging.info(f"!! {tmp['input_ids'].shape}, {tmp['attention_mask'].shape}")
            batch[column] = tmp["input_ids"]
            batch[f"{column}_mask"] = tmp["attention_mask"]

        # TODO: Since this is the target, process it separately?
        if "tags" in batch:
            batch["tags"] = torch.stack([
                torch.tensor(all_tags.apply(tags.__contains__), dtype=torch.float)
                for tags in batch["tags"]
            ])

        return batch

    complete_dataset = (
        datasets.Dataset
        .from_pandas(data)
        # .with_format(None)
        .with_transform(preprocess)
    )

complete_dataset

Dataset({
    features: ['name', 'one_liner', 'long_description', 'tags'],
    num_rows: 4423
})

In [13]:
with localize_globals("train_dataset", "val_dataset", "test_dataset"):
    train_test_split = complete_dataset.train_test_split(test_size=0.2)
    train_dataset = train_test_split["train"]

    test_val_split = train_test_split["test"].train_test_split(test_size=0.3)
    val_dataset = test_val_split["train"]
    test_dataset = test_val_split["test"]

train_dataset, val_dataset, test_dataset

(Dataset({
     features: ['name', 'one_liner', 'long_description', 'tags'],
     num_rows: 3538
 }),
 Dataset({
     features: ['name', 'one_liner', 'long_description', 'tags'],
     num_rows: 619
 }),
 Dataset({
     features: ['name', 'one_liner', 'long_description', 'tags'],
     num_rows: 266
 }))

### Model definition

In [14]:
class MultiInputModule(nn.Module):
    """
    Takes multiple inputs from named columns of a dataset,
    passes them to separate sub-modules, and collects the
    result with a single collector module.

    Note: the arguments are passed to the collector by
    their order, the names are only used for column selection.
    This behaviour relies on nn.ModuleDict preserving the
    order of insertion, which should hold for Python >= 3.6.
    If that's not the case, you'll get arbitrary but consistent (?)
    order within a single `MultiInputModule` instance.
    """

    inputs: nn.ModuleDict
    collector: nn.Module

    def __init__(
        self,
        inputs: nn.ModuleDict,
        collector: nn.Module,
    ) -> None:
        super().__init__()

        self.collector = collector
        self.inputs = inputs

    def forward(
        self,
        # TODO: **?
        input_dict: typing.Mapping[str, torch.Tensor],
    ) -> torch.Tensor:
        assert set(input_dict.keys()).issuperset(self.inputs.keys()), \
            f"Missing parameters: expected {set(self.inputs.keys())}, got only {set(input_dict.keys())}"

        return self.collector(*(
            self.inputs[name](input_dict[name])
            for name in self.inputs
        ))


In [15]:
class ConcatenationModule(nn.Module):
    """
    Takes a list of tensors and concatenates them
    into a single tensor along a new axis
    """

    def __init__(self) -> None:
        super().__init__()

    def forward(
        self,
        *tensors: torch.Tensor,
    ) -> torch.Tensor:
        return torch.cat(tensors, dim=-1)


In [16]:
class NLPWrapperModule(nn.Module):
    """
    Wraps an nlp module and performs the following pre- and postprocessing:
    - Takes a dictionary and `**`-unwraps it for the submodule's input
    - Takes `.last_hidden_state` from the submodule's result and returns only it
    """

    submodule: nn.Module

    def __init__(self, submodule: nn.Module) -> None:
        super().__init__()

        self.submodule = submodule

    def forward(
        self,
        params: typing.Mapping[str, torch.Tensor],
    ) -> torch.Tensor:
        # logging.info(f"! input_ids:{params['input_ids'].shape}, attention_mask:{params['attention_mask'].shape}")

        return self.submodule(**params).last_hidden_state

In [17]:
class YCTagPredictorConfig(transformers.modeling_utils.PretrainedConfig):
    model_type: typing.ClassVar[str] = "yc_tag_predictor"

    def __init__(self, **kwargs: typing.Any) -> None:
        super().__init__(**kwargs)


class YCTagPredictorModel(transformers.modeling_utils.PreTrainedModel):
    config_class = YCTagPredictorConfig

    def __init__(self, config: YCTagPredictorConfig) -> None:
        super().__init__(config)

        input_embedder: nn.Module = NLPWrapperModule(nlp_model)

        # I can't afford to also tune BERT, nor do I need to
        input_embedder.train(False)
        for param in input_embedder.parameters():
            param.requires_grad = False

        self.model = nn.Sequential(
            MultiInputModule(
                inputs=nn.ModuleDict(dict(
                    name=input_embedder,
                    one_liner=input_embedder,
                    long_description=input_embedder,
                )),
                collector=ConcatenationModule(),
            ),
            nn.Flatten(),
            nn.Linear(
                in_features=EMBEDDING_SIZE * MAX_TOKENS * 3,
                out_features=len(all_tags),
            ),
            # nn.Softmax(dim=-1),
        )

    def forward(
        self,
        *,
        name: torch.Tensor,
        name_mask: torch.Tensor,
        one_liner: torch.Tensor,
        one_liner_mask: torch.Tensor,
        long_description: torch.Tensor,
        long_description_mask: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        return self.model(dict(
            name=dict(
                input_ids=name,
                attention_mask=name_mask,
            ),
            one_liner=dict(
                input_ids=one_liner,
                attention_mask=one_liner_mask,
            ),
            long_description=dict(
                input_ids=long_description,
                attention_mask=long_description_mask,
            ),
        ))

In [18]:
model = YCTagPredictorModel(
    YCTagPredictorConfig(),
)

In [19]:
with localize_globals():
    model.to(torch.device("cpu"))
    actual_shape = model(**next(train_dataset.iter(1)))[0].shape
    target_shape = next(iter(train_dataset))["tags"].shape

    logging.info(f"{actual_shape=}, {target_shape=}")

    assert actual_shape == target_shape, "Bad model result shape"


### Model training

In [26]:
training_args = transformers.TrainingArguments(
    output_dir="./training_output",
    logging_dir="./training_logs",
    label_names=["tags"],
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    remove_unused_columns=False,
    num_train_epochs=10,
    # warmup_steps=100,
    # weight_decay=0.01,
)

In [21]:
class CustomTrainer(transformers.trainer.Trainer):
    def compute_loss(
        self,
        model: nn.Module,
        inputs: dict[str, typing.Any],
        return_outputs: bool = False,
    ) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, transformers.modeling_outputs.ModelOutput]]:
        # print("!!!", flush=True)
        labels = inputs.pop("tags")
        outputs = model(**inputs)
        loss = nn.functional.cross_entropy(
            outputs.view(-1), labels.view(-1),
        )
        return (loss, outputs) if return_outputs else loss

In [27]:
trainer: transformers.trainer.Trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [28]:
trainer.train(resume_from_checkpoint=True)

There were missing keys in the checkpoint model loaded: ['model.0.inputs.one_liner.submodule.embeddings.word_embeddings.weight', 'model.0.inputs.one_liner.submodule.embeddings.position_embeddings.weight', 'model.0.inputs.one_liner.submodule.embeddings.LayerNorm.weight', 'model.0.inputs.one_liner.submodule.embeddings.LayerNorm.bias', 'model.0.inputs.one_liner.submodule.transformer.layer.0.attention.q_lin.weight', 'model.0.inputs.one_liner.submodule.transformer.layer.0.attention.q_lin.bias', 'model.0.inputs.one_liner.submodule.transformer.layer.0.attention.k_lin.weight', 'model.0.inputs.one_liner.submodule.transformer.layer.0.attention.k_lin.bias', 'model.0.inputs.one_liner.submodule.transformer.layer.0.attention.v_lin.weight', 'model.0.inputs.one_liner.submodule.transformer.layer.0.attention.v_lin.bias', 'model.0.inputs.one_liner.submodule.transformer.layer.0.attention.out_lin.weight', 'model.0.inputs.one_liner.submodule.transformer.layer.0.attention.out_lin.bias', 'model.0.inputs.one_l

Step,Training Loss
1500,595.1178
2000,616.3283


KeyboardInterrupt: ignored

In [29]:
trainer.evaluate()

Step,Training Loss,Validation Loss
1500,595.1178,
2000,616.3283,
2245,616.3283,709.569946


{'eval_loss': 709.5699462890625}