# IMDB movie review text generation

Once you have fine-tuned your model you can test it interactively with this notebook.

In [None]:
from transformers import pipeline

path_to_model = "/scratch/project_462001095/data/users/YOUR_USERNAME_HERE/gpt-imdb-model/checkpoint-5000/"
generator = pipeline("text-generation", model=path_to_model)

In [None]:
def print_output(output):
    for item in output:
        text = item['generated_text']
        text = text.replace("<br />", "\n")
        print('-', text)
        print()

In [None]:
output = generator("This movie was")
print_output(output)

## Experiment with the generation strategy

You can play with the text generation if you wish. Text generation strategies are discussed here: https://huggingface.co/docs/transformers/generation_strategies

Note that we are here using the easy-to-use `TextGenerationPipeline` and its `generator()` function, but the link discusses the `model.generate()` method. The same parameters can be used, though, the pipeline just takes care of some of the pre- and post-processing.

In particular these parameters of the `generator()` function might be interesting:

- `max_new_tokens`: the maximum number of tokens to generate
- `num_beams`: activate Beam search by setting this > 1
- `do_sample`: activate multinomial sampling if set to True
- `num_return_sequences`: the number of candidate sentences to return (available only for beam search and sampling)

Here is a nice blog post explaining in more detail about the different generation strategies: https://huggingface.co/blog/how-to-generate

In [None]:
output = generator("This movie was awful because", num_return_sequences=1, max_new_tokens=100, do_sample=True)
print_output(output)

## Compare with the original model without fine-tuning

We can also load the original `distilgpt2` model and see how it would have worked without fine-tuning.

In [None]:
generator_orig = pipeline("text-generation", model='distilgpt2')

In [None]:
output = generator_orig("This movie was awful because", num_return_sequences=1, max_new_tokens=100, do_sample=True)
print_output(output)