<a href="https://colab.research.google.com/github/charlielito/stable-diffusion-videos/blob/fix%2Ftpu_colab_example/flax_stable_diffusion_videos.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Flax Stable Diffusion Videos

This notebook allows you to generate videos by interpolating the latent space of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) using TPU for faster inference.

In comparison with standard Colab GPU, this runs ~6x faster after the first run. The first run is comparable to the GPU version because it compiles the code.

You can either dream up different versions of the same prompt, or morph between different text prompts (with seeds set for each for reproducibility).

If you like this notebook:
- consider giving the [repo a star](https://github.com/nateraw/stable-diffusion-videos) ⭐️
- consider following us on Github [@nateraw](https://github.com/nateraw) [@charlielito](https://github.com/charlielito)

You can file any issues/feature requests [here](https://github.com/nateraw/stable-diffusion-videos/issues)

Enjoy 🤗

## Setup

In [1]:
#@title Set up JAX
#@markdown If you see an error, make sure you are using a TPU backend. Select `Runtime` in the menu above, then select the option "Change runtime type" and then select `TPU` under the `Hardware accelerator` setting.

# No upgrade, latest working version of Jax with Colab TPU is: 0.3.25
# !pip install --upgrade jax jaxlib 

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu('tpu_driver_20221011')

jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [2]:
# Install with --no-deps because flax attempts to install jax>0.3.16 although we already have 0.3.25. 
# Then it gets installed jax~=0.4 which is incompatible with Colab TPUs in our code
!pip install --no-deps flax==0.6.3

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting flax==0.6.3
  Downloading flax-0.6.3-py3-none-any.whl (197 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m197.4/197.4 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: flax
  Attempting uninstall: flax
    Found existing installation: flax 0.6.8
    Uninstalling flax-0.6.8:
      Successfully uninstalled flax-0.6.8
Successfully installed flax-0.6.3


In [3]:
!pip install diffusers==0.12.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting diffusers==0.12.0
  Downloading diffusers-0.12.0-py3-none-any.whl (604 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m604.0/604.0 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub>=0.10.0
  Downloading huggingface_hub-0.13.4-py3-none-any.whl (200 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.1/200.1 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: huggingface-hub, diffusers
Successfully installed diffusers-0.12.0 huggingface-hub-0.13.4


In [4]:
%%capture
! pip install stable_diffusion_videos

## Run the App 🚀

### Load the Interface

This step will take a couple minutes the first time you run it.

In [5]:
import numpy as np
import jax
import jax.numpy as jnp

from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from stable_diffusion_videos import FlaxStableDiffusionWalkPipeline, Interface

pipeline, params = FlaxStableDiffusionWalkPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    revision="bf16", 
    dtype=jnp.bfloat16
)
p_params = replicate(params)

interface = Interface(pipeline, params=p_params)

Downloading (…)f16/model_index.json:   0%|          | 0.00/563 [00:00<?, ?B/s]

Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

Downloading (…)nfig-checkpoint.json:   0%|          | 0.00/209 [00:00<?, ?B/s]

Downloading flax_model.msgpack:   0%|          | 0.00/608M [00:00<?, ?B/s]

Downloading (…)_encoder/config.json:   0%|          | 0.00/587 [00:00<?, ?B/s]

Downloading (…)cheduler_config.json:   0%|          | 0.00/230 [00:00<?, ?B/s]

Downloading (…)_checker/config.json:   0%|          | 0.00/4.78k [00:00<?, ?B/s]

Downloading (…)tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

Downloading flax_model.msgpack:   0%|          | 0.00/246M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

Downloading (…)tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

Downloading (…)10a/unet/config.json:   0%|          | 0.00/587 [00:00<?, ?B/s]

Downloading (…)c10a/vae/config.json:   0%|          | 0.00/556 [00:00<?, ?B/s]

Downloading (…)n_flax_model.msgpack:   0%|          | 0.00/1.72G [00:00<?, ?B/s]

Downloading (…)n_flax_model.msgpack:   0%|          | 0.00/167M [00:00<?, ?B/s]

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
Some of the weights of FlaxStableDiffusionSafetyChecker were initialized in bfloat16 precision from the model checkpoint at /root/.cache/huggingface/diffusers/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/safety_checker:
[('concept_embeds',), ('concept_embeds_weights',), ('special_care_embeds',), ('special_care_embeds_weights',), ('vision_model', 'vision_model', 'embeddings', 'class_embedding'), ('vision_model', 'vision_model', 'embeddings', 'patch_embedding', 'kernel'), ('vision_model', 'vision_model', 'embeddings', 'position_embedding', 'embedding'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('vision_m

In [None]:
#@title Connect to Google Drive to Save Outputs

#@markdown If you want to connect Google Drive, click the checkbox below and run this cell. You'll be prompted to authenticate.

#@markdown If you just want to save your outputs in this Colab session, don't worry about this cell

connect_google_drive = True #@param {type:"boolean"}

#@markdown Then, in the interface, use this path as the `output` in the Video tab to save your videos to Google Drive:

#@markdown > /content/gdrive/MyDrive/stable_diffusion_videos


if connect_google_drive:
    from google.colab import drive

    drive.mount('/content/gdrive')

### Launch

This cell launches a Gradio Interface. Here's how I suggest you use it:

1. Use the "Images" tab to generate images you like.
    - Find two images you want to morph between
    - These images should use the same settings (guidance scale, height, width)
    - Keep track of the seeds/settings you used so you can reproduce them

2. Generate videos using the "Videos" tab
    - Using the images you found from the step above, provide the prompts/seeds you recorded
    - Set the `num_interpolation_steps` - for testing you can use a small number like 3 or 5, but to get great results you'll want to use something larger (60-200 steps). 

💡 **Pro tip** - Click the link that looks like `https://<id-number>.gradio.app` below , and you'll be able to view it in full screen.

In [None]:
interface.launch(debug=True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Note: opening Chrome Inspector may crash demo inside Colab notebooks.

To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>

---

## Use `walk` programmatically

The other option is to not use the interface, and instead use `walk` programmatically. Here's how you would do that...

First we define a helper fn for visualizing videos in colab

In [None]:
from IPython.display import HTML
from base64 import b64encode

def visualize_video_colab(video_path):
    mp4 = open(video_path,'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    return HTML("""
    <video width=400 controls>
        <source src="%s" type="video/mp4">
    </video>
    """ % data_url)

Walk! 🚶‍♀️

In [None]:
video_path = pipeline.walk(
    p_params,
    ['a cat', 'a dog'],
    [42, 1337],
    fps=5,                      # use 5 for testing, 25 or 30 for better quality
    num_interpolation_steps=30,  # use 3-5 for testing, 30 or more for better results
    height=512,                 # use multiples of 64 if > 512. Multiples of 8 if < 512.
    width=512,                  # use multiples of 64 if > 512. Multiples of 8 if < 512.
    jit=True                    # To use all TPU cores
)
visualize_video_colab(video_path)

### Bonus! Music videos

First, we'll need to install `youtube-dl`

In [None]:
%%capture
! pip install youtube-dl

Then, we can download an example music file. Here we download one from my soundcloud:

In [None]:
! youtube-dl -f bestaudio --extract-audio --audio-format mp3 --audio-quality 0 -o "music/thoughts.%(ext)s" https://soundcloud.com/nateraw/thoughts

In [None]:
from IPython.display import Audio

Audio(filename='music/thoughts.mp3')

In [None]:
# Seconds in the song
audio_offsets = [7, 9]
fps = 8

# Convert seconds to frames
num_interpolation_steps = [(b-a) * fps for a, b in zip(audio_offsets, audio_offsets[1:])]

video_path = pipeline.walk(
    p_params,
    prompts=['blueberry spaghetti', 'strawberry spaghetti'],
    seeds=[42, 1337],
    num_interpolation_steps=num_interpolation_steps,
    height=512,                            # use multiples of 64
    width=512,                             # use multiples of 64
    audio_filepath='music/thoughts.mp3',   # Use your own file
    audio_start_sec=audio_offsets[0],      # Start second of the provided audio
    fps=fps,                               # important to set yourself based on the num_interpolation_steps you defined
    batch_size=2,                          # in TPU-v2 typically maximum of 3 for 512x512
    output_dir='./dreams',                 # Where images will be saved
    name=None,                             # Subdir of output dir. will be timestamp by default
    jit=True                               # To use all TPU cores
)
visualize_video_colab(video_path)