# Textual Inversion with Determined

This notebook generates images from the trained textual inversion models generated with the `detsd.DetSDTextualInversionTrainer` class and saved as Determined checkpoints.  This notebook should be connected to a GPU.

### Pre-Launch Setup

A [Huggingface User Access Token](https://huggingface.co/docs/hub/security-tokens) is required to download the [Stable Diffusion weights](https://huggingface.co/CompVis/stable-diffusion-v1-4). To use this notebook, please modify the following lines in the `detsd-notebook.yaml` file:
```yaml
environment:
    environment_variables:
        - HF_AUTH_TOKEN=YOUR_HF_AUTH_TOKEN_HERE
```
after which this notebook can be launched by calling the below from the root of the repo directory
```bash
det -m MASTER_URL_WITH_PORT notebook start --config-file detsd-notebook.yaml --context .
```
to load the entire repo into the JupyterLab instance, or 
```bash
det -m MASTER_URL_WITH_PORT notebook start --config-file detsd-notebook.yaml -i detsd -i startup-hook.sh -i learned_embeddings_dict_demo.pt -i textual_inversion.ipynb
```
to only load the files relevant to this notebook.  After either option above, the copy of `textual_inversion.ipynb` on the master can be opened and run.

Update the jupyter notebook for better progress-bar rendering (may require a re-load to take effect). If the notebook was launched using the above command, other dependencies were already installed upon agent start-up through the repo's `startup-hook.sh` script.

In [None]:
! pip install -qq jupyterlab-widgets==1.1.1 ipywidgets==7.7.2

## Creating the Pipeline

Import the `DetSDTextualInversionPipeline` class from `detsd.py` (loaded via the `--context` flag above), which will be used to generate Stable Diffusion images.

In [None]:
import torch

from detsd import DetSDTextualInversionPipeline

Instantiate the pipeline with the default arguments:

In [None]:
detsd_pipeline = DetSDTextualInversionPipeline()

Note: `DetSDTextualInversionPipeline` is initialized with `use_fp16=True` by default which increases inference speed and reduces memory usage, at the cost of somewhat reduced-quality images.  All available args can be viewed by uncommenting and running the cell below

In [None]:
# ? DetSDTextualInversionPipeline

## Load Determined Checkpoints

We can now load textual-inversion checkpoints into the model. They are assumed to have been trained with `DetSDTextualInversionTrainer`, also contained in `detsd.py`.  These Determined checkpoints can be specified by their uuid, assuming all such checkpoints exist on the master we are currently logged into.

In [None]:
# Code for logging into the master, if not already logged in.
# Not required if notebook was launched as described above.

# from determined.experimental import client
# client.login(master=MASTER_URL_WITH_PORT, user=USER, password=PASS)

Fill in the `uuids` list below with the `uuid` `str` values of any Determined checkpoints you wish to incorporate into the model.

In [None]:
uuids = []
detsd_pipeline.load_from_uuids(uuids)

A sample embedding is also included in the repo (with corresponding concept token `det-logo-demo`) and can be loaded in as follows (assuming the notebook was launched with the `-i learned_embeddings_dict_demo.pt` arg):

In [None]:
from os.path import exists
demo_concept_path = 'learned_embeddings_dict_demo.pt'
if exists(demo_concept_path):
    detsd_pipeline.load_from_checkpoint_dir(checkpoint_dir='.', learned_embeddings_filename='learned_embeddings_dict_demo.pt')

## Generate Images

Finally, let's generate some art.

Grab the first concept which was loaded into the pipeline and store it as `first_concept`.  If no concepts were loaded above, fall back to using `brain logo, sharp lines, connected circles, concept art` as a default value for `first_concept`; vanilla Stable Diffusion is being used in this case.

In [None]:
all_added_concepts = detsd_pipeline.all_added_concepts
if all_added_concepts:
    first_concept = all_added_concepts[0]
else:
    first_concept = 'brain logo, sharp lines, connected circles, concept art'
print(f'Using "{first_concept}" as first_concept in the below\n')
print(f'All available concepts: {all_added_concepts}')

Create a directory for saved images and an index for tracking the number of images created.

In [None]:
save_dir = 'generated_images'
! mkdir {save_dir}
num_generated_images = 0

The below code uses creates `batch_size * num_images_per_prompt` total images from the prompt.

If you are generating using the demo embedding with `det-logo-demo` as `first_concept`, we recommend setting the guidance scale to a relatively low value, e.g. ~3.

In [None]:
prompt = f'a watercolor painting on textured paper of a {first_concept} using soft strokes, pastel colors, incredible composition, masterpiece'
batch_size = 2
num_images_per_prompt = 2

generator = torch.Generator(device='cuda').manual_seed(2147483647)
output = detsd_pipeline(prompt=[prompt] * batch_size,
                        num_images_per_prompt=num_images_per_prompt,
                        num_inference_steps=50,
                        generator=generator,
                        guidance_scale=7.5
                       )

Visualize and save:

In [None]:
from pathlib import Path

for img, nsfw in zip(output.images, output.nsfw_content_detected):
    # Skip black images which are made when NSFW is detected.
    if not nsfw:
        num_generated_images += 1
        display(img)
        img.save(Path(save_dir).joinpath(f'{num_generated_images}.png'))

Explanation the some arguments above:
* `num_inference_steps`: how many steps to run the generation process for. ~50 is typical
* `guidance_scale`: tunes how much weight is given to the prompt during generation. 7.5 is the default, with larger numbers leading to stronger adherence to the prompt.
* `generator`: pass in a fixed `torch.Generator` instance for reproducible results.

`DetSDTextualInversionPipeline`'s `__call__` method accepts the same arguments as its underlying Huggingface `StableDiffusionPipeline` instance; see the [Hugging Face documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline.__call__) for information on all available arguments.