disco  
Copyright (C) 2022-present NAVER Corp.  
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license  

# Tuning with DPG

Once we have expressed our preferences on the generated sequences, through an Energy-Based Model (EBM), we cannot directly sample from it. What we can do is approximate it by fine-tuning a model.  
Let's first see the case of classic, unconditional, ie with a fixed context, DPG.

## Expressing Preferences

Let's stick with our _amazing_ use case: we want the word to appear in our samples —see the [Expressing Preference](./2.expressing_preferences.ipynb) notebook for the detailed explanations.

In [None]:
from disco.scorers import BooleanScorer

In [None]:
import re

is_amazing = lambda s, c: bool(re.search(r"\bamazing\b", s.text))
amazing_scorer = BooleanScorer(is_amazing)

In [None]:
from disco.distributions import LMDistribution

for a pointwise constraint:

In [None]:
base = LMDistribution()
pw_target = base * amazing_scorer

for a distributional one:

In [None]:
from disco.distributions.single_context_distribution import SingleContextDistribution

In [None]:
incipit = "It was a cold and stormy night"

In [None]:
dc_target = base.constrain([amazing_scorer], [1/2],
        n_samples=2**10,
        context_distribution=SingleContextDistribution(incipit))

## Tuning

We then instantiate the model we want to tune —we'll tune the "network" inside the distribution.

In [None]:
model = LMDistribution(freeze=False)

Let's check the initial rate for our constraint.

In [None]:
from disco.samplers import AccumulationSampler

In [None]:
sampler = AccumulationSampler(model, total_size=2**9)
samples, log_scores = sampler.sample(context=incipit)

In [None]:
sum([is_amazing(s, _) for s in samples]) / len(samples)

### Offline

In the offline scheme, we use a companion proposal distribution to sample from, and update that proposal, eventually, during the tuning.

In [None]:
proposal = LMDistribution()

We can now instantiate a tuner. We're going:
  * to tune model to approximate dc_target getting our samples from proposal;
  * to use a fixed incipit for the context;
  * to check the divergence every `divergence_evaluation_interval` gradient steps, when we'll also eventually update the proposal.

In [None]:
from disco.tuners import DPGTuner

In [None]:
tuner = DPGTuner(model, dc_target, proposal,
        context=incipit,
        n_gradient_steps=1000,
        n_samples_per_step=2**8,
        sampling_size=2**5,
        scoring_size=2**5,
        divergence_evaluation_interval=2**2,
        n_kl_samples=2**10)

There are loggers we can use to monitor the tuning. They are built on the observer patterns so it's easy to add more specific ones —although beyond the simple `ConsoleLoger` disco provides loggers for Neptune, Weight & Biases, ...

In [None]:
from disco.tuners.loggers.console import ConsoleLogger

In [None]:
ConsoleLogger(tuner)

Let's dance!

In [None]:
tuner.tune()

Are we doing better?

In [None]:
sampler = AccumulationSampler(model, total_size=512)
samples, log_scores = sampler.sample(context=incipit)

In [None]:
sum([is_amazing(s, _) for s in samples]) / len(samples)

### Online tuning

In the online scheme, the model being tuned is also the one providing the samples, so we don't need a proposal.

In [None]:
model = LMDistribution(freeze=False)

_Note that, for an actual tuning, you might want to move the networks to GPUs first, for example with:_
```
model.to("cuda")
dc_target.scorers[0].to("cuda")
```

In [None]:
tuner = DPGTuner(model, dc_target,
        context=incipit,
        n_gradient_steps=100,
        n_samples_per_step=2**8,
        sampling_size=2**5,
        scoring_size=2**5,
        divergence_evaluation_interval=1)

_Again, for an actual tuning, you might want to initiate logging, for example with:_
```
from disco.tuners.loggers.wandb import WandBLogger
logger = WandBLogger(tuner, "my_project", "my_run")
```

In [None]:
tuner.tune()

In [None]:
sampler = AccumulationSampler(model, total_size=2**9)
samples, log_scores = sampler.sample(context=incipit)

In [None]:
sum([is_amazing(s, _) for s in samples]) / len(samples)