[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jamescalam/applied-ml-minicourse/blob/main/code/08-populating-index.ipynb)

# 08: Populating the Index

In this chapter we will use everything we have learned so far about the diffusion pipeline, vector search, and cloud storage to build an initial database of *prompt vectors* and their respective images.

<img src="https://github.com/jamescalam/applied-ml-minicourse/raw/main/images/generating-images.png" style="width:80%">

To do all of this we will start by initializing four components:

1. A prompt dataset
2. The `StableDiffusionPipeline`
3. Cloud Storage
4. Pinecone

## Prompt Dataset

To begin we will download the prompt dataset from Hugging Face *Datasets* called `'bartman081523/stable-diffusion-discord-prompts'`. It contains almost *3.9M* prompts. We will not index them all in this example, but feel free to do so if you're feeling *very* patient.

In [5]:
from datasets import load_dataset

prompts = load_dataset(
    'bartman081523/stable-diffusion-discord-prompts',
    split='train'
)
prompts

Using custom data configuration bartman081523--stable-diffusion-discord-prompts-c1485b9878be2896
Reusing dataset text (/Users/jamesbriggs/.cache/huggingface/datasets/bartman081523___text/bartman081523--stable-diffusion-discord-prompts-c1485b9878be2896/0.0.0/4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8)


Dataset({
    features: ['text'],
    num_rows: 3884798
})

## StableDiffusionPipeline

We will be relying on the stable diffusion pipeline for both prompt vector creation *and* generating the images. It should be moved to GPU where possible.

In [None]:
import torch
from diffusers import StableDiffusionPipeline

# set the hardware device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# init all of the pipeline models and move them to a given GPU
pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
  	use_auth_token="<<ACCESS_TOKEN>>"
)
pipe.to(device)
print(device)

## Connect to Cloud Storage

Next we need to connect to our *Cloud Storage* instance. As before we do this using the `cloud-storage.json` credentials.

In [None]:
import os

# set credentials
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json'

And connect to the `diffusion` bucket:

In [None]:
from google.cloud import storage

# connect to bucket (we named it 'diffusion')
storage_client = storage.Client()
bucket = storage_client.get_bucket('diffusion')

## Connect to Pinecone

We must also initialize our connection to Pinecone and create a vector index to store the *prompt vectors*. First we initialize a connection with our [API key](https://app.pinecone.io/).

In [None]:
import pinecone

pinecone.init(
    api_key='<<YOUR_API_KEY>>',  # app.pinecone.io
    environment='us-west1-gcp'
)

Then create a new index with the CLIP pooled embedding dimensionality and using cosine similarity metric.

In [None]:
dim = pipe.text_encoder.get_word_embedding_dimension()
print(dim)

index_name = 'diffusion'

# create index
pinecone.create_index(
    name=index_name,
    dimension=dim,
    metric='cosine'
)

# connect to index
index = pinecone.Index(index_name)

# view index stats
index.describe_index_stats()

We should see that the index is currently empty as we have not added anything yet. Let's take a look at adding a single item before we move on to doing this for a larger number of samples.

## Adding a Record

When adding a record we will perform a few steps:

1. Generate *prompt vector*
2. Generate image
3. Create unique ID shared between the prompt vector and image
4. Upload image to Cloud Storage
5. Insert prompt vector to Pinecone
`
These five steps will be repeating for every prompt. Step **2** is very time consuming and hence why we will not do this for the full 3.9M+ records.

### Generating a Prompt Vector

Starting with the prompt vector, we generate this using the first two components of our `StableDiffusionPipeline`:

In [None]:
prompt = "a person surfing"

# encode prompt to mean_pooled vector
tokens = pipe.tokenizer(
    prompt, padding='max_length',
    return_tensors='pt'
).to(device)
vec = pipe.text_encoder(**tokens)['mean_pooled'].detach().cpu().numpy().tolist()

### Generating Image

To generate the equivalent image for our prompt vector, we simply run the prompt through the full pipeline.

In [None]:
out = pipe(prompt)
out

: 

And view the image like so:

In [None]:
out.images[0]

### Unique ID

The unique ID will be shared by both the prompt vector that will be stored in Pinecone, and the image to be stored in Cloud Storage. We use the `uuid` library to create it:

In [2]:
import uuid

_id = str(uuid.uuid4())
_id

'c5280325-206f-44d8-a9a4-0b8cdf9fd2cd'

### Upload to Cloud Storage

Next we use the unique ID to upload the generated image to GCP Cloud Storage. First we save the image to file:

In [None]:
out.images[0].save('tmp.png')

Then upload using the unique ID as a filename:

In [None]:
blob = bucket.blob(f'{_id}.png')
blob.upload_from_filename('tmp.png')

### Insert to Pinecone

The final step is inserting the prompt vector and any relevant metadata in Pinecone. Every record in Pinecone requires an ID, vector, and *optionally* metadata dictionary.

In [None]:
# create metadata
metadata = {
    "prompt": prompt
}
# format as tuple
to_upsert = (_id, vec, metadata)
# upsert to index
index.upsert([to_upsert])

# view index stats
index.describe_index_stats()

We should now see that a single item exists within the index.

All we need to do now is repeat this process for many records in our prompts dataset.

## Building our Database

We will repeat the above steps and batch prompts together where possible.

Before starting, we will trim down the prompts dataset. Many smaller prompts are nonscensical or not very interesting, so we will filter those out first:

In [6]:
prompts = prompts.filter(
    lambda x: len(x['text']) > 30
)
prompts

  0%|          | 0/3885 [00:00<?, ?ba/s]

Dataset({
    features: ['text'],
    num_rows: 3582032
})

And remove duplicates...

In [7]:
prompts = list(set(prompts['text']))
len(prompts)

3546559

We can build a reasonable database of images with just 10-20K of these prompts. Naturally the more the merrier, but to avoid finishing this minicourse next year let's go with a smaller number of `10000` (go for fewer if you'd rather not wait).

In [None]:
prompts = prompts[:10000]

We'll create a function to create the prompt vectors and images called `embed_and_diffuse`:

In [None]:
def embed_and_diffuse(prompts: list):
    # __diffuse images__
    out = pipe(prompts)
    if any(out.nsfw_content_detected):
        return {}
    # __create text embeddings__
    inputs = text_inputs = pipe.tokenizer(
        prompts, padding=True,
        truncation=True, return_tensors='pt'
    ).to(device)
    text_embeds = pipe.text_encoder(**text_inputs)
    # get pooled embeddings, move to CPU and convert to list for pinecone
    text_embeds = text_embeds.pooler_output.cpu().tolist()
    return {
        'text_embeds': text_embeds,
        'images': out.images
    }

Now iterate through and populate the database.

In [None]:
from tqdm.auto import tqdm

batch_size = 8  # we will run through in batches

for i in tqdm(range(0, len(prompts), batch_size)):
    i_end = min(len(prompts), i+batch_size)
    # get batch of prompts
    prompts = prompts[i:i_end]
    data = embed_and_diffuse(prompts)
    if not data:
        # nsfw content detected so skip
        continue
    # create batch of ids
    ids = [str(uuid.uuid4()) for _ in range(len(prompts))]
    meta = []
    # add images to cloud storage
    for _id, image, prompt in zip(ids, data['images'], prompts):
        image.save('tmp.png', format='png')
        # push to cloud storage
        blob = bucket.blob(f'images/{_id}.png')
        blob.upload_from_filename('tmp.png')
    # create metadata
    meta = [{
            'prompt': prompt,
            'image_url': f'images/{_id}.png'
        } for _id, prompt in zip(ids, prompts)]
    # add to pinecone
    index.upsert(zip(ids, data['text_embeds'], meta))

After this has completed running we should find that we have 10,001 images uploaded in our index (+1 for the first `"a person surfing"` prompt).

If the above is taking too long to run, feel free to stop the execution and continue with the course. 10K images is not a prerequisite for the remainder of the course.