# $ \text{Video Killed The Radio Star}$ $\color{red}{...Diffusion}$

Notebook by David Marx ([@DigThatData](https://twitter.com/digthatdata))

Shared under MIT license


# $\text{FAQ}$

**What is this?**

Point this notebook at a youtube url and it'll make a music video for you.

**How does this animation technique work?**

For each text prompt you provide, the notebook will...

1. Generate an image based on that text prompt (using stable diffusion)
2. Use the generated image as the `init_image` to recombine with the text prompt to generate variations similar to the first image. This produces a sequence of extremely similar images based on the original text prompt
3. Images are then intelligently reordered to find the smoothest animation sequence of those frames
3. This image sequence is then repeated to pad out the animation duration as needed

The technique demonstrated in this notebook was inspired by a [video](https://www.youtube.com/watch?v=WJaxFbdjm8c) created by Ben Gillin.

**How are lyrics transcribed?**

This notebook uses openai's recently released 'whisper' model for performing automatic speech recognition. 
OpenAI was kind enough to offer several different sizes of this model which each have their own pros and cons. 
This notebook uses the largest whisper model for transcribing the actual lyrics. Additionally, we use the 
smallest model for performing the lyric segmentation. Neither of these models is perfect, but the results 
so far seem pretty decent.

The first draft of this notebook relied on subtitles from youtube videos to determine timing, which was
then aligned with user-provided lyrics. Youtube's automated captions are powerful and I'll update the
notebook shortly to leverage those again, but for the time being we're just using whisper for everything
and not referencing user-provided captions at all.

**Something didn't work quite right in the transcription process. How do fix the timing or the actual lyrics?**

The notebook is divided into several steps. Between each step, a "storyboard" file is updated. If you want to
make modifications, you can edit this file directly and those edits should be reflected when you next load the
file. Depending on what you changed and what step you run next, your changes may be ignored or even overwritten.
Still playing with different solutions here.

**Can I provide my own images to 'bring to life' and associate with certain lyrics/sequences?**

Yes, you can! As described above: you just need to modify the storyboard. Will describe this functionality in
greater detail after the implementation stabilizes a bit more.

**This gave me an idea and I'd like to use just a part of your process here. What's the best way to reuse just some of the machinery you've developed here?**

Most of the functionality in this notebook has been offloaded to library I published to pypi called `vktrs`. I strongly encourage you to import anything you need 
from there rather than cutting and pasting function into a notebook. Similarly, if you have ideas for improvements, please don't hesitate to submit a PR!

**How can I support your work or work like it?**

This notebook was made possible thanks to ongoing support from [stability.ai](https://stability.ai/). The best way to support my work is to share it with your friends, [report bugs](https://github.com/dmarx/video-killed-the-radio-star/issues/new), [suggest features](https://github.com/dmarx/video-killed-the-radio-star/discussions) or to donate to open source non-profits :) 

## $0.$ Setup

In [None]:
# @title # 📊 Check GPU Status

import pandas as pd
import subprocess

def gpu_info():
    outv = subprocess.run([
        'nvidia-smi',
            # these lines concatenate into a single query string
            '--query-gpu='
            'timestamp,'
            'name,'
            'utilization.gpu,'
            'utilization.memory,'
            'memory.used,'
            'memory.free,'
            ,
        '--format=csv'
        ],
        stdout=subprocess.PIPE).stdout.decode('utf-8')

    header, rec = outv.split('\n')[:-1]
    return pd.DataFrame({' '.join(k.strip().split('.')).capitalize():v for k,v in zip(header.split(','), rec.split(','))}, index=[0]).T

gpu_info()

In [None]:
%%capture
# @title # 🛠️ Installations

try: 
    import google.colab
    local=False
except:
    local=True

# local only additional dependencies
if local:
    %pip install pandas torch pillow beautifulsoup4 scipy toolz numpy lxml

# dependencies for both colab and local
%pip install yt-dlp python-tsp stability-sdk diffusers transformers ftfy accelerate omegaconf
%pip install openai-whisper  panel prefetch_generator huggingface_hub ipywidgets

In [None]:
# @title # 🔑 Provide your API Key

# TODO: separate assets folder not attached to any project to reduce re-download burden

# @markdown Running this cell will prompt you to enter your API Key below. 

# @markdown To get your API key, visit https://beta.dreamstudio.ai/membership

# @markdown ---

# @markdown A note on security best practices: **don't publish your API key.**

# @markdown We're using a form field designed for sensitive data like passwords.
# @markdown This notebook does not save your API key in the notebook itself,
# @markdown but instead loads your API Key into the colab environment. This way,
# @markdown you can make changes to this notebook and share it without concern
# @markdown that you might accidentally share your API Key. 
# @markdown 

use_stability_api = True # @param {type:'boolean'}
mount_gdrive = True # @param {type:'boolean'}

try: 
    import google.colab
    local=False
except:
    local=True

if local:
    mount_gdrive=False

import os
from pathlib import Path
import time

from omegaconf import OmegaConf


os.environ['XDG_CACHE_HOME'] = os.environ.get(
    'XDG_CACHE_HOME',
    str(Path('~/.cache').expanduser())
)
if mount_gdrive:
    from google.colab import drive
    drive.mount('/content/drive')
    Path('/content/drive/MyDrive/AI/models/.cache/').mkdir(parents=True, exist_ok=True) 
    os.environ['XDG_CACHE_HOME']='/content/drive/MyDrive/AI/models/.cache'

model_dir_str=str(Path(os.environ['XDG_CACHE_HOME']))
proj_root_str = '${active_project}'
application_root = str(Path('.').absolute())
if mount_gdrive:
    #proj_root_str = '/content/drive/MyDrive/AI/VideoKilledTheRadioStar/${active_project}'
    application_root = '/content/drive/MyDrive/AI/VideoKilledTheRadioStar'


# notebook config
cfg = OmegaConf.create({
    'active_project':str(time.time()),
    #'project_root':proj_root_str,
    'application_root':application_root,
    'project_root':"${application_root}/${active_project}",
    'shared_assets_root':"${application_root}/shared_assets",
    'gdrive_mounted':mount_gdrive,
    'use_stability_api':use_stability_api,
    'model_dir':model_dir_str,
    #'output_dir':'${active_project}/frames' # this worked with gdrive mounted? maybe there's a `cd` invoked somewhere?
    'output_dir':'${project_root}/frames'
})

with open('config.yaml','w') as fp:
    OmegaConf.save(config=cfg, f=fp.name)

###################

# add some tracking to reduce duplicated processing
assets_dir = Path(cfg.shared_assets_root)
assets_dir.mkdir(parents=True, exist_ok=True)

# TODO: let's use jsonl here instead
video_assets_meta_fname = assets_dir / 'video_assets_meta.yaml'
if not video_assets_meta_fname.exists():
    video_assets_meta = OmegaConf.create()
    video_assets_meta.videos = []
    with video_assets_meta_fname.open('w') as fp:
        OmegaConf.save(config=video_assets_meta, f=fp.name)
video_assets_meta = OmegaConf.load(video_assets_meta_fname)

audio_assets_meta_fname = assets_dir / 'audio_assets_meta.yaml'
if not audio_assets_meta_fname.exists():
    audio_assets_meta = OmegaConf.create()
    audio_assets_meta.content = []
    with audio_assets_meta_fname.open('w') as fp:
        OmegaConf.save(config=audio_assets_meta, f=fp.name)
audio_assets_meta = OmegaConf.load(audio_assets_meta_fname)

###################

if use_stability_api:
    import os, getpass
    if not os.environ.get('STABILITY_KEY'):
        os.environ['STABILITY_KEY'] = getpass.getpass('Enter your Stability API Key, then press enter to continue')
else:
    if not local:
        from google.colab import output
        output.enable_custom_widget_manager()
        
    from huggingface_hub import notebook_login
    notebook_login()

## $1.$ 📋 Set Project Name (create/resume)

In [None]:

import time
import string
from omegaconf import OmegaConf

def sanitize_folder_name(fp):
    outv = ''
    whitelist = string.ascii_letters + string.digits + '-_'
    for token in str(fp):
        if token not in whitelist:
            token = '-'
        outv += token
    return outv

project_name = 'offspring-handgrenades' # @param {type:'string'}
if not project_name:
    project_name = str(time.time())

project_name = sanitize_folder_name(project_name)

workspace = OmegaConf.load('config.yaml')
workspace.active_project = project_name

with open('config.yaml','w') as fp:
    OmegaConf.save(config=workspace, f=fp.name)

# @markdown To create a new project, enter a unique project name.
# @markdown If you leave `project_name` blank, the current unix timestamp will be used
# @markdown  (seconds since 1970-01-01 00:00).

# @markdown If you use the name of an existing project, the workspace will switch to that project.

# @markdown Non-alphanumeric characters (excluding '-' and '_') will be replaced with hyphens.


# reset workspace
if 'df' in locals():
    del df
if 'df_regen' in locals():
    del df_regen

## $2.$ 🔊 Infer speech from audio

In [36]:
from omegaconf import OmegaConf
from pathlib import Path

workspace = OmegaConf.load('config.yaml')
use_stability_api = workspace.use_stability_api
model_dir = workspace.model_dir
root = Path(workspace.project_root)
root.mkdir(parents=True, exist_ok=True)


import copy
import datetime as dt
import gc
from itertools import chain, cycle
import json
import os
import re
import string
import subprocess
from subprocess import Popen, PIPE
import textwrap
import time
import warnings

from IPython.display import display
import numpy as np
import pandas as pd
import panel as pn
from tqdm.autonotebook import tqdm

# TODO: add support for whisper API
import whisper

storyboard = OmegaConf.create()

d_ = dict(
    # all the underscore does is make it so each of the following lines can be preceded with a comma
    # otw the first parameter would be offset from the other in the colab form
    _=""

    , video_url = 'https://www.youtube.com/watch?v=B1BdQcJ2ZYY' # @param {type:'string'}
    , audio_fpath = '' # @param {type:'string'}
)

# @markdown `video_url` - URL of a youtube video to download as a source for audio and potentially for text transcription as well.

# @markdown `audio_fpath` - Optionally provide an audio file instead of relying on a youtube download. Name it something other than 'audio.mp3', 
# @markdown                 otherwise it might get overwritten accidentally.


d_.pop('_')
storyboard.params = d_

if not storyboard.params.audio_fpath:
    storyboard.params.audio_fpath = None

storyboard_fname = root / 'storyboard.yaml'
with open(storyboard_fname,'wb') as fp:
    OmegaConf.save(config=storyboard, f=fp.name)

###############################
# Download audio from youtube #
###############################

# this should modify the existing record for the URL rather than creating a new one...
force_redownload=False

video_url = storyboard.params.video_url
download_video=True

# check if the video has already been downloaded/processed before redownloading
assets_dir = Path(workspace.shared_assets_root)
video_assets_meta_fname = assets_dir / 'video_assets_meta.yaml'
video_assets_meta = OmegaConf.load(video_assets_meta_fname)
audio_assets_meta_fname = assets_dir / 'audio_assets_meta.yaml'
audio_assets_meta = OmegaConf.load(audio_assets_meta_fname)

if not force_redownload:
    for rec in video_assets_meta.videos:
        if rec.video_url == video_url:
            print("previously downloaded video detected")
            download_video=False
            # populate storyboard with previous processing results
            if rec.get('audio_fpath'):
                storyboard.params.audio_fpath = rec.get('audio_fpath')
                for audio_meta in audio_assets_meta.content:
                    if audio_meta.audio_fpath == storyboard.params.audio_fpath:
                        print("previously processed audio located")
                        storyboard.params.video_duration = audio_meta.duration    

                        whisper_seg_fpath = Path(audio_meta.whisper_segmentation)
                        with whisper_seg_fpath.open() as f:
                            timings = json.load(f)
                        storyboard.prompt_starts = timings['segments']
                        break
            break


if download_video:
    # check if user provided an audio filepath (or we already have one from youtube) before attempting to download
    # TODO: this needs to be reorganized a bit for the new shared assets folder
    video_assets_meta_record = {}
    video_assets_meta_record['video_url'] = video_url
    if storyboard.params.get('audio_fpath') is None:
        ytdl_prefix = "DOWNLOADED__"
        #ytdl_fname = f"{str(root / ytdl_prefix)}%(title)s.%(ext)s"
        # should probably download to a tmp dir then move it after
        ytdl_fname = f"{str(assets_dir / ytdl_prefix)}%(title)s.%(ext)s"

        # force re-download for now... existing mp4 causes download step to be skipped.
        # TODO: intelligently detect when to re-download or not.
        # -> pretty sure this is mainly going to be changing the checkpointing sequence,
        #    i.e. so we can compare the input video_url with whatever may already be on the storyboard
        #if Path(ytdl_fname).exists():
        #    Path(ytdl_fname).unlink()
        
        !yt-dlp -o "{ytdl_fname}" {video_url}

        #matched_files = root.glob(ytdl_prefix+"*")
        matched_files = assets_dir.glob(ytdl_prefix+"*")
        most_recent_file = max(matched_files, key=os.path.getctime)
        print(f"downloaded: {most_recent_file}")
        ytdl_fname = most_recent_file
        # new attribute cause why not
        storyboard.params.downloaded_video_fpath = ytdl_fname
        video_assets_meta_record['video_fpath'] = str(ytdl_fname.absolute())

        #audio_fpath = str( root / 'audio.m4a' )
        #audio_fpath = str( assets_dir / 'audio.m4a' )
        audio_fpath = ytdl_fname.with_suffix('.m4a')
        input_audio = ytdl_fname
        !ffmpeg -y -i "{input_audio}" -vn -c:a aac "{audio_fpath}"

        storyboard.params.audio_fpath = audio_fpath
        video_assets_meta_record['audio_fpath'] = str(audio_fpath.absolute())
    
    video_assets_meta.videos.append(video_assets_meta_record)
        
with open(storyboard_fname,'wb') as fp:
    OmegaConf.save(config=storyboard, f=fp.name)

with open(video_assets_meta_fname, 'wb') as fp:
    OmegaConf.save(config=video_assets_meta, f=fp.name)

###################################################
# 💬 Transcribe and segment speech using whisper #
###################################################

# for video duration
def get_audio_duration_seconds(audio_fpath):
    outv = subprocess.run([
        'ffprobe'
        ,'-i',audio_fpath
        ,'-show_entries', 'format=duration'
        ,'-v','quiet'
        ,'-of','csv=p=0'
        ],
        stdout=subprocess.PIPE
        ).stdout.decode('utf-8')
    return float(outv.strip())

audio_fpath = Path(storyboard.params.audio_fpath)

force_retranscription = False # @param {type:'boolean'}

if not force_retranscription:
    for audio_meta in audio_assets_meta.content:
        if audio_meta.audio_fpath == audio_fpath:
            print("previously processed audio detected")
            #storyboard.prompt_starts = audio_meta.whisper_segmentation
            whisper_seg_fpath = Path(audio_meta.whisper_segmentation)
            with whisper_seg_fpath.open() as f:
                timings = json.load(f)
            storyboard.prompt_starts = timings['segments']
            storyboard.params['video_duration'] = audio_meta.duration
            
if not storyboard.get('prompt_starts'):
    audio_meta = {}
    audio_meta['audio_fpath'] = storyboard.params.audio_fpath
    # outputs text files as audio.* locally
    #!whisper --model large --word_timestamps True -o {str(root)} "{storyboard.params.audio_fpath}"
    !whisper --model large --word_timestamps True -o {str(assets_dir)} "{storyboard.params.audio_fpath}"

    #with Path(root / 'audio.json').open() as f:
    whisper_seg_fpath = Path(storyboard.params.audio_fpath).with_suffix('.json')
    audio_meta['whisper_segmentation'] = str(whisper_seg_fpath)
    
    with whisper_seg_fpath.open() as f:
        timings = json.load(f)

    storyboard.prompt_starts = timings['segments']
    
    # i don't think this is reliable unfortunately.
    #storyboard.params['video_duration'] = storyboard.prompt_starts[-1]['end']
    audio_meta['duration'] = get_audio_duration_seconds(audio_fpath)
    storyboard.params['video_duration'] = audio_meta['duration']
    
    # checkpoint new audio processing metadata
    audio_assets_meta.content.append(audio_meta)
    with open(audio_assets_meta_fname, 'wb') as fp:
        OmegaConf.save(config=audio_assets_meta, f=fp.name)


prompt_starts = storyboard.prompt_starts

### checkpoint the processing work we've done to this point

prompt_starts_copy = copy.deepcopy(prompt_starts)
storyboard.prompt_starts = prompt_starts_copy

with open(storyboard_fname) as fp:
    OmegaConf.save(config=storyboard, f=fp.name)

###############################
# Review/Modify transcription #
###############################

# @markdown ---
# @markdown NB: When this cell finishes running, a table will appear
# @markdown at the bottom of the output window. This table is editable
# @markdown and can be used to correct errors in the transcription.
# @markdown
# @markdown additionally, the `override_prompt` field can be used to provide an 
# @markdown alternative text prompt for image generation. If this feature is
# @markdown used, both the lyric and the theme prompt (which you will specify 
# @markdown in the cell that follows this) will be ignored. If you want to use
# @markdown an `override_prompt` and also want to stay on theme, you will have 
# @markdown to append the desired `theme_prompt` to the end of the 
# @markdown `override_prompt` manually.


# https://panel.holoviz.org/reference/widgets/Tabulator.html
pn.extension('tabulator') # I don't know that specifying 'tabulator' here is even necessary...

tabulator_formatters = {
    'bool': {'type': 'tickCross'}
}

# reset workspace
if 'df_regen' in locals():
    del df_regen
    
try:
    prompt_starts = OmegaConf.to_container(prompt_starts)
except:
    print("huh. that's weird.")
    print(prompt_starts)
    pass
df = pd.DataFrame(prompt_starts)[['start','end','text']]

#df['override_prompt'] = ''
# TODO: check if user has edited lyrics or... anything i guess

df_pre = copy.deepcopy(df)
outv = pn.widgets.Tabulator(df, formatters=tabulator_formatters)
#if local: # really the issue isn't "local", it's VSCode
#    outv = df
outv


[youtube] Extracting URL: https://www.youtube.com/watch?v=B1BdQcJ2ZYY
[youtube] B1BdQcJ2ZYY: Downloading webpage
[youtube] B1BdQcJ2ZYY: Downloading android player API JSON
[info] B1BdQcJ2ZYY: Downloading 1 format(s): 242+251
[dashsegments] Total fragments: 1
[download] Destination: /home/dmarx/projects/video-killed-the-radio-star/shared_assets/DOWNLOADED__The Humans Are Dead - Full version.f242.webm
[K[download] 100% of    1.18MiB in [1;37m00:00:00[0m at [0;32m4.24MiB/s[0m[0;33m00:00[0m (frag 1/1)
[dashsegments] Total fragments: 1
[download] Destination: /home/dmarx/projects/video-killed-the-radio-star/shared_assets/DOWNLOADED__The Humans Are Dead - Full version.f251.webm
[K[download] 100% of    1.82MiB in [1;37m00:00:00[0m at [0;32m4.98MiB/s[0m[0;33m00:00[0m (frag 1/1)
[Merger] Merging formats into "/home/dmarx/projects/video-killed-the-radio-star/shared_assets/DOWNLOADED__The Humans Are Dead - Full version.webm"
Deleting original file /home/dmarx/projects/video-killed-the

In [51]:
# time to wrap some of the loading logic for portability

from omegaconf import OmegaConf
from pathlib import Path
from safetensors.numpy import load_file as load_safetensors

def save_storyboard(storyboard):
    with open(storyboard_fname) as fp:
        OmegaConf.save(config=storyboard, f=fp.name)

def load_storyboard():
    workspace = OmegaConf.load('config.yaml')
    root = Path(workspace.project_root)
    storyboard_fname = root / 'storyboard.yaml'
    storyboard = OmegaConf.load(storyboard_fname)
    return workspace, storyboard

def load_audio_meta(workspace, storyboard):
    assets_dir = Path(workspace.shared_assets_root)
    audio_assets_meta_fname = assets_dir / 'audio_assets_meta.yaml'
    audio_assets_meta = OmegaConf.load(audio_assets_meta_fname)
    audio_meta=dict()
    for idx, rec in enumerate(audio_assets_meta.content):
        if rec.audio_fpath == storyboard.params.audio_fpath:
            audio_meta = rec
            break
    return audio_meta



# load audio features
workspace, storyboard = load_storyboard()
audio_meta = load_audio_meta(workspace, storyboard)
#structural_features = load_safetensors(audio_meta.structural_features)
#structural_features['beat_times']

In [37]:
# @markdown Run this cell to preview an alternative segmentation
# @markdown of the lyrics. If you like what you see, run the next 
# @markdown cell to keep this segmentation, or skip
# @markdown the next cell to keep your lyrics segmented as above.
try_to_segment_further = False # @param {'type':'boolean'}

def calculate_interword_gaps(segment):
    end_prev = -1
    gaps = []
    for word in segment['words']:
        if end_prev < 0:
            end_prev = word['end']
            continue 
        gap = word['start'] - end_prev
        gaps.append(gap)
        end_prev = word['end']
    return gaps 

def trivial_subsegmentation(segment, threshold=0, gaps=None):
    """
    split on gaps in detected vocal activity. 
    Contiguity = gap between adjacent tokens is less than the input threshold.
    """
    if gaps is None:
        gaps = calculate_interword_gaps(seg)
    out_segments = []
    this_segment = [seg['words'][0]]
    for word, preceding_pause in zip(seg['words'][1:], gaps):
        if preceding_pause <= threshold:
            this_segment.append(word)
        else:
            out_segments.append(this_segment)
            this_segment = [word]
    out_segments.append(this_segment)

    outv = [dict(
        start=seg[0]['start'],
        end=seg[-1]['end'],
        text=''.join([w['word'] for w in seg]).strip(),
    ) for seg in out_segments]

    return outv

segments = timings['segments']
segments_df = df
if try_to_segment_further:
    segments = []
    for seg in timings['segments']:
        segments.extend(trivial_subsegmentation(seg))

    segments_df = pd.DataFrame(segments)
pn.widgets.Tabulator(segments_df[['start', 'end','text']], formatters=tabulator_formatters)

In [39]:
# @markdown Run this cell (with the box ticked) to use the alternative version of the segmentation.
# @markdown The "alternative" or "further segmented" version should be broken up into more scenes (or will be unchanged).

use_new_segmentation = False # @param {'type':'boolean'}

if use_new_segmentation:
    prompt_starts=segments
    storyboard.prompt_starts = prompt_starts
    save_storyboard(storyboard)
    # with open(storyboard_fname) as fp:
    #     OmegaConf.save(config=storyboard, f=fp.name)

save_storyboard(storyboard)
    
#prompt_starts = OmegaConf.to_container(prompt_starts)
df = pd.DataFrame(prompt_starts)[['start','end','text']]

#df['override_prompt'] = ''

df_pre = copy.deepcopy(df)
outv = pn.widgets.Tabulator(df, formatters=tabulator_formatters)
#if local: # really the issue isn't "local", it's VSCode
#    outv = df
outv

# TODO: edits to dataframe should propogate to prompt_starts, or we should tell user when that does/n't apply

In [52]:
# Music Structure Analysis
# - beat and tempo detection
# - Self-similarity graph

import numpy as np
import scipy
import sklearn.cluster
import librosa

from safetensors.numpy import save_file as save_safetensors
from safetensors.numpy import load_file as load_safetensors


def analyze_audio_structure(
    audio_fpath,
    BINS_PER_OCTAVE = 12 * 3, # should be a multiple of twelve: https://github.com/MTG/essentia/blob/master/src/examples/python/tutorial_spectral_constantq-nsg.ipynb
    N_OCTAVES = 7,
):
    """
    via librosa docs
    https://librosa.org/doc/latest/auto_examples/plot_segmentation.html#sphx-glr-auto-examples-plot-segmentation-py
    cites: McFee and Ellis, 2014 - https://brianmcfee.net/papers/ismir2014_spectral.pdf
    """
    y, sr = librosa.load(audio_fpath)

    C = librosa.amplitude_to_db(np.abs(librosa.cqt(y=y, sr=sr,
                                            bins_per_octave=BINS_PER_OCTAVE,
                                            n_bins=N_OCTAVES * BINS_PER_OCTAVE)),
                                ref=np.max)

    # reduce dimensionality via beat-synchronization
    tempo, beats = librosa.beat.beat_track(y=y, sr=sr, trim=False)
    Csync = librosa.util.sync(C, beats, aggregate=np.median)
    
    # I have concerns about this frame fixing operation
    beat_times = librosa.frames_to_time(librosa.util.fix_frames(beats, x_min=0), sr=sr)

    # width=3 prevents links within the same bar 
    # mode=’affinity’ here implements S_rep (after Eq. 8)
    R = librosa.segment.recurrence_matrix(Csync, width=3, mode='affinity', sym=True)
    # Enhance diagonals with a median filter (Equation 2)
    df = librosa.segment.timelag_filter(scipy.ndimage.median_filter)
    Rf = df(R, size=(1, 7))
    # build the sequence matrix (S_loc) using mfcc-similarity
    mfcc = librosa.feature.mfcc(y=y, sr=sr)
    Msync = librosa.util.sync(mfcc, beats)
    path_distance = np.sum(np.diff(Msync, axis=1)**2, axis=0)
    sigma = np.median(path_distance)
    path_sim = np.exp(-path_distance / sigma)
    R_path = np.diag(path_sim, k=1) + np.diag(path_sim, k=-1)
    # compute the balanced combination
    deg_path = np.sum(R_path, axis=1)
    deg_rec = np.sum(Rf, axis=1)
    mu = deg_path.dot(deg_path + deg_rec) / np.sum((deg_path + deg_rec)**2)
    A = mu * Rf + (1 - mu) * R_path

    # compute normalized laplacian and its spectrum
    L = scipy.sparse.csgraph.laplacian(A, normed=True)
    evals, evecs = scipy.linalg.eigh(L)
    # clean this up with a median filter. can help smooth over discontinuities
    evecs = scipy.ndimage.median_filter(evecs, size=(9, 1))
    return dict(
        y=y, 
        #sr=np.array(sr).astype(np.uint8), # too low precision
        sr=np.array(sr).astype(np.uint32), # uint16 works but it's a single value, doesn't really matter.
        tempo=tempo,
        beats=beats,
        beat_times=beat_times,
        evecs=evecs,
    )


audio_structure_features = analyze_audio_structure(audio_fpath=storyboard.params.audio_fpath)
audio_features_fpath = Path(storyboard.params.audio_fpath).with_suffix('.audio_features.safetensors')
save_safetensors(audio_structure_features, audio_features_fpath)

print(storyboard.params.audio_fpath)
#audio_meta = load_audio_meta(workspace, storyboard)


assets_dir = Path(workspace.shared_assets_root)
audio_assets_meta_fname = assets_dir / 'audio_assets_meta.yaml'
audio_assets_meta = OmegaConf.load(audio_assets_meta_fname)
audio_meta=dict()
for idx, rec in enumerate(audio_assets_meta.content):
    if rec.audio_fpath == storyboard.params.audio_fpath:
        audio_meta = rec
        break


print(audio_meta)
if 'audio_fpath' not in audio_meta:
    audio_meta['audio_fpath'] = storyboard.params.audio_fpath
    #audio_assets_meta.append(audio_meta)
    audio_assets_meta.content.append(audio_meta)

#audio_meta['structural_features'] = str(audio_features_fpath) # maybe i need to use dot notation for it to hit the actual object?
#audio_meta.structural_features = str(audio_features_fpath) # nope, doesn't work either.
with open(audio_assets_meta_fname, 'wb') as fp:
    OmegaConf.save(config=audio_assets_meta, f=fp.name)
print(audio_meta)

  y, sr = librosa.load(audio_fpath)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


/home/dmarx/projects/video-killed-the-radio-star/shared_assets/DOWNLOADED__The Humans Are Dead - Full version.m4a
{'audio_fpath': PosixPath('/home/dmarx/projects/video-killed-the-radio-star/shared_assets/DOWNLOADED__The Humans Are Dead - Full version.m4a'), 'whisper_segmentation': '/home/dmarx/projects/video-killed-the-radio-star/shared_assets/DOWNLOADED__The Humans Are Dead - Full version.json', 'duration': 114.331}
{'audio_fpath': PosixPath('/home/dmarx/projects/video-killed-the-radio-star/shared_assets/DOWNLOADED__The Humans Are Dead - Full version.m4a'), 'whisper_segmentation': '/home/dmarx/projects/video-killed-the-radio-star/shared_assets/DOWNLOADED__The Humans Are Dead - Full version.json', 'duration': 114.331}


In [53]:
### for adjusting start times
import librosa

workspace, storyboard = load_storyboard()
audio_meta = load_audio_meta(workspace, storyboard)
print(audio_meta)

#structural_features = analyze_audio_structure(audio_fpath=storyboard.params.audio_fpath)
structural_features = load_safetensors(audio_meta['structural_features'])

beats = structural_features['beats']
sr = int(structural_features['sr'])

#beat_times = librosa.samples_to_time(beats, sr=sr)
beat_times = librosa.frames_to_time(beats, sr=sr)
#beat_times = structural_features['beat_times']

# as written, this won't work if we do the additional segmentation thing.
#scenes = OmegaConf.load(audio_meta['whisper_segmentation'])
#scene_starts = [s['start'] for s in scenes['segments']]

scene_starts = [s['start'] for s in storyboard.prompt_starts]

#scene_starts_as_frames = librosa.time_to_samples(scene_starts, sr=sr)
scene_starts_as_frames = librosa.time_to_frames(scene_starts, sr=sr)
downbeat_indices = librosa.util.match_events(scene_starts_as_frames, beats)
downbeat_times = beat_times[downbeat_indices]

#starts_paired_with_beats = np.vstack([np.array(scene_starts), np.array(downbeat_times)]).T

adjust_to_closest_beat = True
only_adjust_down = False

if adjust_to_closest_beat:
    for idx, scene in enumerate(storyboard.prompt_starts):
        #print(f"scene {idx}: {scene['start']} || {downbeat_times[idx]}")
        if scene['start'] > downbeat_times[idx]:
            print(f"scene {idx}: {scene['start']} -> {downbeat_times[idx]}")
            scene['start'] = float(downbeat_times[idx])
        elif (not only_adjust_down) and scene['start'] < downbeat_times[idx]:
            print(f"scene {idx}: {scene['start']} <-> {downbeat_times[idx]}")
            scene['start'] = float(downbeat_times[idx])
        
# TODO: checkpoint changes

save_storyboard(storyboard)

# options:
# - use closest beat for all scenes
# - use beat if earlier than current scene start
# - ..use onsets instead of beats for this part

{'audio_fpath': PosixPath('/home/dmarx/projects/video-killed-the-radio-star/shared_assets/DOWNLOADED__The Humans Are Dead - Full version.m4a'), 'whisper_segmentation': '/home/dmarx/projects/video-killed-the-radio-star/shared_assets/DOWNLOADED__The Humans Are Dead - Full version.json', 'duration': 114.331}


## Theme -> Scene Assignment

In [None]:
# TODO: add to setup
# %pip install librosa # scikit-learn

# TODO: need to save the updated prompt_starts to the storyboard before this step.
#       otw the cluster assignments won't propagate to the newly added subscenes properly.
#       i think.

## development
#storyboard = OmegaConf.load('/home/dmarx/projects/video-killed-the-radio-star/nirvana-sold-structure/storyboard.yaml')

# if False, themes will be rotated sequentially such that no two adjacent frames
# will use the same theme prompt (if multiple theme prompts were provided).
infer_thematic_structure = True # @param {type:'boolean'}
theme_prompt = ( # @param {type:'string'}
    " awkward 1990s photos | "
    " awkward 1980s photos | "
    " awkward 1970s photos | "
    " awkward 1960s photos | "
    " awkward 2000s photos | " # TODO: test with fewer than 5 themes
    #"abstract art, inspired by peter gabriel | "
    #"abstract art, inspired by radiohead | "
    #"abstract art, inspired by coldplay | "
    #"abstract art, inspired by david bowie | "
    #"abstract art, inspired by talking heads |"
    #"abstract art, inspired by kusama | "
    #"abstract art, inspired by alex grey "
    
)
storyboard.params.theme_prompt = theme_prompt
themes = [prompt.strip() for prompt in theme_prompt.split('|') if prompt.strip()]

    
def laplacian_segmentation(
    audio_fpath=None,
    evecs=None,
    n_clusters = 5,
    n_spectral_features = None,
    # probably don't need to expose these parameters
):
    """
    segment audio by clustering a self-similarity matrix.
    via librosa docs
    https://librosa.org/doc/latest/auto_examples/plot_segmentation.html#sphx-glr-auto-examples-plot-segmentation-py
    cites: McFee and Ellis, 2014 - https://brianmcfee.net/papers/ismir2014_spectral.pdf
    """
    if evecs is None:
        if audio_fpath is None:
            raise Exception("One of `audio_fpath` or `evecs` must be provided")
        features = analyze_audio_structure(audio_fpath)
        evecs = features['evecs']
    
    if n_spectral_features is None:
        n_spectral_features = n_clusters

    # cumulative normalization is needed for symmetric normalize laplacian eigenvectors
    Cnorm = np.cumsum(evecs**2, axis=1)**0.5
    k = n_spectral_features
    X = evecs[:, :k] / Cnorm[:, k-1:k]

    # use these k components to cluster beats into segments
    KM = sklearn.cluster.KMeans(n_clusters=n_clusters, n_init="auto")
    seg_ids = KM.fit_predict(X)

    return seg_ids #, beat_times, tempo


from safetensors.numpy import save_file as save_safetensors
from safetensors.numpy import load_file as load_safetensors

# audio_assets_meta_fname = assets_dir / 'audio_assets_meta.yaml'
# audio_assets_meta = OmegaConf.load(audio_assets_meta_fname)
# audio_meta=dict()
# for rec in audio_assets_meta.content:
#     if rec.audio_fpath == storyboard.params.audio_fpath:
#         audio_meta = rec
#         break


if (len(themes) > 1):
    if infer_thematic_structure:
        audio_features_fpath = audio_meta.get('structural_features')
        if audio_features_fpath:
            audio_structure_features = load_safetensors(audio_features_fpath)
        else:
            print("Run the cell above first to compute the structural features that will be used in this cell.")
                
        beat_times = audio_structure_features['beat_times']
        evecs = audio_structure_features['evecs']
        segment_labels = laplacian_segmentation(
            evecs=evecs,
            n_clusters=len(themes),
            n_spectral_features=len(themes),
        )
        
        #adjust scene start times
        
        
        # TODO: perform this inference only after fixing long scenes
        # TODO: swap out rec['end'] -> rec_prev['start'] here
        for rec in storyboard.prompt_starts:
            beat_indices = np.where((beat_times >= rec['start']) & (beat_times <= rec['end']))[0]
            segments_this_interval = segment_labels[beat_indices]
            if len(segments_this_interval) == 0:
                dominant_label = 0
            else:
                dominant_label = int(np.argmax(np.bincount(segments_this_interval)))
            rec['structural_segmentation_label'] = dominant_label
            rec['_theme'] = themes[dominant_label]
    else:
        for rec in storyboard.prompt_starts:
            rec['_theme'] = themes[idx % len(themes)]
else:
    for rec in storyboard.prompt_starts:
        rec['_theme'] = theme_prompt

## checkpoint
with open(storyboard_fname) as fp:
    OmegaConf.save(config=storyboard, f=fp.name)
        
pn.extension('tabulator') # I don't know that specifying 'tabulator' here is even necessary...
tabulator_formatters = {
    'bool': {'type': 'tickCross'}
}
        
df_themes = pd.DataFrame(storyboard.prompt_starts).rename(columns={'_theme':'theme', 'structural_segmentation_label':'theme_id'})[['start','end','text','theme_id', 'theme']]
outv = pn.widgets.Tabulator(df_themes, formatters=tabulator_formatters)
outv


In [None]:
# https://github.com/mjhydri/1d-statespace
#%pip install jump-reward-inference
#%pip install git+https://github.com/mjhydri/1D-StateSpace

In [None]:
#from jump_reward_inference.joint_tracker import joint_inference


#estimator = joint_inference(1, plot=True) 

#output = estimator.process("music file directory")

In [None]:
# audio_assets_meta_fname = assets_dir / 'audio_assets_meta.yaml'
# audio_assets_meta = OmegaConf.load(audio_assets_meta_fname)


# from safetensors.numpy import save_file as save_safetensors
# from safetensors.numpy import load as load_safetensors
# #import numpy as np
# audio_structure_features = analyze_audio_structure(audio_fpath=storyboard.params.audio_fpath)
# audio_features_fpath = Path(storyboard.params.audio_fpath).with_suffix('.audio_features.safetensors')
# save_safetensors(audio_structure_features, audio_features_fpath)


# for audio_meta in audio_assets_meta.content:
#     if audio_meta.audio_fpath == storyboard.params.audio_fpath:
#         #print("found it!")
#         audio_meta['structural_features'] = str(audio_features_fpath)
#         with open(audio_assets_meta_fname, 'wb') as fp:
#             OmegaConf.save(config=audio_assets_meta, f=fp.name)
#         break
        

In [None]:
### Subdivide unusually long scenes

import numpy as np
import copy

# TODO: wrap these steps in functions for legibility
# TODO: make this parameterizable via storyboard (and save analysis to storyboard)
# TODO: use beat counts to estimate a smart scene duration

# estimate parameters of scene duration distribution 
#storyboard = OmegaConf.load('/home/dmarx/projects/video-killed-the-radio-star/nirvana-sold-structure/storyboard.yaml')
scene_durations = []
scenes_ = []
for idx, rec in enumerate(storyboard.prompt_starts):
    rec=dict(rec)
    if idx > 0:
        # are we maybe doubling up 'start' time stamps? like there are more unique 'end's than 'start's?
        duration = rec['start'] - prev['start']
        prev['duration_'] = duration
        scene_durations.append(duration)
    prev = rec
    scenes_.append(prev)
# handle last record
else:
    # here's the bug.
    # TODO: swap out rec['end'] -> rec_prev['start'] here
    rec['duration_'] = rec['end'] - rec['start']
    scenes_.append(rec)


mu = sum(scene_durations)/len(scene_durations)
sigma = np.std(scene_durations)

# 1sd filter to concentrate on mode
scene_durations2 = [s for s in scene_durations if (mu - sigma) < s < (mu+sigma)]
mu2 = sum(scene_durations2)/len(scene_durations2)
sigma2 = np.std(scene_durations2)

# # second 1sd filter
# scene_durations3 = [s for s in scene_durations2 if (mu2 - sigma2) < s < (mu2+sigma2)]
# mu3 = sum(scene_durations3)/len(scene_durations3)
# sigma3 = np.std(scene_durations3)

###########

# break up "outlier" segments into smaller chunks
# this heuristic could be improved with beat synchronization and onset detection.
# ... also could probably leverage the 'end' time of the scene
# TODO: hierarchical theme structure analysis
# TODO: MSA segmentation for fully instrumental (i.e. arbitrary) audio
threshhold = mu2 + sigma
scenes = []
#for rec in storyboard.prompt_starts:
for rec in list(scenes_):
    gap_remaining = rec['duration_']
    while gap_remaining > threshhold:
        #step = min(max(mu2-sigma, (np.random.normal() + mu2)*sigma), mu2+sigma)
        step=mu2
        step = float(step)
        rec['duration_'] = step
        # new_rec = {
        #     'start': rec['start'] + step,
        #     'duration_': step,
        #     #'_prompt':rec.get('_prompt')
        #     'prompt':rec.get('prompt'), #,rec.get('_prompt'))
        #     'structural_segmentation_label' = rec.get('structural_segmentation_label'),
        #     '_theme' = rec.get('_theme'),
        # }
        ######### TODO: infer appropriate structural segmentation for new subscene
        new_rec = copy.deepcopy(rec)
        new_rec['start'] = rec['start'] + step
        new_rec['duration_'] = step
        ### maybe i could add a flag or something to clarify that this was an "inferred" subscene
        new_rec['parent_scene'] = rec.get('uid') # something like this?
        new_rec['inferred_subscene'] = True # or this?
         
        scenes.append(rec)
        rec = new_rec
        gap_remaining -= step
    scenes.append(rec)

    
# TODO: Adjust to beats


###############

storyboard.prompt_starts = scenes
#storyboard_fname = '/home/dmarx/projects/video-killed-the-radio-star/nirvana-sold-structure/storyboard.yaml'

storyboard_fname = Path(workspace.project_root) / 'storyboard.yaml'
with open(storyboard_fname,'wb') as fp:
    OmegaConf.save(config=storyboard, f=fp.name)


## $3.$ 🎬 Animate

In [None]:
#####################################
# @title ## 🎨 Generate init images
#####################################


import copy
import datetime as dt
from pathlib import Path
import random
import string
import time
import os
import io

from bokeh.models.widgets.tables import (
    NumberFormatter, 
    BooleanFormatter,
    CheckboxEditor,
)
import numpy as np
from omegaconf import OmegaConf, DictConfig
import pandas as pd
import panel as pn
import PIL
from PIL import Image
from tqdm.autonotebook import tqdm

import torch
from torch import autocast
from diffusers import (
    StableDiffusionImg2ImgPipeline,
    StableDiffusionPipeline,
)


# to do: is there a way to check if this is in the env already?
pn.extension('tabulator')

# this processes optional edits to the transcription (above) 
if ('prompt_starts' in locals()) \
and ('df_pre' in locals()):
    if isinstance(prompt_starts, DictConfig):
        prompt_starts = OmegaConf.to_container(prompt_starts)
    # update prompt_starts if any changes were made above
    if not np.all(df_pre.values == df.values):
        print("override value detected")
        df_pre = copy.deepcopy(df)
        for i, rec in enumerate(prompt_starts):
            rec['start'] = float(df.loc[i,'start'])
            rec['end'] = float(df.loc[i,'end'])
            rec['text'] = df.loc[i,'text']
            #rec['prompt'] = df.loc[i,'prompt'] ## TODO...
        
        # TO DO: check if we're checkpointing like, too much or whatever
        # ...actually, I think the above code might not do anything
        # probably need to checkpoint prompt_starts into the storyboard on disk.
        # let's do that here just to be safe.    
        workspace = OmegaConf.load('config.yaml')
        root = Path(workspace.project_root)

        storyboard_fname = root / 'storyboard.yaml'
        storyboard = OmegaConf.load(storyboard_fname)

        storyboard.prompt_starts = prompt_starts
        with open(storyboard_fname) as fp:
            OmegaConf.save(config=storyboard, f=fp.name)
            
##################################################################################


workspace = OmegaConf.load('config.yaml')
root = Path(workspace.project_root)

storyboard_fname = root / 'storyboard.yaml'
storyboard = OmegaConf.load(storyboard_fname)

## development
#storyboard = OmegaConf.load('/home/dmarx/projects/video-killed-the-radio-star/nirvana-sold-structure/storyboard.yaml')


prompt_starts = storyboard.prompt_starts
#use_stability_api = workspace.use_stability_api
model_dir = workspace.model_dir

########################################

# misc utils

def rand_str(n_char=5):
    return ''.join(random.choice(string.ascii_lowercase) for i in range(n_char))

def save_frame(
    img: Image,
    idx:int=0,
    root_path=Path('./frames'),
    name=None,
):
    root_path.mkdir(parents=True, exist_ok=True)
    if name is None:
        name = rand_str()
    outpath = root_path / f"{idx}-{name}.png"
    img.save(outpath)
    return str(outpath)

def get_image_sequence(idx, root, init_first=True):
    root = Path(root)
    images = (root / 'frames' ).glob(f'{idx}-*.png')
    images = [str(fp) for fp in images]
    if init_first:
        init_image = None
        images2 = []
        for i, fp in enumerate(images):
            if 'anchor' in fp:
                init_image = fp
            else:
                images2.append(fp)
        if not init_image:
            try:
                init_image, images2 = images2[0], images2[1:]
                images = [init_image] + images2
            except IndexError:
                images = images2
    return images

def archive_images(idx, root, archive_root = None):
    root = Path(root)
    if archive_root is None:
        archive_root = root / 'archive'
    archive_root = Path(archive_root)
    archive_root.mkdir(parents=True, exist_ok=True)
    old_images = get_image_sequence(idx, root=root)
    if not old_images:
        return
    print(f"moving {len(old_images)} old images for scene {idx} to {archive_root}")
    for old_fp in old_images:
        old_fp = Path(old_fp)
        im_name = Path(old_fp.name)
        new_path = archive_root / im_name
        if new_path.exists():
            im_name = f"{im_name.stem}-{time.time()}{im_name.suffix}"
            new_path = archive_root / im_name
        old_fp.rename(new_path)

############################

device = 'cuda'
model_id = "CompVis/stable-diffusion-v1-5"
download=True

#use_stability_api = workspace.use_stability_api
model_dir = workspace.model_dir
model_path= str(Path(model_dir) / 'huggingface' / 'diffusers')


if 'get_image_for_prompt' not in locals():

    if use_stability_api:
        import warnings
        from stability_sdk import client
        import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation

        # TODO: update this stuff to reflect updates to API/sdk
        def get_image_for_prompt(prompt, max_retries=5, **kargs):
            stability_api = client.StabilityInference(
                key=os.environ['STABILITY_KEY'], 
                verbose=False,
            )

            # auto-retry if mitigation triggered
            while max_retries:
                try:
                    answers = stability_api.generate(prompt=prompt, **kargs)
                    response = process_response(answers)
                    for img in response:
                        yield img
                    break # NB: this breaks us out of the while loop, not the for loop.

                # TODO: better regen handling
                except RuntimeError:
                    print("runtime error")
                    max_retries -= 1
                    warnings.warn(f"mitigation triggered, retries remaining: {max_retries}")


        def process_response(answers):
            for resp in answers:
                for artifact in resp.artifacts:
                    if artifact.finish_reason == generation.FILTER:
                        warnings.warn(
                            "Your request activated the API's safety filters and could not be processed."
                            "Please modify the prompt and try again.")
                        raise RuntimeError
                    if artifact.type == generation.ARTIFACT_IMAGE:
                        img = Image.open(io.BytesIO(artifact.binary))
                        yield img

        # leverage stability API internal parallelism for batch variation requests
        # TODO: make sure this behaves appropriately for regeneration on NSFW trigger. only regen as needed, not whole batch
        def get_variations_w_init(prompt, init_image, n_variations=2, image_consistency=.7, **kargs):
             return list(
                 get_image_for_prompt(
                     prompt=prompt, 
                     init_image=init_image, 
                     start_schedule=(1-image_consistency), 
                     #num_samples=n_variations,
                     samples=n_variations,
                     **kargs,
                 )
             )
                        
    else:

        if download:
            img2img = StableDiffusionImg2ImgPipeline.from_pretrained(
                model_id,
                revision="fp16", 
                torch_dtype=torch.float16,
                use_auth_token=True
            )
            img2img = img2img.to(device)
            img2img.save_pretrained(model_path)
        else:
            img2img = StableDiffusionImg2ImgPipeline.from_pretrained(
                model_path,
                local_files_only=True
            ).to(device)

        text2img = StableDiffusionPipeline(
            vae=img2img.vae,
            text_encoder=img2img.text_encoder,
            tokenizer=img2img.tokenizer,
            unet=img2img.unet,
            feature_extractor=img2img.feature_extractor,
            scheduler=img2img.scheduler,
            safety_checker=img2img.safety_checker,
        )
        text2img.enable_attention_slicing()
        img2img.enable_attention_slicing()


        def get_image_for_prompt_hf(
            prompt,
            **kwargs
        ):
            f = text2img if kwargs.get('image') is None else img2img
            n_retries = 5
            with autocast(device):
                while n_retries > 0:
                    n_retries-=1
                    result = f(prompt, **kwargs)
                    if not any(result.nsfw_content_detected):
                        return result.images
                    else:
                        print(f"nsfw content detectected. retries remaining: {n_retries}")

        def get_image_for_prompt(*args, **kargs):
            if 'init_image' in kargs:
                kargs['image'] = kargs.pop('init_image')
            if 'start_schedule' in kargs:
                kargs['strength'] = kargs.pop('start_schedule')
            return get_image_for_prompt_hf(*args, **kargs)

        # TODO: modify for updated batch request
        def get_variations_w_init(prompt, init_image, **kargs):
            return list(get_image_for_prompt(prompt=prompt, init_image=init_image, **kargs))

        
###############################



# def get_close_variations_from_prompt(prompt, n_variations=2, image_consistency=.7):
#     """
#     prompt: a text prompt
#     n_variations: total number of images to return
#     image_consistency: float in [0,1], controls similarity between images generated by the prompt.
#                         you can think of this as controlling how much "visual vibration" there will be.
#                         - 0=regenerate each iandely identical
#     """
#     images = list(get_image_for_prompt(prompt))
#     for _ in range(n_variations - 1):
#         img = get_variations_w_init(prompt, images[0], start_schedule=(1-image_consistency))[0]
#         images.append(img)
#     return images


##################
##  PARAMETERS  ##
##################

d_ = dict(
    _=''
    #, theme_prompt = " by peter gabriel | art by radiohead" # @param {type:'string'}
    , height = 512 # @param {type:'integer'}
    , width = 512 # @param {type:'integer'}
    , display_frames_as_we_get_them = True # @param {type:'boolean'}
)
d_.pop('_')

regenerate_all_init_images = True # @param {type:'boolean'}

# TODO: make this an integer
prompt_lag = True # @param {type:'boolean'}

# @markdown `theme_prompt` - Text that will be appended to the end of each lyric, useful for e.g. applying a consistent aesthetic style

# @markdown `display_frames_as_we_get_them` - Displaying frames will make the notebook slightly slower

# regenerate all images if the theme prompt has changed or user specifies

# @markdown `prompt_lag` - Extend prompt with lyrics from previous frame. Can improve temporal consistency of narrative. 
# @markdown  Especially useful for lyrics segmented into short prompts.

#d_['theme_prompt'] = [prompt.strip() for prompt in d_['theme_prompt'].split('|')]

#TODO: check from storyboard
#if d_['theme_prompt'] != storyboard.params.get('theme_prompt'):
#    regenerate_all_init_images = True

display_frames_as_we_get_them = d_.pop('display_frames_as_we_get_them')
storyboard.params.update(d_)

if regenerate_all_init_images:
    for i, rec in enumerate(prompt_starts):
        rec['frame0_fpath'] = None
        archive_images(i, root=root)
    print("archival process complete")

# anchor images will be regenerated if there's no associated frame0_fpath
# regenerate specific images if
# * manually tagged by user in df_regen
# * associated fpath doesn't exist (i.e. deleted)
if 'df_regen' in locals():
    for i, _ in df_regen.iterrows():
        rec = prompt_starts[i]
        regen = not _['keep']
        if rec.get('frame0_fpath') is None:
            regen = True
        elif not Path(rec['frame0_fpath']).exists():
            regen=True
        if regen:
            rec['frame0_fpath'] = None
            #rec['prompt'] = df_regen.loc[i, 'Lyric']
            #rec['override_prompt'] = df_regen.loc[i, 'override_prompt']
            print(rec)
            archive_images(i, root=root)
    print("archival process complete")


theme_prompts = storyboard.params.theme_prompt
#display_frames_as_we_get_them = storyboard.params.display_frames_as_we_get_them
height = storyboard.params.height
width = storyboard.params.width

proj_name = workspace.active_project

## Poking around? This is probably the code you're looking for. ##

print("Ensuring each prompt has an associated image")
for idx, rec in enumerate(prompt_starts):
    print(idx, rec)
    #theme = rec.get('theme_prompt')
    theme = rec.get('_theme')
#     if not theme and theme_prompts:
#         theme = theme_prompts[idx % len(theme_prompts)]
#         #print(
#             #f"len(theme_prompts): {len(theme_prompts)}\n"
#             #f"idx: {idx}\n"
#             #f"theme index: {len(theme_prompts) % (idx+1) -1}\n"
#             #f"old theme index: {len(theme_prompts) % (idx+1) -1}\n"
#             #f"new theme index: {idx % len(theme_prompts)}\n"
#             # )
#     #else:
#     #    print(f"theme prompt from storyboard: {theme}")
        
    prompt = rec.get('prompt')
    if not prompt:
        prompt = f"{rec['text']}, {theme}"
    
    #override = rec.get('override_prompt','').strip()
    #if override:
    #    print('override prompt detected')
    #    prompt = override

        if prompt_lag and (idx > 0):
            rec_prev = prompt_starts[idx -1]
            prev_text = rec_prev.get('text')
            if not prev_text:
                prev_text = rec_prev.get('prompt').split(',')[0]
            this_text = rec.get('text')
            if this_text:
                prompt = f"{prev_text} {this_text}, {theme}"
            else:
                prompt = rec_prev['_prompt']
    rec['_prompt'] = prompt
    
    print(
        f"scene: {idx}\t time: {rec['start']}\n"
        f"spoken text: {rec.get('text')}\n"
        f"image prompt: {rec['_prompt']}\n"
    )
    if rec.get('frame0_fpath') is None:
        init_image = list(get_image_for_prompt(
              rec['_prompt'],
              height=height,
              width=width,
              )
          )[0]
    # this shouldn't be necessary, but is a consequence of
    # the globbing thing we're doing atm
    if 'anchor' not in str(rec.get('frame0_fpath')):
        # to do: save_frame doesn't need to be a function.
        rec['frame0_fpath'] = save_frame(
            init_image,
            idx,
            root_path = root / 'frames',
            name='anchor',
            )

        if display_frames_as_we_get_them:
            #print(lyric)
            print(rec.get('text'))
            display(init_image)


##############
# checkpoint #
##############

prompt_starts_copy = copy.deepcopy(prompt_starts)
storyboard.prompt_starts = prompt_starts_copy

with open(storyboard_fname) as fp:
    OmegaConf.save(config=storyboard, f=fp.name)


###############
# flag regens #
###############

# @markdown ---
# @markdown NB: When this cell finishes running, a table will appear at the bottom of the output window. This table is editable and can be used to correct errors in the transcription (see above).

# @markdown Additionally, this table can be used to trigger regeneration of
# @markdown images you don't want to keep. On the far left of the table, you
# @markdown you should see a `keep` column that defaults to "true". Double 
# @markdown clicking this value should flip it to "false". Rerunning this cell
# @markdown will regenerate the `init_image` for all scenes where `keep=false`.
# @markdown Images that are flagged for regeneration will be moved to the
# @markdown project's `archive` folder.

# @markdown Image regeneration can also be triggered by deleting the image from 
# @markdown the `frames` folder.


df_regen = pd.DataFrame(prompt_starts)
#if 'override_prompt' not in df_regen:
#    df_regen['override_prompt'] = ''

#df_regen = df_regen[['ts','prompt','override_prompt']].rename(
#     columns={
#         'ts':'Timestamp (sec)',
#         'prompt':'Lyric',
#     }
# )


# TODO: clean up logging


df_regen['keep'] = True

# move the "keep" column to the front
#df_regen= df_regen[['keep', 'Timestamp (sec)', 'Lyric', 'override_prompt']]
df_regen= df_regen.rename(columns={'_prompt':'prompt'})[['keep', 'start', 'text', 'prompt']]

pn.widgets.Tabulator(
    df_regen,
    formatters={'bool': BooleanFormatter()},
    editors={'bool':CheckboxEditor()}
    )


In [None]:
# For outputting a storyboard for ChatGPT prompting

# import rich
# import json
# #rich.print(dict(storyboard))

# outter_cols = ['audio_fpath','video_duration']

# cols = ['start','text', 'structural_segmentation_label'] #,'no_speech_prob','avg_logprob','compression_ratio']
# recs = [{k:rec[k] for k in cols} for rec in storyboard['prompt_starts'] if rec.get('inferred_subscene') is None]
# recs = [{k:storyboard.params[k]} for k in outter_cols] + recs
# rich.print(recs)

In [None]:

# @title ## 🚀 Generate animation frames

###################
# improved resume #
###################

import copy
import datetime as dt
from itertools import cycle
from pathlib import Path

from omegaconf import OmegaConf
from PIL import Image


workspace = OmegaConf.load('config.yaml')
root = Path(workspace.project_root)

storyboard_fname = root / 'storyboard.yaml'
storyboard = OmegaConf.load(storyboard_fname)

## development
#storyboard = OmegaConf.load('/home/dmarx/projects/video-killed-the-radio-star/nirvana-sold-structure/storyboard.yaml')


if not 'prompt_starts' in locals():
    prompt_starts = OmegaConf.to_container(storyboard.prompt_starts)
else:
    ##########################
    # checkpoint any changes #
    ##########################
    prompt_starts_copy = copy.deepcopy(prompt_starts)

    storyboard.prompt_starts = prompt_starts_copy

    with open(storyboard_fname) as fp:
        OmegaConf.save(config=storyboard, f=fp.name)


#################################################
# Math                                          #
#                                               #
#    This block computes how many frames are    #
#    needed for each segment based on the start #
#    times for each prompt                      #
#################################################

# TODO: leverage previous beat detection, onsets, etc. for frame timings

# TODO: experiment with tying instantaneous framerate to a music attribute (i.e. so motion can change speed mid animation)

fps = 16 # @param {type:'integer'}
storyboard.params.fps = fps

ifps = 1/fps

# estimate video end
video_duration = storyboard.params['video_duration']

# dummy prompt for last scene duration
prompt_starts = OmegaConf.to_container(storyboard.prompt_starts)
prompt_starts.append({'start':video_duration})

# make sure we respect the duration of the previous phrase
frame_start=0
prompt_starts[0]['anim_start']=frame_start
for i, rec in enumerate(prompt_starts[1:], start=1):
    rec_prev = prompt_starts[i-1]
    k=0
    while (rec_prev['anim_start'] + k*ifps) < rec['start']:
        k+=1
    k-=1
    rec_prev['frames'] = k
    rec_prev['anim_duration'] = k*ifps
    frame_start+=k*ifps
    rec['anim_start']=frame_start

# drop the dummy frame
prompt_starts = prompt_starts[:-1]

# to do: given a 0 duration prompt, assume its duration is captured in the next prompt 
#        and guesstimate a corrected prompt start time and duration 


##############
# checkpoint #
##############

prompt_starts_copy = copy.deepcopy(prompt_starts)

storyboard.prompt_starts = prompt_starts_copy

with open(storyboard_fname) as fp:
    OmegaConf.save(config=storyboard, f=fp.name)


##################################
# Generate animation frames #
##################################

d_ = dict(
    _=''
    , n_variations=8 # @param {type:'integer'}
    , image_consistency=0.72 # @param {type:"slider", min:0, max:1, step:0.01}  
    , max_video_duration_in_seconds = 300 # @param {type:'integer'}
)
d_.pop('_')


# @markdown `fps` - Frames-per-second of generated animations

# @markdown `n_variations` - How many unique variations to generate for a given text prompt. This determines the frequency of the visual "pulsing" effect

# @markdown `image_consistency` - controls similarity between images generated by the prompt.
# @markdown - 0: ignore the init image
# @markdown - 1: true as possible to the init image

# @markdown `max_video_duration_in_seconds` - Early stopping if you don't want to generate a video the full duration of the provided audio. Default = 5min.


storyboard.params.update(d_)
storyboard.params.max_frames = storyboard.params.fps * storyboard.params.max_video_duration_in_seconds

display_frames_as_we_get_them=True
#display_frames_as_we_get_them = storyboard.params.display_frames_as_we_get_them

image_consistency = storyboard.params.image_consistency
max_frames = storyboard.params.max_frames

n_variations = storyboard.params.n_variations
#theme_prompt = storyboard.params.get('theme_prompt')


# load init_images and generate variations as needed
# to do: request multiple images in single request
print("Fetching variations")
for idx, rec in enumerate(prompt_starts):
    new_images = []
    images_fpaths = get_image_sequence(idx, root=root)
    curr_variation_count = len(images_fpaths)
    print(f"curr_variation_count:{curr_variation_count}")
    if curr_variation_count < n_variations:
        prompt = rec['_prompt']

        init_image = Image.open(rec['frame0_fpath'])
        # TODO: user should be able to specify basically anything per-entry
        # next line is here to permit user to specify more variations for a specific entry
        tot_variations = rec.get('n_variations', n_variations)
        tot_variations = min(tot_variations, rec['frames']) # don't generate variations we won't use
        print(f"tot_variations:{tot_variations}")
        tot_variations -= curr_variation_count  # only generate variations we still need
        print(f"tot_variations to request:{tot_variations}")
        
        # why do we have a scene with 0 frames? something funny with start times overlapping i think.
        # seems like the subsegmentation thing didn't fully update this part properly?
        if tot_variations < 1:
            continue
            
        #for _ in range(tot_variations):
        #    img = get_variations_w_init(prompt, init_image, start_schedule=(1-image_consistency))[0]
        image_variations = get_variations_w_init(
            prompt=prompt, 
            init_image=init_image, 
            image_consistency=image_consistency,
            n_variations=tot_variations,
        )
        for img in image_variations:
            save_frame(
                img,
                idx,
                root_path= root / 'frames',
            )
            if display_frames_as_we_get_them:
                display(img)


# TODO: last frame not generating variations for some reason
                
##############
# checkpoint #
##############

prompt_starts_copy = copy.deepcopy(prompt_starts)

storyboard.prompt_starts = prompt_starts_copy

with open(storyboard_fname) as fp:
    OmegaConf.save(config=storyboard, f=fp.name)

# @markdown ---
# @markdown Running this cell will generate as many variation frames as required 
# @markdown per `n_variations`. To trigger regeneration of images that didn't
# @markdown generate correctly (e.g. because a nsfw classifier was triggered),
# @markdown just delete those images.

In [None]:
# @title ## 📺 Compile your video and enjoy your animation!

import shutil

# to do: skip tsp if n_variations ==1

from pathlib import Path
from PIL import Image
from itertools import cycle

from omegaconf import OmegaConf
from tqdm.autonotebook import tqdm

try: 
    import google.colab
    local=False
except:
    local=True

try:
    from prefetch_generator import BackgroundGenerator
except:
    !pip install prefetch_generator
    from prefetch_generator import BackgroundGenerator

# reload config
workspace = OmegaConf.load('config.yaml')
root = Path(workspace.project_root)

# storyboard_fname = root / 'storyboard.yaml'
# storyboard = OmegaConf.load(storyboard_fname)

########################
# rendering parameters #
########################

output_filename = 'output.mp4' # @param {type:'string'}
#add_caption = False # @param {type:'boolean'}
add_caption = True
## TODO: DEBUGGING
optimal_ordering = True # @param {type:'boolean'} 
#optimal_ordering = False # @param {type:'boolean'}
#upscale = True # @param {type:'boolean'}
upscale = False

download_video = True # @param {type:'boolean'}

# TODO: add inference of `local` again

# @markdown NB: Your video will probably download way faster from https://drive.google.com


# @markdown `add_caption` - Whether or not to overlay the prompt text on the image

# @markdown `optimal_ordering` - Intelligently permutes animation frames to provide a smoother animation.

# @markdown  `upscale`: Naively (lanczos interpolation) upscale video 2x. This can be a way to force
# @markdown  services like youtube to deliver your video without mangling it with compression
# @markdown  artifacts. Thanks [@gandamu_ml](https://twitter.com/gandamu_ml) for this trick!

# TODO: make sure image is being written to correct location?
# I think it might be more efficient to write the video to the local disk first, then move it
# afterwards, rather than writing into google drive
final_output_filename = str( root / output_filename )
storyboard.params.output_filename = final_output_filename

# to do: move/duplicate fps computations here (?)
fps = storyboard.params.fps


#####################################

import time

import numpy as np
from scipy.spatial.distance import pdist, squareform
from itertools import cycle
from python_tsp.exact import solve_tsp_dynamic_programming

import textwrap
from PIL import Image, ImageDraw, ImageFont

# TODO: GPU acceleration
def tsp_sort(frames):
    frames_m = np.array([np.array(f).ravel() for f in frames])
    dmat = pdist(frames_m, metric='cosine')
    dmat = squareform(dmat)
    permutation, _ = solve_tsp_dynamic_programming(dmat)
    return permutation

def add_caption2image(
      image, 
      caption, 
      text_font='LiberationSans-Regular.ttf', 
      font_size=20,
      fill_color=(255, 255, 255),
      stroke_color=(0, 0, 0), #stroke_fill
      stroke_width=2,
      align='center',
      ):
    # via https://stackoverflow.com/a/59104505/819544
    wrapper = textwrap.TextWrapper(width=50) 
    word_list = wrapper.wrap(text=caption) 
    caption_new = ''
    for ii in word_list[:-1]:
        caption_new = caption_new + ii + '\n'
    caption_new += word_list[-1]

    draw = ImageDraw.Draw(image)

    # Download the Font and Replace the font with the font file. 
    font = ImageFont.truetype(text_font, size=font_size)
    w,h = draw.textsize(caption_new, font=font, stroke_width=stroke_width)
    W,H = image.size
    x,y = 0.5*(W-w),0.90*H-h
    draw.text(
        (x,y), 
        caption_new,
        font=font,
        fill=fill_color, 
        stroke_fill=stroke_color,
        stroke_width=stroke_width,
        align=align,
    )

    return image


# prep everything...
ffmpeg_cmd_script = ""
for idx, rec in enumerate(storyboard.prompt_starts):
    if 'frame_order' not in rec:
        im_paths = get_image_sequence(idx, root)
        #if not im_paths:
        #    print(idx)
        #    print(rec)
        #    raise
        # to do: persist the ordering in the storyboard
        if optimal_ordering:
            print(f"computing frame order for scene {idx}")
            images = [Image.open(fp) for fp in im_paths]
            try:
                frame_order = tsp_sort(images)
                im_paths = [im_paths[j] for j in frame_order]
                images = [images[j] for j in frame_order]
            except ValueError:
                pass
        # TODO: actually persist frame order to storyboard...
        rec['frame_order'] = im_paths
    else:
        im_paths = rec['frame_order']

    images = [Image.open(fp) for fp in im_paths]

    if add_caption:
        new_paths = []
        #images_captioned = [add_caption2image(im, rec['prompt']) for im in images]
        #images_captioned = [add_caption2image(im, rec['text']) for im in images]
        #for fp, im in zip(im_paths, images_captioned):
        for fp, im in zip(im_paths, images):
            fp = Path(fp)
            #fp = fp.with_stem(fp.stem + '-captioned')
            fp = fp.parent / 'captioned' / fp.name
            fp.parent.mkdir(exist_ok=True, parents=True)
            if not rec.get('inferred_subscene', False):
                im = add_caption2image(im, rec['text'])
            im.save(fp)
            new_paths.append(fp)
        im_paths = new_paths
    
    frame_picker = cycle(im_paths)
    for _ in range(rec.frames):
        fpath = Path(next(frame_picker))
        ffmpeg_cmd_script += f"file '{fpath.absolute()}'\nduration {1/fps}\n"
    
    with open(root/'scenes.txt', 'w') as f:
        f.write(ffmpeg_cmd_script)


if upscale:
    height=storyboard.params.height
    width=storyboard.params.width
    !ffmpeg -y -f concat -safe 0 -i {root/'scenes.txt'} -i "{storyboard.params.audio_fpath}" -r {storyboard.params.fps} -pix_fmt yuv420p -crf 25 -preset veryslow -vf scale={2*width}x{2*height}:flags=lanczos -shortest {storyboard.params.output_filename}
else:
    !ffmpeg -y -f concat -safe 0 -i {root/'scenes.txt'} -i "{storyboard.params.audio_fpath}" -r {storyboard.params.fps} -pix_fmt yuv420p -crf 25 -preset veryfast -shortest {storyboard.params.output_filename}


# EASTER EGG FEATURE
#  NB: only embed short videos
embed_video_in_notebook = False

output_filename = storyboard.params.output_filename

if download_video and not local:
    from google.colab import files
    files.download(output_filename)

if embed_video_in_notebook:
    from IPython.display import display, Video
    display(Video(output_filename, embed=True))

In [None]:
!ffmpeg -y -f concat -safe 0 -i {root/'scenes.txt'} -vn -i "{video_assets_meta_record['video_fpath']}" -r {storyboard.params.fps} -pix_fmt yuv420p -crf 25 -preset veryfast -shortest test_alt_audio.mp4

# ⚖️ I put on my robe and lawyer hat

### Notebook license

This notebook and the accompanying [git repository](https://github.com/dmarx/video-killed-the-radio-star/) and its contents are shared under the MIT license.

<!-- Note to self: lawyers should really be forced to use some sort of markup or pseudocode to eliminate ambiguity 

...oh shit, if laws were actually described in code, we could just run queries against it
-->

```
MIT License

Copyright (c) 2022 David Marx

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
```

### DreamStudio API TOS

The default behavior of this notebook uses the [DreamStudio](https://beta.dreamstudio.ai/) API to generate images. Users of the DreamStudio API are subject to the DreamStudio usage terms: https://beta.dreamstudio.ai/terms-of-service

### Stable Diffusion

As of the date of this writing (2022-09-29), all publicly available model checkpoints are subject to the restrictions of the Open RAIL license: https://huggingface.co/spaces/CompVis/stable-diffusion-license. 

