disco  
Copyright (C) 2022-present NAVER Corp.  
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license  

# Conditional Tuning with CDPG

In the conditional case the context is no longer fixed and we rely on a more generic tuner, a `CDPGTuner`. In most cases we favor a seq2seq model, and our features make use of the context and the sample.

For our experiment, we're going to summarize news article with T5, making sure the model does not hallucinate organizations.

## Expressing Preferences

Using spaCy, we can extract organization names from a text.

In [None]:
import spacy

In [None]:
nlp = spacy.load("en_core_web_sm")

In [None]:
def organizations(text):
    """returns a set of organizations from a text"""
    doc = nlp(text)
    return set(ent.text for ent in doc.ents if "ORG" == ent.label_)

Now that we can obtain a set of organizations from a text, we can build a scorer: we want to make sure that a sample only includes the organizations mentioned in the context, that we're going to summarize —in other words we don't want to have hallucinated organizations.

In [None]:
from disco.scorers.boolean_scorer import BooleanScorer

In [None]:
organization_scorer = BooleanScorer(lambda s, c: all({o in organizations(c) for o in organizations(s.text)}))

For this task, we're going to use a powerful seq2seq model from Transformers, in a "small" version: T5.

In [None]:
from disco.distributions import LMDistribution

In [None]:
base = LMDistribution(network="t5-base", tokenizer="t5-base", nature="seq2seq")

And we simply state that we want all samples to respect our preferences.

In [None]:
target = base * organization_scorer

## Tuning, Conditionally

We now want to tune a model in order to approximate this target distribution. For this we will need many contexts: we can use a DatasetContextDistribution to rely on a dataset from Hugging Face's Datasets repository, the CNN / Dailymail dataset. Let's see how this works.

In [None]:
from disco.distributions.dataset_context_distribution import DatasetContextDistribution

In [None]:
dataset = DatasetContextDistribution(dataset="cnn_dailymail", subset="1.0.0", split="train", key="article")

Out of curiosity, we can sample a few articles and extract a set of organizations from the first one by doing:

In [None]:
articles, log_scores = dataset.sample(sampling_size=2**3)

In [None]:
articles[0]

In [None]:
organizations(articles[0])

We're using the online scheme, sampling directly from the model we'll be tuning —it's also very possible to rely on the offline scheme, see the [Tuning notebook](./3.tuning_DPG.ipynb).

In [None]:
model = LMDistribution(network="t5-base", tokenizer="t5-base", nature="seq2seq", length=256, freeze=False, )



We can now instantiate a tuner. We're going:
  * to tune model to approximate target getting our samples from the model itself;
  * to use a context distribution to fetch articles from the CNN / Dailymail —all prepended with the task "summarize :" to control T5.

In [None]:
from disco.tuners import CDPGTuner

In [None]:
tuner = CDPGTuner(model, target,
        context_distribution=DatasetContextDistribution(
                dataset="cnn_dailymail", subset="1.0.0", split="train", key="article", prefix="summarize: "),
        n_gradient_steps=1000,
        n_samples_per_step=2**8,
        sampling_size=2**5,
        scoring_size=2**5)

Of course we want to monitor the progress so we use a logger.

In [None]:
from disco.tuners.loggers.console import ConsoleLogger

In [None]:
ConsoleLogger(tuner)

Note that to instead / also use a `NeptuneLogger` we can simply uncomment the following cell, assuming we've actually setup to use the service.

In [None]:
# from disco.tuners.loggers.neptune import NeptuneLogger
# import os
# NEPTUNE_API_TOKEN = os.environ["NEPTUNE_API_TOKEN"]
# NeptuneLogger(tuner,
#     project="disco", api_token=NEPTUNE_API_TOKEN
# )

Let's dance!

In [None]:
tuner.tune()