# Pickling and FastText models

## Setup

Here, we will import the necessary modules and define the convenience functions we will use throughout this example

In [None]:
from sklearn import pipeline, preprocessing, base, cluster
from typing import Self, Any
from fasttext import FastText
import numpy as np
from lorem_text import lorem
import tempfile
import pathlib
import pickle


def make_fake_data(n_samples: int = 100) -> str:
    return "\n".join(lorem.sentence() for _ in range(n_samples)) + "\n"


def save_pipeline(pipeline: pipeline.Pipeline, filename: str = "saved_pipeline.pkl") -> None:
    with open(filename, "wb") as f:
        pickle.dump(pipeline, f)


def load_pipeline(filename: str = "saved_pipeline.pkl") -> pipeline.Pipeline:
    with open(filename, "rb") as f:
        return pickle.load(f)


def train_model(n_examples: int = 100) -> FastText._FastText:
    with tempfile.TemporaryDirectory() as d:
        data_file = pathlib.Path(d) / "training_data.txt"
        data_file.write_text(make_fake_data())
        model = FastText.train_unsupervised(str(data_file))
    return model

## The problem

Suppose you wanted to do some clustering using on some sentence embeddings taken from Meta's [FastText](https://fasttext.cc/) model, and you were a disciplined ML engineer who also wanted to use scikit-learn's well-known transformer interface (not to be confused with HuggingFace's `transformers`) so that you could drop it in a [Pipeline](https://scikit-learn.org/stable/modules/compose.html#pipeline-chaining-estimators). And, let's suppose that models are persisteed in your organization using [pickles](https://docs.python.org/3/library/pickle.html).

### You define your`FastTextTransformer`

This is done by implementing the following:
  1. inherit from `base.BaseEstimator` and `base.TransformerMixin`
  2. implement both `fit` and `transform` methods

In [None]:
class FastTextTransformer(base.BaseEstimator, base.TransformerMixin):
    def __init__(self, model: FastText._FastText):
        self.model = model
        super().__init__()

    def fit(self, *args: Any, **kwargs: Any) -> Self:
        # No-Op
        return self

    def transform(self, X: np.ndarray) -> np.ndarray:
        text = np.atleast_1d(X)

        if text.ndim != 1:
            raise ValueError(f"`X` must be 1-dimensional, received {text.ndim}d data")

        return np.asarray([self.model.get_sentence_vector(s) for s in text.tolist()])

### Train your FastText model and construct your clustering pipeline

#### Train the model

In [None]:
fasttext_model = train_model()

#### Train your cluster model

In [None]:
cluster_training_data = make_fake_data().strip().split("\n")
random_state = np.random.RandomState(1024)  # setting random state for comparibility

unpicklable_pipeline = pipeline.make_pipeline(
    FastTextTransformer(model=fasttext_model), cluster.KMeans(n_clusters=2, n_init="auto", random_state=random_state)
).fit(cluster_training_data)

#### Save your pipeline

Here we should get
> `TypeError: cannot pickle 'fasttext_pybind.fasttext' object`

In [None]:
save_pipeline(unpicklable_pipeline)

### What happened?

`fasttext_pybind.fasttext` is a compiled C++ extension and so cannot be serialized cross-platform.

### Are we out of luck?

At this point we have two choices:
1. accept that we can't use `FastText` models in our pipeline transformations
2. learn some pickling dark magic, and make an object that we can pickle

Option $1$ is no fun and wouldn't make for a very interesting post so let's go with option $2$.

### What can we do?

Fortunately, we have a couple things going for us.
1. `fasttext` provides its own object serialization functionality (see [fasttext.save_model]())
2. Object pickling in Python, like most things, relies on a protocol that we are free to hack away at

### The solution

The solution is simply wrap`fasttext.load_model` and `FastText._FastText.save_model` functionality into Python's pickling protocol. All this requires is us to extend `FastText._FastText` to include:

1. a method to serialize a trained model to `bytes`
3. a constructor to instantiate a model from its serialized `bytes` representatino
2. override the [\_\_reduce\_\_](https://docs.python.org/3/library/pickle.html#object.__reduce__) method to return a `tuple` containing
  > - A callable object that will be called to create the initial version of the object.
  > - A tuple of arguments for the callable object. An empty tuple must be given if the callable does not accept any argument.

In our case `__reduce__` will return a `tuple` containing the constructor in 2.) and a `tuple` containing the `bytes` of our serialized instance returned from the method in 1.).

### Putting it all together

Finally, we land on an implementation like the following

In [None]:
import pathlib
import tempfile
from typing import Callable, Self

import fasttext
from fasttext import FastText


class PicklableFastText(FastText._FastText):  # type: ignore[no-any-unimported]
    _tmp_filename = "model.bin"

    @classmethod
    def from_pretrained(cls, model: FastText._FastText) -> Self:  # type: ignore[no-any-unimported]
        self = cls.__new__(cls)
        self.__dict__.update(model.__dict__)
        return self

    @classmethod
    def load(cls, saved_model: bytes | str) -> Self:
        if isinstance(saved_model, str) and pathlib.Path(saved_model).exists():
            return cls.from_pretrained(fasttext.load_model(saved_model))

        if isinstance(saved_model, bytes):
            with tempfile.TemporaryDirectory() as d:
                model_path = pathlib.Path(d) / cls._tmp_filename
                model_path.write_bytes(saved_model)
                return cls.load(str(model_path))

        raise FileNotFoundError(saved_model)

    def serialize(self) -> bytes:
        with tempfile.TemporaryDirectory() as d:
            model_path = pathlib.Path(d) / self._tmp_filename
            self.save_model(str(model_path))
            return model_path.read_bytes()

    def __reduce__(self) -> tuple[Callable[[bytes | str], Self], tuple[bytes]]:
        return self.load, (self.serialize(),)

#### Update our original pipeline with our `PicklableFastText` model.

In [None]:
picklable_pipeline = pipeline.make_pipeline(
    FastTextTransformer(model=PicklableFastText.from_pretrained(fasttext_model)),
    cluster.KMeans(n_clusters=2, n_init="auto", random_state=random_state),
).fit(cluster_training_data)

#### Save our new Pipeline

In [None]:
save_pipeline(picklable_pipeline)

#### Does it work?

In [None]:
sentence = lorem.sentence()

print(picklable_pipeline.transform([sentence]))
print(load_pipeline().transform([sentence]))