# $ \text{Video Killed The Radio Star}$ $\color{red}{...Diffusion}$

Notebook by David Marx ([@DigThatData](https://twitter.com/digthatdata))

Shared under MIT license


# $\text{FAQ}$

**What is this?**

Point this notebook at a youtube url and it'll make a music video for you.

**How does this animation technique work?**

For each text prompt you provide, the notebook will...

1. Generate an image based on that text prompt (using stable diffusion)
2. Use the generated image as the `init_image` to recombine with the text prompt to generate variations similar to the first image. This produces a sequence of extremely similar images based on the original text prompt
3. Images are then intelligently reordered to find the smoothest animation sequence of those frames
3. This image sequence is then repeated to pad out the animation duration as needed

The technique demonstrated in this notebook was inspired by a [video](https://www.youtube.com/watch?v=WJaxFbdjm8c) created by Ben Gillin.

**How are lyrics transcribed?**

This notebook uses openai's recently released 'whisper' model for performing automatic speech recognition. 
OpenAI was kind enough to offer several different sizes of this model which each have their own pros and cons. 
This notebook uses the largest whisper model for transcribing the actual lyrics. Additionally, we use the 
smallest model for performing the lyric segmentation. Neither of these models is perfect, but the results 
so far seem pretty decent.

The first draft of this notebook relied on subtitles from youtube videos to determine timing, which was
then aligned with user-provided lyrics. Youtube's automated captions are powerful and I'll update the
notebook shortly to leverage those again, but for the time being we're just using whisper for everything
and not referencing user-provided captions at all.

**Something didn't work quite right in the transcription process. How do fix the timing or the actual lyrics?**

The notebook is divided into several steps. Between each step, a "storyboard" file is updated. If you want to
make modifications, you can edit this file directly and those edits should be reflected when you next load the
file. Depending on what you changed and what step you run next, your changes may be ignored or even overwritten.
Still playing with different solutions here.

**Can I provide my own images to 'bring to life' and associate with certain lyrics/sequences?**

Yes, you can! As described above: you just need to modify the storyboard. Will describe this functionality in
greater detail after the implementation stabilizes a bit more.

**This gave me an idea and I'd like to use just a part of your process here. What's the best way to reuse just some of the machinery you've developed here?**

Most of the functionality in this notebook has been offloaded to library I published to pypi called `vktrs`. I strongly encourage you to import anything you need 
from there rather than cutting and pasting function into a notebook. Similarly, if you have ideas for improvements, please don't hesitate to submit a PR!

**How can I support your work or work like it?**

This notebook was made possible thanks to ongoing support from [stability.ai](https://stability.ai/). The best way to support my work is to share it with your friends, [report bugs](https://github.com/dmarx/video-killed-the-radio-star/issues/new), [suggest features](https://github.com/dmarx/video-killed-the-radio-star/discussions) or to donate to open source non-profits :) 

## $0.$ Setup

In [None]:
# @title # 📊 Check GPU Status

try:
    from vktrs.utils import gpu_info
except:
    import pandas as pd
    import subprocess
    
    def gpu_info():
        outv = subprocess.run([
            'nvidia-smi',
                # these lines concatenate into a single query string
                '--query-gpu='
                'timestamp,'
                'name,'
                'utilization.gpu,'
                'utilization.memory,'
                'memory.used,'
                'memory.free,'
                ,
            '--format=csv'
            ],
            stdout=subprocess.PIPE).stdout.decode('utf-8')

        header, rec = outv.split('\n')[:-1]
        return pd.DataFrame({' '.join(k.strip().split('.')).capitalize():v for k,v in zip(header.split(','), rec.split(','))}, index=[0]).T

gpu_info()

In [None]:
#%%capture

local = True

if not local:
    # @title # 🛠️ Installations
    !apt-get install cargo
    !pip install vktrs[api,hf]

    #!pip install git+https://github.com/openai/whisper@v20230314
    !pip install openai-whisper


    # these are only needed for hf
    !pip install "ipywidgets>=7,<8"
    !sudo apt -qq install git-lfs
    !git config --global credential.helper store

    !pip install panel prefetch_generator

    !pip uninstall -y protobuf
    !pip install protobuf==4.22.1

    # FML... sorry y'all...
    !pip install omegaconf huggingface_hub

In [None]:
# @title # 🔑 Provide your API Key
# @markdown Running this cell will prompt you to enter your API Key below. 

# @markdown To get your API key, visit https://beta.dreamstudio.ai/membership

# @markdown ---

# @markdown A note on security best practices: **don't publish your API key.**

# @markdown We're using a form field designed for sensitive data like passwords.
# @markdown This notebook does not save your API key in the notebook itself,
# @markdown but instead loads your API Key into the colab environment. This way,
# @markdown you can make changes to this notebook and share it without concern
# @markdown that you might accidentally share your API Key. 
# @markdown 

use_stability_api = False # @param {type:'boolean'}
mount_gdrive = True # @param {type:'boolean'}

if local:
    mount_gdrive=False

import os
from pathlib import Path
import time

from omegaconf import OmegaConf


os.environ['XDG_CACHE_HOME'] = os.environ.get(
    'XDG_CACHE_HOME',
    str(Path('~/.cache').expanduser())
)
if mount_gdrive:
    from google.colab import drive
    drive.mount('/content/drive')
    Path('/content/drive/MyDrive/AI/models/.cache/').mkdir(parents=True, exist_ok=True) 
    # This rm+ln solution is not great. Be careful not to run this locally. 
    # Low risk, but could be annoying    
    !rm -rf /root/.cache
    !ln -sf /content/drive/MyDrive/AI/models/.cache/ /root/
    # Following line will be sufficient pending merge of https://github.com/openai/whisper/pull/257
    os.environ['XDG_CACHE_HOME']='/content/drive/MyDrive/AI/models/.cache'

model_dir_str=str(Path(os.environ['XDG_CACHE_HOME']))
proj_root_str = '${active_project}'
if mount_gdrive:
    proj_root_str = '/content/drive/MyDrive/AI/VideoKilledTheRadioStar/${active_project}'


# notebook config
cfg = OmegaConf.create({
    'active_project':str(time.time()),
    'project_root':proj_root_str,
    'gdrive_mounted':mount_gdrive,
    'use_stability_api':use_stability_api,
    'model_dir':model_dir_str,
    'output_dir':'${active_project}/frames'
})

with open('config.yaml','w') as fp:
    OmegaConf.save(config=cfg, f=fp.name)

###################

if use_stability_api:
    import os, getpass
    os.environ['STABILITY_KEY'] = getpass.getpass('Enter your API Key')
else:
    try:
        from google.colab import output
        output.enable_custom_widget_manager()
    except ImportError:
        # assume local use
        pass
    
    from huggingface_hub import notebook_login

    # to do: if gdrive mounted, check for API token... somewhere on drive?
    # looks like we should be able to find the token through an environment variable
    notebook_login()


In [None]:
from huggingface_hub import notebook_login

    # to do: if gdrive mounted, check for API token... somewhere on drive?
    # looks like we should be able to find the token through an environment variable
notebook_login()

## $1.$ 📋 Set Project Name (create/resume)

In [None]:

import time
from vktrs.utils import sanitize_folder_name
from omegaconf import OmegaConf

project_name = 'importance_sampling' # @param {type:'string'}
if not project_name:
    project_name = str(time.time())

project_name = sanitize_folder_name(project_name)

workspace = OmegaConf.load('config.yaml')
workspace.active_project = project_name

with open('config.yaml','w') as fp:
    OmegaConf.save(config=workspace, f=fp.name)

# @markdown To create a new project, enter a unique project name.
# @markdown If you leave `project_name` blank, the current unix timestamp will be used
# @markdown  (seconds since 1970-01-01 00:00).

# @markdown If you use the name of an existing project, the workspace will switch to that project.

# @markdown Non-alphanumeric characters (excluding '-' and '_') will be replaced with hyphens.


# reset workspace
if 'df' in locals():
    del df
if 'df_regen' in locals():
    del df_regen

## $2.$ 🔊 Infer speech from audio

In [None]:
from omegaconf import OmegaConf
from pathlib import Path

workspace = OmegaConf.load('config.yaml')
use_stability_api = workspace.use_stability_api
model_dir = workspace.model_dir

root = workspace.project_root
root = Path(root)
root.mkdir(parents=True, exist_ok=True)


import copy
import datetime as dt
import gc
from itertools import chain, cycle
import json
import os
import re
import string
from subprocess import Popen, PIPE
import textwrap
import time
import warnings

from IPython.display import display
import numpy as np
import pandas as pd
import panel as pn
from tqdm.autonotebook import tqdm

import tokenizations
import webvtt
import whisper

from vktrs.utils import remove_punctuation
from vktrs.utils import get_audio_duration_seconds
from vktrs.youtube import (
    YoutubeHelper,
    parse_timestamp,
    vtt_to_token_timestamps,
    srv2_to_token_timestamps,
)

storyboard = OmegaConf.create()

d_ = dict(
    # all this does is make it so each of the following lines can be preceded with a comma
    # otw the first parameter would be offset from the other in the colab form
    _=""

    , video_url = 'https://www.youtube.com/watch?v=fregObNcHC8' # @param {type:'string'}
    , audio_fpath = '' # @param {type:'string'}
    , whisper_seg = True # @param {type:'boolean'}
)
d_.pop('_')
storyboard.params = d_

if not storyboard.params.audio_fpath:
    storyboard.params.audio_fpath = None


# @markdown `video_url` - URL of a youtube video to download as a source for audio and potentially for text transcription as well.

# @markdown `audio_fpath` - Optionally provide an audio file instead of relying on a youtube download. Name it something other than 'audio.mp3', 
# @markdown                 otherwise it might get overwritten accidentally.

# @markdown `whisper_seg` - Whether or not to use openai's whisper model for lyric segmentation. This is currently the only option, but that will change in a few days.


storyboard_fname = root / 'storyboard.yaml'
with open(storyboard_fname,'wb') as fp:
    OmegaConf.save(config=storyboard, f=fp.name)


###############################
# Download audio from youtube #
###############################

video_url = storyboard.params.video_url

if video_url:
    # check if user provided an audio filepath (or we already have one from youtube) before attempting to download
    if storyboard.params.get('audio_fpath') is None:
        helper = YoutubeHelper(
            video_url,
            ydl_opts = {
                'outtmpl':{'default':str( root / f"ytdlp_content.%(ext)s" )},
                'writeautomaticsub':True,
                'subtitlesformat':'srv2/vtt'
                },
        )

        # estimate video end
        video_duration = dt.timedelta(seconds=helper.info['duration'])
        storyboard.params['video_duration'] = video_duration.total_seconds()

        audio_fpath = str( root / 'audio.mp3' )
        input_audio = helper.info['requested_downloads'][-1]['filepath']
        !ffmpeg -y -i "{input_audio}" -acodec libmp3lame {audio_fpath}

        # to do: write audio and subtitle paths/meta to storyboard
        storyboard.params.audio_fpath = audio_fpath

        if False:
            subtitle_format = helper.info['requested_subtitles']['en']['ext']
            subtitle_fpath = helper.info['requested_subtitles']['en']['filepath']

            if subtitle_format == 'srv2':
                with open(subtitle_fpath, 'r') as f:
                    srv2_xml = f.read() 
                token_start_times = srv2_to_token_timestamps(srv2_xml)
                # to do: handle timedeltas...
                #storyboard.params.token_start_times = token_start_times

            elif subtitle_format == 'vtt':
                captions = webvtt.read(subtitle_fpath)
                token_start_times = vtt_to_token_timestamps(captions)
                # to do: handle timedeltas...
                #storyboard.params.token_start_times = token_start_times

            # If unable to download supported subtitles, force use whisper
            else:
                storyboard.params.whisper_seg = True


# estimate video end
if storyboard.params.get('video_duration') is None:
    # estimate duration from audio file
    audio_fpath = storyboard.params['audio_fpath']
    storyboard.params['video_duration'] = get_audio_duration_seconds(audio_fpath)

if storyboard.params.get('video_duration') is None:
    raise RuntimeError('unable to determine audio duration. was a video url or path to a file supplied?')

# force use
storyboard.params.whisper_seg = True

with open(storyboard_fname,'wb') as fp:
    OmegaConf.save(config=storyboard, f=fp.name)

whisper_seg = storyboard.params.whisper_seg

###################################################
# 💬 Transcribe and segment speech using whisper #
###################################################

# handle OOM... or try to, anyway
if 'hf_helper' in locals():
    del hf_helper.img2img
    del hf_helper.text2img
    del hf_helper


if whisper_seg:
    # from vktrs.asr import (
    #     #whisper_lyrics,
    #     #whisper_transcribe,
    #     #whisper_align,
    #     whisper_transmit_meta_across_alignment,
    #     whisper_segment_transcription,
    # )

    # #prompt_starts = whisper_lyrics(audio_fpath=storyboard.params.audio_fpath)

    audio_fpath = storyboard.params.audio_fpath
    #whispers = whisper_transcribe(audio_fpath)




    # [experimental] whisper alt

    import json

    def calculate_interword_gaps(segment):
        end_prev = -1
        gaps = []
        #for seg in timings['segments']:
        for word in segment['words']:
            if end_prev < 0:
                end_prev = word['end']
                continue 
            gap = word['start'] - end_prev
            gaps.append(gap)
            end_prev = word['end']
        return gaps 

    def trivial_subsegmentation(segment, threshold=0, gaps=None):
        """
        split on gaps in detected vocal activity. 
        Contiguity = gap between adjacent tokens is less than the input threshold.
        """
        if gaps is None:
            gaps = calculate_interword_gaps(seg)
        out_segments = []
        this_segment = [seg['words'][0]]
        for word, preceding_pause in zip(seg['words'][1:], gaps):
            if preceding_pause <= threshold:
                this_segment.append(word)
            else:
                out_segments.append(this_segment)
                this_segment = [word]
        out_segments.append(this_segment)
        #return out_segments
        outv = [dict(
            start=seg[0]['start'],
            end=seg[-1]['end'],
            text=''.join([w['word'] for w in seg]).strip(),
        ) for seg in out_segments]
        # fuck it...
        for rec in outv:
            rec['ts'] = rec['start']
            rec['prompt'] = rec['text']
        return outv

    #trivial_subsegmentation(seg)

    audio_fpath = Path(storyboard.params.audio_fpath)

    !whisper --model large --word_timestamps True {storyboard.params.audio_fpath}
    # outputs text files as audio.* locally


    with Path('audio.json').open() as f:
        timings = json.load(f)

    storyboard.prompt_starts = timings['segments']

    # segments = []
    # for seg in timings['segments']:
    #     segments.extend(trivial_subsegmentation(seg))
    # storyboard.prompt_starts = segments
    # storyboard.prompt_starts

    prompt_starts = storyboard.prompt_starts



    # to do: dropdown selectors
    # segmentation_model = 'tiny'
    # transcription_model = 'large'

    # storyboard.params.whisper = dict(
    #     segmentation_model = segmentation_model
    #     ,transcription_model = transcription_model
    # )

    # whispers = {
    #     #'tiny':None, # 5.83 s
    #     #'large':None # 3.73 s
    # }
    # # accelerated runtime required for whisper
    # # to do: pypi package for whisper

    # # to do: use transcripts we've already built if we have them
    # #scripts = storyboard.params.whisper.get('transcriptions')
    
    # for k in set([segmentation_model, transcription_model]):
    #     #if k in scripts:

    #     options = whisper.DecodingOptions(
    #         language='en',
    #     )
    #     # to do: be more proactive about cleaning up these models when we're done with them
    #     model = whisper.load_model(k).to('cuda')
    #     start = time.time()
    #     print(f"Transcribing audio with whisper-{k}")
        
    #     # to do: calling transcribe like this unnecessarily re-processes audio each time.
    #     whispers[k] = model.transcribe(audio_fpath) # re-processes audio each time, ~10s overhead?
    #     print(f"elapsed: {time.time()-start}")
    #     del model
    #     gc.collect()
    

    ## TO DO
    #######################
    # save transcriptions #
    #######################

    # transcriptions = {}
    # transcription_root = root / 'whispers'
    # transcription_root.mkdir(parents=True, exist_ok=True)
    # writer = whisper.utils.get_writer(output_format='vtt', output_dir=transcription_root) # output dir doesn't do anything...?
    # for k in whispers:
    #     outpath = str( transcription_root / f"{k}.vtt" )
    #     transcriptions[k] = outpath
    #     with open(outpath,'w') as f:
    #         # to do: upstream PR to control verbosity
    #         writer.write_result(
    #             whispers[k],
    #             file=f,
    #         )
    # storyboard.params.whisper.transcriptions = transcriptions

    #tiny2large, large2tiny, whispers_tokens = whisper_align(whispers)
    # sanitize and tokenize
    # whispers_tokens = {}
    # for k in whispers:
    #     whispers_tokens[k] = [
    #     remove_punctuation(tok) for tok in whispers[k]['text'].split()
    #     ]

    # # align sequences
    # tiny2large, large2tiny = tokenizations.get_alignments(
    #     whispers_tokens[segmentation_model], #whispers_tokens['tiny'],
    #     whispers_tokens[transcription_model] #whispers_tokens['large']
    # )
    # #return tiny2large, large2tiny, whispers_tokens

    # token_large_index_segmentations = whisper_transmit_meta_across_alignment(
    #     whispers,
    #     large2tiny,
    #     whispers_tokens,
    # )
    # prompt_starts = whisper_segment_transcription(
    #     token_large_index_segmentations,
    # )


    ### checkpoint the processing work we've done to this point

    prompt_starts_copy = copy.deepcopy(prompt_starts)
    
    # to do: deal with timedeltas in asr.py and yt.py
    # for rec in prompt_starts_copy:
    #     for k,v in list(rec.items()):
    #         if isinstance(v, dt.timedelta):
    #             rec[k] = v.total_seconds()
    
    storyboard.prompt_starts = prompt_starts_copy

    with open(storyboard_fname) as fp:
        OmegaConf.save(config=storyboard, f=fp.name)

###############################
# Review/Modify transcription #
###############################

# @markdown ---
# @markdown NB: When this cell finishes running, a table will appear
# @markdown at the bottom of the output window. This table is editable
# @markdown and can be used to correct errors in the transcription.
# @markdown
# @markdown additionally, the `override_prompt` field can be used to provide an 
# @markdown alternative text prompt for image generation. If this feature is
# @markdown used, both the lyric and the theme prompt (which you will specify 
# @markdown in the cell that follows this) will be ignored. If you want to use
# @markdown an `override_prompt` and also want to stay on theme, you will have 
# @markdown to append the desired `theme_prompt` to the end of the 
# @markdown `override_prompt` manually.


# https://panel.holoviz.org/reference/widgets/Tabulator.html
pn.extension('tabulator') # I don't know that specifying 'tabulator' here is even necessary...

tabulator_formatters = {
    'bool': {'type': 'tickCross'}
}

# reset workspace
if 'df_regen' in locals():
    del df_regen

df = pd.DataFrame(prompt_starts).rename(
    columns={
        'ts':'Timestamp (sec)',
        'prompt':'Lyric',
    }
)

if 'td' in df:
    del df['td']

df['override_prompt'] = ''

df_pre = copy.deepcopy(df)
outv = pn.widgets.Tabulator(df, formatters=tabulator_formatters)
if local: # really the issue isn't "local", it's VSCode
    outv = df
outv


In [None]:
try_to_segment_further = True

if try_to_segment_further:
    segments = []
    for seg in timings['segments']:
        segments.extend(trivial_subsegmentation(seg))

pd.DataFrame(segments)

In [None]:
use_new_segmentation = True
if use_new_segmentation:
    prompt_starts=segments
    storyboard.prompt_starts = prompt_starts

    with open(storyboard_fname) as fp:
        OmegaConf.save(config=storyboard, f=fp.name)

## $3.$ 🎬 Animate

In [None]:
import copy
import datetime as dt
from pathlib import Path
import random
import string

from bokeh.models.widgets.tables import (
    NumberFormatter, 
    BooleanFormatter,
    CheckboxEditor,
)
import numpy as np
from omegaconf import OmegaConf, DictConfig
import pandas as pd
import panel as pn
import PIL
from tqdm.autonotebook import tqdm

from vktrs.tsp import (
    tsp_permute_frames,
    batched_tsp_permute_frames,
)

from vktrs.utils import (
    add_caption2image,
    save_frame,
    remove_punctuation,
    get_image_sequence,
    archive_images,
)

# to do: is there a way to check if this is in the env already?
pn.extension('tabulator')

# this processes optional edits to the transcription (above) 
if ('prompt_starts' in locals()) \
and ('df_pre' in locals()):
    if isinstance(prompt_starts, DictConfig):
        prompt_starts = OmegaConf.to_container(prompt_starts)
    # update prompt_starts if any changes were made above
    if not np.all(df_pre.values == df.values):
        df_pre = copy.deepcopy(df)
        for i, rec in enumerate(prompt_starts):
            rec['ts'] = float(df.loc[i,'Timestamp (sec)'])
            rec['prompt'] = df.loc[i,'Lyric']
            rec['override_prompt'] = df.loc[i,'override_prompt']
        
        # ...actually, I think the above code might not do anything
        # probably need to checkpoint prompt_starts into the storyboard on disk.
        # let's do that here just to be safe.    
        workspace = OmegaConf.load('config.yaml')
        root = Path(workspace.project_root)

        storyboard_fname = root / 'storyboard.yaml'
        storyboard = OmegaConf.load(storyboard_fname)

        storyboard.prompt_starts = prompt_starts
        with open(storyboard_fname) as fp:
            OmegaConf.save(config=storyboard, f=fp.name)


#####################################
# @title ## 🎨 Generate init images
#####################################

workspace = OmegaConf.load('config.yaml')
root = Path(workspace.project_root)

storyboard_fname = root / 'storyboard.yaml'
storyboard = OmegaConf.load(storyboard_fname)

prompt_starts = storyboard.prompt_starts
use_stability_api = workspace.use_stability_api
model_dir = workspace.model_dir


# scavenged from hf.py

from pathlib import Path
import torch
from torch import autocast
from diffusers import (
    StableDiffusionImg2ImgPipeline,
    StableDiffusionPipeline,
)
from diffusers.models import AutoencoderKL


class HfHelper:
    def __init__(
        self,
        device = 'cuda',
        device_img2img = None,
        device_text2img = None,
        model_path = '.',
        #model_id = "CompVis/stable-diffusion-v1-4",
        model_id = "CompVis/stable-diffusion-v1-5",
        download=True,
    ):
        if not device_img2img:
            device_img2img = device
        if not device_text2img:
            device_text2img = device
        self.device = device
        self.device_img2img = device_img2img
        self.device_text2img = device_text2img
        self.model_path = model_path
        self.model_id = model_id
        self.download = download
        self.load_pipelines()

    def load_pipelines(
        self,
    ):
        vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
        if self.download:
            img2img = StableDiffusionImg2ImgPipeline.from_pretrained(
                self.model_id,
                revision="fp16", 
                torch_dtype=torch.float16,
                use_auth_token=True,
                vae=vae
            )
            img2img = img2img.to(self.device)
            img2img.save_pretrained(self.model_path)
        else:
            img2img = StableDiffusionImg2ImgPipeline.from_pretrained(
                self.model_path,
                local_files_only=True,
                vae=vae,
            ).to(self.device)

        text2img = StableDiffusionPipeline(
            vae=img2img.vae,
            text_encoder=img2img.text_encoder,
            tokenizer=img2img.tokenizer,
            unet=img2img.unet,
            feature_extractor=img2img.feature_extractor,
            scheduler=img2img.scheduler,
            safety_checker=img2img.safety_checker,
        )
        #return text2img, img2img
        text2img.enable_attention_slicing()
        img2img.enable_attention_slicing()
        self.text2img = text2img
        self.img2img = img2img

    def get_image_for_prompt(
        self,
        prompt,
        **kwargs
    ):
        f = self.text2img if kwargs.get('image') is None else self.img2img
        #if kwargs.get('image_consistency') is not None:
        #kwargs['strength'] = 1- kwargs['image_consistency'] 
        if kwargs.get('start_schedule') is not None:
            #kwargs['strength'] = kwargs['start_schedule']
            kwargs['strength'] = kwargs.pop('start_schedule')
        with autocast(self.device):
            return f(prompt, **kwargs)




if use_stability_api:
    from vktrs.api import get_image_for_prompt
elif 'hf_helper' not in locals():
    #from vktrs.hf import HfHelper
    # this needs to not be in the same cell as the login.
    # some sort of stupid race condition.
    try:
        hf_helper = HfHelper(
            download=False,
            model_path=str(Path(model_dir) / 'huggingface' / 'diffusers')
        )
    except:
        hf_helper = HfHelper(
            download=True,
            model_path=str(Path(model_dir) / 'huggingface' / 'diffusers')
        )
    # I give up.
    def get_image_for_prompt(*args, **kargs):
        # ugly hotfix. todo: unify function signatures...
        if 'init_image' in kargs:
            img = kargs.pop('init_image')
            if img:
                kargs['image'] = img
        result = hf_helper.get_image_for_prompt(*args, **kargs)
        return result.images


def get_variations_w_init(prompt, init_image, **kargs):
    return list(get_image_for_prompt(prompt=prompt, init_image=init_image, **kargs))

def get_close_variations_from_prompt(prompt, n_variations=2, image_consistency=.7):
    """
    prompt: a text prompt
    n_variations: total number of images to return
    image_consistency: float in [0,1], controls similarity between images generated by the prompt.
                        you can think of this as controlling how much "visual vibration" there will be.
                        - 0=regenerate each iandely identical
    """
    images = list(get_image_for_prompt(prompt))
    for _ in range(n_variations - 1):
        img = get_variations_w_init(prompt, images[0], start_schedule=(1-image_consistency))[0]
        images.append(img)
    return images


d_ = dict(
    _=''
    , theme_prompt = "retrofuturism travel poster" # @param {type:'string'}
    , height = 512 # @param {type:'integer'}
    , width = 512 # @param {type:'integer'}
    , display_frames_as_we_get_them = True # @param {type:'boolean'}
)
d_.pop('_')

regenerate_all_init_images = False # @param {type:'boolean'}

prompt_lag = True # @param {type:'boolean'}

# @markdown `theme_prompt` - Text that will be appended to the end of each lyric, useful for e.g. applying a consistent aesthetic style

# @markdown `display_frames_as_we_get_them` - Displaying frames will make the notebook slightly slower

# regenerate all images if the theme prompt has changed or user specifies

# @markdown `prompt_lag` - Extend prompt with lyrics from previous frame. Can improve temporal consistency of narrative. 
# @markdown  Especially useful for lyrics segmented into short prompts.

if d_['theme_prompt'] != storyboard.params.get('theme_prompt'):
    regenerate_all_init_images = True

storyboard.params.update(d_)

if regenerate_all_init_images:
    for i, rec in enumerate(prompt_starts):
        rec['frame0_fpath'] = None
        archive_images(i, root=root)
    print("archival process complete")

# anchor images will be regenerated if there's no associated frame0_fpath
# regenerate specific images if
# * manually tagged by user in df_regen
# * associated fpath doesn't exist (i.e. deleted)
if 'df_regen' in locals():
    for i, _ in df_regen.iterrows():
        rec = prompt_starts[i]
        regen = not _['keep']
        if rec.get('frame0_fpath') is None:
            regen = True
        elif not Path(rec['frame0_fpath']).exists():
            regen=True
        if regen:
            rec['frame0_fpath'] = None
            rec['prompt'] = df_regen.loc[i, 'Lyric']
            rec['override_prompt'] = df_regen.loc[i, 'override_prompt']
            print(rec)
            archive_images(i, root=root)
    print("archival process complete")


theme_prompt = storyboard.params.theme_prompt
display_frames_as_we_get_them = storyboard.params.display_frames_as_we_get_them
height = storyboard.params.height
width = storyboard.params.width

proj_name = workspace.active_project

# TODO: this is just for development
max_idx = 10

print("Ensuring each prompt has an associated image")
for idx, rec in enumerate(prompt_starts):
    if idx> max_idx: break # TODO: remove this when done with development
    lyric = rec['prompt']
    prompt = f"{lyric}, {theme_prompt}"
    override = rec.get('override_prompt','').strip()
    if override:
        print('override prompt detected')
        prompt = override
    #print(
    #    f"\n[{idx} | {rec['ts']}] - {lyric} - {prompt}"
    #)
    
    if prompt_lag and (idx > 0):
        rec_prev = prompt_starts[idx -1]
        prev_prompt = rec_prev.get('override_prompt','').strip()
        if not prev_prompt:
            prev_prompt = rec_prev['prompt']
        prompt = f"{prev_prompt}, {prompt}"
    print(
        f"\n[{idx} | {rec['ts']}] - {lyric} - {prompt}"
    )
    if rec.get('frame0_fpath') is None:
        init_image = list(get_image_for_prompt(
              prompt,
              height=height,
              width=width,
              )
          )[0]
    # this shouldn't be necessary, but is a consequence of
    # the globbing thing we're doing atm
    if 'anchor' not in str(rec.get('frame0_fpath')):
        rec['frame0_fpath'] = save_frame(
            init_image,
            idx,
            root_path = root / 'frames',
            name='anchor',
            )

        if display_frames_as_we_get_them:
            print(lyric)
            display(init_image)


##############
# checkpoint #
##############

prompt_starts_copy = copy.deepcopy(prompt_starts)
storyboard.prompt_starts = prompt_starts_copy

with open(storyboard_fname) as fp:
    OmegaConf.save(config=storyboard, f=fp.name)


###############
# flag regens #
###############

# @markdown ---
# @markdown NB: When this cell finishes running, a table will appear at the bottom of the output window. This table is editable and can be used to correct errors in the transcription (see above).

# @markdown Additionally, this table can be used to trigger regeneration of
# @markdown images you don't want to keep. On the far left of the table, you
# @markdown you should see a `keep` column that defaults to "true". Double 
# @markdown clicking this value should flip it to "false". Rerunning this cell
# @markdown will regenerate the `init_image` for all scenes where `keep=false`.
# @markdown Images that are flagged for regeneration will be moved to the
# @markdown project's `archive` folder.

# @markdown Image regeneration can also be triggered by deleting the image from 
# @markdown the `frames` folder.


df_regen = pd.DataFrame(prompt_starts)
if 'override_prompt' not in df_regen:
    df_regen['override_prompt'] = ''

df_regen = df_regen[['ts','prompt','override_prompt']].rename(
    columns={
        'ts':'Timestamp (sec)',
        'prompt':'Lyric',
    }
)



df_regen['keep'] = True

# move the "keep" column to the front
df_regen= df_regen[['keep', 'Timestamp (sec)', 'Lyric', 'override_prompt']]

pn.widgets.Tabulator(
    df_regen,
    formatters={'bool': BooleanFormatter()},
    editors={'bool':CheckboxEditor()}
    )


In [None]:
from diffusers import EulerAncestralDiscreteScheduler

# scavenged from img2img pipeline __call__
from typing import Union, List, Optional, Dict, Any, Callable
@torch.no_grad()
def img2img__call__(
    self=hf_helper.img2img,
    prompt: Union[str, List[str]] = None,
    image: Union[torch.FloatTensor, PIL.Image.Image] = None,
    image_latent=None,
    #prompt_embeds=None,
    strength: float = 0.8,
    num_inference_steps: Optional[int] = 50,
    guidance_scale: Optional[float] = 7.5,
    negative_prompt: Optional[Union[str, List[str]]] = None,
    num_images_per_prompt: Optional[int] = 1,
    eta: Optional[float] = 0.0,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = True,
    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
    callback_steps: int = 1,
    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
    r"""
    Function invoked when calling the pipeline for generation.

    Args:
        prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
            instead.
        image (`torch.FloatTensor` or `PIL.Image.Image`):
            `Image`, or tensor representing an image batch, that will be used as the starting point for the
            process.
        strength (`float`, *optional*, defaults to 0.8):
            Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
            will be used as a starting point, adding more noise to it the larger the `strength`. The number of
            denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
            be maximum and the denoising process will run for the full number of iterations specified in
            `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
        num_inference_steps (`int`, *optional*, defaults to 50):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
            expense of slower inference. This parameter will be modulated by `strength`.
        guidance_scale (`float`, *optional*, defaults to 7.5):
            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
            `guidance_scale` is defined as `w` of equation 2. of [Imagen
            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
            usually at the expense of lower image quality.
        negative_prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts not to guide the image generation. If not defined, one has to pass
            `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
            is less than `1`).
        num_images_per_prompt (`int`, *optional*, defaults to 1):
            The number of images to generate per prompt.
        eta (`float`, *optional*, defaults to 0.0):
            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
            [`schedulers.DDIMScheduler`], will be ignored for others.
        generator (`torch.Generator`, *optional*):
            One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
            to make generation deterministic.
        prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
            provided, text embeddings will be generated from `prompt` input argument.
        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
            weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
            argument.
        output_type (`str`, *optional*, defaults to `"pil"`):
            The output format of the generate image. Choose between
            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
            plain tuple.
        callback (`Callable`, *optional*):
            A function that will be called every `callback_steps` steps during inference. The function will be
            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
        callback_steps (`int`, *optional*, defaults to 1):
            The frequency at which the `callback` function will be called. If not specified, the callback will be
            called at every step.
        cross_attention_kwargs (`dict`, *optional*):
            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
            `self.processor` in
            [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
    Examples:

    Returns:
        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
        When returning a tuple, the first element is a list with the generated images, and the second element is a
        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
        (nsfw) content, according to the `safety_checker`.
    """

    # 
    self.scheduler = EulerAncestralDiscreteScheduler.from_config(self.scheduler.config, device=self.device)

    # 1. Check inputs. Raise error if not correct
    self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)

    # 2. Define call parameters
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]
    device = self._execution_device
    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    do_classifier_free_guidance = guidance_scale > 1.0

    # 3. Encode input prompt
    #if prompt_embeds is None:
    prompt_embeds = self._encode_prompt(
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
    )

    # 5. set timesteps
    self.scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)

    if image_latent is None:
        # 4. Preprocess image
        image = self.image_processor.preprocess(image)

        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

        # 6. Prepare latent variables
        latents = self.prepare_latents(
            image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
        )
    else:
        latents = image_latent
    in_latent = latents.detach().clone()

    # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

    # 8. Denoising loop
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            
            latent_model_input=latent_model_input.to(self.device)
            #print(latent_model_input.device)
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=prompt_embeds,
                cross_attention_kwargs=cross_attention_kwargs,
            ).sample

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                progress_bar.update()
                if callback is not None and i % callback_steps == 0:
                    callback(i, t, latents)


    image = self.decode_latents(latents)
    image = self.image_processor.postprocess(image, output_type=output_type)

    # Offload last model to CPU
    # TODO: ...? is this something I want? guessing not...
    if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
        self.final_offload_hook.offload()

    #return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
    return dict(image=image, in_latent=in_latent, out_latent=latents, prompt_embeds=prompt_embeds)

In [None]:

# @title ## 🚀 Generate animation frames

###################
# improved resume #
###################

import copy
import datetime as dt
from itertools import cycle
from pathlib import Path

from omegaconf import OmegaConf
from PIL import Image
from vktrs.utils import (
    add_caption2image,
    get_image_sequence,
)


workspace = OmegaConf.load('config.yaml')
root = Path(workspace.project_root)

storyboard_fname = root / 'storyboard.yaml'
storyboard = OmegaConf.load(storyboard_fname)

if not 'prompt_starts' in locals():
    prompt_starts = OmegaConf.to_container(storyboard.prompt_starts)
else:
    ##########################
    # checkpoint any changes #
    ##########################
    prompt_starts_copy = copy.deepcopy(prompt_starts)

    storyboard.prompt_starts = prompt_starts_copy

    with open(storyboard_fname) as fp:
        OmegaConf.save(config=storyboard, f=fp.name)


#################################################
# Math                                          #
#                                               #
#    This block computes how many frames are    #
#    needed for each segment based on the start #
#    times for each prompt                      #
#################################################

# to do: 
# * make this more portable and add to vktrs lib

fps = 12 # @param {type:'integer'}
storyboard.params.fps = fps

ifps = 1/fps

# estimate video end
video_duration = storyboard.params['video_duration']

# dummy prompt for last scene duration
prompt_starts = OmegaConf.to_container(storyboard.prompt_starts)
prompt_starts.append({'ts':video_duration})

# make sure we respect the duration of the previous phrase
frame_start=0
prompt_starts[0]['anim_start']=frame_start
for i, rec in enumerate(prompt_starts[1:], start=1):
    rec_prev = prompt_starts[i-1]
    k=0
    while (rec_prev['anim_start'] + k*ifps) < rec['ts']:
        k+=1
    k-=1
    rec_prev['frames'] = k
    rec_prev['anim_duration'] = k*ifps
    frame_start+=k*ifps
    rec['anim_start']=frame_start

# drop the dummy frame
prompt_starts = prompt_starts[:-1]

# to do: given a 0 duration prompt, assume its duration is captured in the next prompt 
#        and guesstimate a corrected prompt start time and duration 


##############
# checkpoint #
##############

prompt_starts_copy = copy.deepcopy(prompt_starts)

storyboard.prompt_starts = prompt_starts_copy

with open(storyboard_fname) as fp:
    OmegaConf.save(config=storyboard, f=fp.name)


##################################
# Generate animation frames #
##################################

d_ = dict(
    _=''
    , n_variations=30 # @param {type:'integer'}
    , image_consistency=0.7 # @param {type:"slider", min:0, max:1, step:0.01}  
    , max_video_duration_in_seconds = 300 # @param {type:'integer'}
)
d_.pop('_')


# @markdown `fps` - Frames-per-second of generated animations

# @markdown `n_variations` - How many unique variations to generate for a given text prompt. This determines the frequency of the visual "pulsing" effect

# @markdown `image_consistency` - controls similarity between images generated by the prompt.
# @markdown - 0: ignore the init image
# @markdown - 1: true as possible to the init image

# @markdown `max_video_duration_in_seconds` - Early stopping if you don't want to generate a video the full duration of the provided audio. Default = 5min.


storyboard.params.update(d_)
storyboard.params.max_frames = storyboard.params.fps * storyboard.params.max_video_duration_in_seconds

# to do: compute and report unique of image generations

display_frames_as_we_get_them = storyboard.params.display_frames_as_we_get_them
image_consistency = storyboard.params.image_consistency
max_frames = storyboard.params.max_frames

n_variations = storyboard.params.n_variations
theme_prompt = storyboard.params.get('theme_prompt')


##################################

# ooo ain't I fancy

import random
from scipy.spatial.distance import cosine

def cosine_similarity(u,v):
    return 1-cosine(u,v)

def get_image_sequence(idx, root):
    print("getting image sequence)")
    root = Path(root)
    images = (root / 'frames' ).glob(f'{idx}-*.png')
    images = sorted(list(images), key=os.path.getmtime)
    return images

def importance_sample_variation(
    idx, root, 
    prompt, image_consistency,
    n_context=3, n_proposals=2,
    n_steps_per_variation=50,
):
    print(f"getting image sequence for {idx}")
    images = get_image_sequence(idx, root)
    print([str(imfpath) for imfpath in images])
    context = images[-n_context:]
    anchor_fpath = random.choice(context)
    print(f"using anchor image:{anchor_fpath}")
    anchor = Image.open(anchor_fpath)    
    #anchor_npy = np.array(anchor).ravel() 
    anchor_latent = None

    previous_fpath = images[-1]
    print(f"previous image:{previous_fpath}")
    previous = Image.open(previous_fpath)
    #previous_npy = np.array(previous).ravel() 
    # TODO: super inefficient, just need the latent for the previous image.
    prev_dummy = img2img__call__(
            prompt=prompt, 
            image=previous, 
            strength=(1-image_consistency),
            #image_latent=anchor_latent,
        )
    previous_latent = prev_dummy["in_latent"].cpu().ravel()

    best_proposal = None
    best_similarity = 0
    # generate proposals from chosen anchor image, select best relative to most recent previous
    for j in range(n_proposals):
        print(f"proposal {j}")
        #proposed_img = get_variations_w_init(prompt, anchor, start_schedule=(1-image_consistency))[0]
        proposal = img2img__call__(
            prompt=prompt, 
            image=anchor, 
            strength=(1-image_consistency),
            image_latent=anchor_latent,
            num_inference_steps=n_steps_per_variation,
        )
        proposed_img = proposal['image'][0]
        if anchor_latent is None:
            anchor_latent = proposal['in_latent'].detach().clone()#.cpu().ravel()
    
        proposed_latent = proposal['out_latent']#.detach().clone()#.cpu().ravel()

        
        display(proposed_img)
        #proposed_npy = np.array(proposed_img).ravel()
        #proposed_similarity = cosine_similarity(proposed_npy, anchor_npy)
        #proposed_similarity = cosine_similarity(proposed_npy, previous_npy)
        proposed_similarity = cosine_similarity(proposed_latent.detach().clone().cpu().ravel(), previous_latent)
        print(f"score: {proposed_similarity}")
        if proposed_similarity > best_similarity:
            best_similarity = proposed_similarity
            best_proposal = proposed_img
            print("proposal accepted")
        else:
            print("proposal rejected")
    return best_proposal


n_context=3
n_proposals=3
use_importance_sampling=True
n_steps_per_variation=30

##################################

# load init_images and generate variations as needed
# to do: request multiple images in single request
print("Fetching variations")
for idx, rec in enumerate(prompt_starts):
    new_images = []
    images_fpaths = get_image_sequence(idx, root=root)
    curr_variation_count = len(images_fpaths)
    print(curr_variation_count)
    if curr_variation_count < n_variations:
        # to do: 
        # * prompt lag
        lyric = rec['prompt']
        prompt = f"{lyric}, {theme_prompt}"
        if rec.get('override_prompt'):
            prompt = rec['override_prompt']


        # next line is here to permit user to specify more variations for a specific entry
        tot_variations = rec.get('n_variations', n_variations)
        tot_variations = min(tot_variations, rec['frames']) # don't generate variations we won't use
        tot_variations -= curr_variation_count  # only generate variations we still need

        if not use_importance_sampling:
            init_image = Image.open(rec['frame0_fpath'])

        for _ in range(tot_variations):

            if use_importance_sampling:
                # TODO: generate proposals in batch
                img = importance_sample_variation(
                    idx, root, 
                    prompt, image_consistency,
                    n_context=n_context, n_proposals=n_proposals, n_steps_per_variation=n_steps_per_variation,
                )
            else:
                img = get_variations_w_init(prompt, init_image, start_schedule=(1-image_consistency))[0]

            save_frame(
                img,
                idx,
                root_path= root / 'frames',
            )
            if display_frames_as_we_get_them:
                display(img)


##############
# checkpoint #
##############

prompt_starts_copy = copy.deepcopy(prompt_starts)

storyboard.prompt_starts = prompt_starts_copy

# to do: deal with these td objects
with open(storyboard_fname) as fp:
    OmegaConf.save(config=storyboard, f=fp.name)

# @markdown ---
# @markdown Running this cell will generate as many variation frames as required 
# @markdown per `n_variations`. To trigger regeneration of images that didn't
# @markdown generate correctly (e.g. because a nsfw classifier was triggered),
# @markdown just delete those images.

In [None]:
# frames_m = np.array([np.array(f).ravel() for f in frames])
# dmat = pdist(frames_m, metric='cosine')

In [None]:
storyboard.params.n_variations = 7

In [None]:
# @title ## 🎞️ Compile your video!

import shutil
from subprocess import Popen, PIPE

from pathlib import Path
from PIL import Image
from itertools import cycle

from omegaconf import OmegaConf
from tqdm.autonotebook import tqdm

try:
    from prefetch_generator import BackgroundGenerator
except:
    !pip install prefetch_generator
    from prefetch_generator import BackgroundGenerator

from vktrs.tsp import (
    tsp_permute_frames,
    batched_tsp_permute_frames,
)

from vktrs.utils import (
    add_caption2image,
    get_image_sequence,
    save_frame,
    remove_punctuation,
)


# reload config
workspace = OmegaConf.load('config.yaml')
root = Path(workspace.project_root)

# storyboard_fname = root / 'storyboard.yaml'
# storyboard = OmegaConf.load(storyboard_fname)

########################
# rendering parameters #
########################

output_filename = 'output.mp4' # @param {type:'string'}
add_caption = False # @param {type:'boolean'}
optimal_ordering = True # @param {type:'boolean'}
upscale = True # @param {type:'boolean'}

# @markdown `add_caption` - Whether or not to overlay the prompt text on the image

# @markdown `optimal_ordering` - Intelligently permutes animation frames to provide a smoother animation.

# @markdown  `upscale`: Naively (lanczos interpolation) upscale video 2x. This can be a way to force
# @markdown  services like youtube to deliver your video without mangling it with compression
# @markdown  artifacts. Thanks [@gandamu_ml](https://twitter.com/gandamu_ml) for this trick!


# this parameter is currently not exposed in the form
max_variations_per_opt_pass = 15

if optimal_ordering:
    opt_batch_size = min(storyboard.params.n_variations, max_variations_per_opt_pass)

# I think it might be more efficient to write the video to the local disk first, then move it
# afterwards, rather than writing into google drive
final_output_filename = str( root / output_filename )
storyboard.params.output_filename = final_output_filename

# to do: move/duplicate fps computations here (?)
fps = storyboard.params.fps
input_audio = storyboard.params.audio_fpath

#####################################

# helper function for readability
def process_sequence(idx):
    im_paths = get_image_sequence(idx, root)
    images = [Image.open(fp) for fp in im_paths]
    
    if add_caption:
        rec = prompt_starts[idx]
        images = [add_caption2image(im, rec['prompt']) for im in images]

    # to do: persist the ordering in the storyboard
    if optimal_ordering:
        images = batched_tsp_permute_frames(
            images,
            #max_variations_per_opt_pass
            opt_batch_size
        )
    return images

############################

cmd_in = ['ffmpeg', '-y', '-f', 'image2pipe', '-vcodec', 'png', '-r', str(fps), '-i', '-']
cmd_out = ['-vcodec', 'libx264', '-r', str(fps), '-pix_fmt', 'yuv420p', '-crf', '1', '-preset', 'veryslow', '-shortest', output_filename]

if input_audio:
  cmd_in += ['-i', str(input_audio)]

# NB: it might be more efficient to perform this upscaling step as a 
# separate step after compiling the video frames
if upscale:
    height=storyboard.params.height
    width=storyboard.params.width
    cmd_out = ['-vf', f'scale={2*width}x{2*height}:flags=lanczos'] + cmd_out


cmd = cmd_in + cmd_out

prompt_starts = storyboard.prompt_starts
batch_gen = BackgroundGenerator(
    [(idx, rec, process_sequence(idx))
        for idx, rec in enumerate(prompt_starts)]
    ,max_prefetch=2)

p = Popen(cmd, stdin=PIPE)
for idx, rec, batch in tqdm(batch_gen, total=len(prompt_starts)): 
    frame_factory = cycle(batch)
    k = 0
    while k < rec['frames']:
        try:
            im = next(frame_factory)
        except StopIteration:
            break
        im.save(p.stdin, 'PNG')
        k+=1
p.stdin.close()

print("Encoding video...")
p.wait()

if output_filename != final_output_filename:
    print(f"Local video compilation complete. Moving video to: {final_output_filename}")
    shutil.move(output_filename, final_output_filename)
print("Video complete.")

In [None]:
# @title ## 📺 Enjoy your animation!

# to do: Merge with 'compile' cell?

output_filename = storyboard.params.output_filename

download_video = True # @param {type:'boolean'}
compress_video = False # @param {type:'boolean'}

# @markdown Compressing to `*.tar.gz`` format can reduce filesize, which in turn reduces
# @markdown your download time. You may need to install additional software
# @markdown to "decompress" the file after downloading to view your video.

# @markdown NB: Your video will probably download way faster from https://drive.google.com

#  NB: only embed short videos
embed_video_in_notebook = False

if compress_video:
    uncompressed_fname = output_filename
    output_filename = f"{output_filename}.tar.gz"
    print(f"Compressing to: {output_filename}")
    !tar -czvf {output_filename} {uncompressed_fname}

if download_video:
    from google.colab import files
    files.download(output_filename)

if embed_video_in_notebook:
    from IPython.display import display, Video
    display(Video(output_filename, embed=True))

# ⚖️ I put on my robe and lawyer hat

### Notebook license

This notebook and the accompanying [git repository](https://github.com/dmarx/video-killed-the-radio-star/) and its contents are shared under the MIT license.

<!-- Note to self: lawyers should really be forced to use some sort of markup or pseudocode to eliminate ambiguity 

...oh shit, if laws were actually described in code, we could just run queries against it
-->

```
MIT License

Copyright (c) 2022 David Marx

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
```

### DreamStudio API TOS

The default behavior of this notebook uses the [DreamStudio](https://beta.dreamstudio.ai/) API to generate images. Users of the DreamStudio API are subject to the DreamStudio usage terms: https://beta.dreamstudio.ai/terms-of-service

### Stable Diffusion

As of the date of this writing (2022-09-29), all publicly available model checkpoints are subject to the restrictions of the Open RAIL license: https://huggingface.co/spaces/CompVis/stable-diffusion-license. 

