Generate a music video that is a latent space interpolation of stylegan3 - you define a story in form of a list of prompts, from these prompts, we generate images, and then interpolate between them in the rythm of the music.

By [nielsrolf](github.com/nielsrolf) inspired by [mikaelalafriz's](https://github.com/mikaelalafriz) [lucid-sonic-dreams](https://github.com/mikaelalafriz/lucid-sonic-dreams), using [StyleGAN-3](https://github.com/NVlabs/stylegan3) by NVidia and OpenAIs [CLIP](https://github.com/openai/CLIP), with code for finding the latent points that tell the story taken from nshepperd and Katherine Crowson.

In [None]:
# Topic that the video starts with
story = "Night sky->Sunrise in Athens->An antique Greek statue->A beauitful woman->Love->A broken heart->A man crying->A creepy psychedelic experience->Death and destroyment->Darkness" #@param {type: "string"}

text_prompt_bass = 'Mysterious and deep, violet' #@param {type: "string"}
text_prompt_treble = "Dreamy and warm, full of love - trending on ArtStation" #@param {type: "string"}
text_prompt_mids = "" #@param {type: "string"}

style_suffix = "trending on ArtStation" #@param {type: "string"}


# How much of a puslating effect the bass creates. A bass sound moves to the text_prompt_bass and moves back once it is released
bass_pulse_impact = 0.06 #@param {type: "number"}
# How much of a puslating effect the mids create. A mid sound moves to the text_prompt_mids and moves back once it is released
mids_pulse_impact = 0 #@param {type: "number"}
# How much of a puslating effect a high pitched sound creates. A treble sound moves to the text_prompt_trebles and moves back once it is released
trebles_pulse_impact = 0.1 #@param {type: "number"}

# How much the bass pushes the story forward
bass_story_speed = 1 #@param {type: "number"}
# How much the mids push the story forward
mids_story_speed = 1 #@param {type: "number"}
# How much the trebles push the story forward
trebles_story_speed = 0 #@param {type: "number"}


# It can take quite long to generate a 6min video, use these inputs to make the video shorter
start_second =  0#@param {type: "number"}
end_second = 360 #@param {type: "number"}

# Music upload
audio_file = "" #@param {type: "string"}

# alternative link to youtube or soundcloud
youtube_dl_link = "https://soundcloud.com/lautundluise/christopher-schwarzwalder-das" #@param {type: "string"}

# Speed at which to try approximating the text. Too fast seems to give strange results. Maximum is 100.
speed = 20  #@param {type: "number"}

# Change the seed to generate variations of the same prompt 
seed = 244 #@param {type: "number"}

# Model type to use.
model = 'Landscapes'  #@param ['Painted Faces', 'Animal Faces', 'Flickr Faces', 'Wiki Art', 'Landscapes']


model_map = {
    'Painted Faces': 'stylegan3-r-metfacesu-1024x1024.pkl', 
    'Animal Faces': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl', 
    'Flickr Faces': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl',
    'Wiki Art': 'https://pollinations.ai/ipfs/QmdxNyN5pDbaesaXbaBtEGGitks4nyaXCDH8cnwWjyivEx/wikiart-1024-stylegan3-t-17.2Mimg.pkl',
    'Landscapes': 'https://ipfs.io/ipfs/QmZkrYwEUnykVQJfJw3opTj1HfdNUCm87amsR3LHp1QnuV/lhq-256-stylegan3-t-25Mimg.pkl'
}


output_path = "/content/output"

steps = 150

#@markdown ---




In [None]:
prompts = story.split("->")
prompts = [f"{prompt} - {style_suffix}" if prompt!="" else "" for prompt in prompts]
prompts

In [None]:
model_url = model_map[model]
smoothing = (100.0-speed)/100.0

summed_speed = bass_story_speed + mids_story_speed + trebles_story_speed
bass_story_speed /= summed_speed
mids_story_speed /= summed_speed
trebles_story_speed /= summed_speed

In [None]:
if youtube_dl_link.startswith("http"):
  print(f"Downloading from {youtube_dl_link}...")
  !pip install -q youtube-dl
  !youtube-dl --rm-cache-dir
  !youtube-dl --extract-audio --audio-format wav {youtube_dl_link} --output /tmp/audio_file.wav
  audio_file = "/tmp/audio_file.wav"
  from glob import glob
  print(glob("/tmp/*.wav"))


import librosa
import librosa.display
from matplotlib import pyplot as plt
import numpy as np

def specshow(spec):
  fig, ax = plt.subplots()
  img = librosa.display.specshow(spec, x_axis='time',
                         y_axis='mel', sr=sr,
                         fmax=8000, ax=ax)
  plt.show()


audio, sr = librosa.load(audio_file)
audio = audio[start_second*sr:end_second*sr]
spec = librosa.feature.melspectrogram(y=audio, sr=sr)[:,::2]
specshow(spec)
# spec_s = np.log(1 + spec*10)
spec_s = librosa.amplitude_to_db(spec)
spec_s = spec_s - spec_s.min()
specshow(spec_s)

seconds = len(audio) / sr
frame_rate = spec_s.shape[1] / seconds

In [None]:
def specshow(spec):
  fig, ax = plt.subplots()
  img = librosa.display.specshow(spec, x_axis='time',
                         y_axis='mel', sr=sr,
                         fmax=8000, ax=ax)
  plt.show()
  plt.figure(figsize=(15, 7))
  ld = spec.sum(0)
  plt.plot(ld)
  plt.show()

mids = spec_s[12:-35]

bass = spec_s[:12]
bass = bass - bass.mean()
bass[bass<0] = 0

treble = spec_s[-35:]
treble = treble - treble.mean() / 2
treble[treble<0] = 0

specshow(bass)
specshow(mids)
specshow(treble)


In [None]:
def get_spec_slice(spec, i):
  tsteps = spec.shape[1]
  chapters = len(prompts) - 1
  chapter_len = tsteps / chapters
  start = int(i * chapter_len)
  end = int((i+1) * chapter_len)
  return spec[:,start:end]


In [None]:
#@title Licensed under the MIT License { display-mode: "form" }

# Copyright (c) 2021 nshepperd; Katherine Crowson

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

In [None]:
# Check GPU and CUDA
!nvidia-smi
!nvcc --version

In [None]:
len(audio)/sr

In [None]:
!pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
#!pip install --upgrade https://download.pytorch.org/whl/nightly/cu111/torch-1.11.0.dev20211012%2Bcu111-cp37-cp37m-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu111/torchvision-0.12.0.dev20211012%2Bcu111-cp37-cp37m-linux_x86_64.whl
!git clone https://github.com/NVlabs/stylegan3
!git clone https://github.com/openai/CLIP
!pip install -e ./CLIP
!pip install einops ninja

In [None]:
import sys
sys.path.append('./CLIP')
sys.path.append('./stylegan3')

import io
import os, time
import pickle
import shutil
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import requests
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import clip
from tqdm.notebook import tqdm
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from IPython.display import display
from einops import rearrange
from google.colab import files

In [None]:
device = torch.device('cuda:0')
print('Using device:', device, file=sys.stderr)
torch.manual_seed(seed)

In [None]:
# Define necessary functions

def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')

def fetch_model(url_or_path):
    basename = os.path.basename(url_or_path)
    if os.path.exists(basename):
        return basename
    else:
        !wget -N '{url_or_path}'
        return basename

def norm1(prompt):
    "Normalize to the unit sphere."
    return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()

def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)

class MakeCutouts(torch.nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
        return torch.cat(cutouts)

make_cutouts = MakeCutouts(224, 32, 0.5)

def embed_image(image):
  n = image.shape[0]
  cutouts = make_cutouts(image)
  embeds = clip_model.embed_cutout(cutouts)
  embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
  return embeds

def embed_url(url):
  image = Image.open(fetch(url)).convert('RGB')
  return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)

In [None]:
class CLIP(object):
  def __init__(self):
    clip_model = "ViT-B/32"
    self.model, _ = clip.load(clip_model)
    self.model = self.model.requires_grad_(False)
    self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                          std=[0.26862954, 0.26130258, 0.27577711])

  @torch.no_grad()
  def embed_text(self, prompt):
      "Normalized clip text embedding."
      return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())

  def embed_cutout(self, image):
      "Normalized clip image embedding."
      return norm1(self.model.encode_image(self.normalize(image)))
  
clip_model = CLIP()

In [None]:
# Load stylegan model

network_url = model_url

with open(fetch_model(network_url), 'rb') as fp:
  G = pickle.load(fp)['G_ema'].to(device)

# # Fix the coordinate grid to w_avg
# shift = G.synthesis.input.affine(G.mapping.w_avg.unsqueeze(0))
# G.synthesis.input.affine.bias.data.add_(shift.squeeze(0))
# G.synthesis.input.affine.weight.data.zero_()

# # Arbitrary coordinate grid (dubious idea)
# with torch.no_grad():
#   grid = G.synthesis.input(G.mapping.w_avg.unsqueeze(0))
#   def const(x):
#     def f(w):
#       n = w.shape[0]
#       return x.broadcast_to([n, *x.shape[1:]])
#     return f
#   G.synthesis.input.forward = const(grid)
# grid.requires_grad_()

zs = torch.randn([10000, G.mapping.z_dim], device=device)
w_stds = G.mapping(zs, None).std(0)

In [None]:
# Get latents for start and end
from functools import lru_cache


tf = Compose([
  Resize(224),
  lambda x: torch.clamp((x+1)/2,min=0,max=1),
  ])


@lru_cache(maxsize=None)
def get_latents_for(text_prompt):
  target = clip_model.embed_text(text_prompt)
  torch.manual_seed(seed)
  timestring = time.strftime('%Y%m%d%H%M%S')

  # Init
  # Method 1: sample 32 inits and choose the one closest to prompt

  with torch.no_grad():
    qs = []
    losses = []
    for _ in range(8):
      q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
      images = G.synthesis(q * w_stds + G.mapping.w_avg)
      embeds = embed_image(images.add(1).div(2))
      loss = spherical_dist_loss(embeds, target).mean(0)
      i = torch.argmin(loss)
      qs.append(q[i])
      losses.append(loss[i])
    qs = torch.stack(qs)
    losses = torch.stack(losses)
    print(losses)
    print(losses.shape, qs.shape)
    i = torch.argmin(losses)
    q = qs[i].unsqueeze(0).requires_grad_()

  # Method 2: Random init depending only on the seed.

  # q = (G.mapping(torch.randn([1,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
  # q.requires_grad_()

  # Sampling loop
  q_ema = q
  opt = torch.optim.AdamW([q], lr=0.1, betas=(0.0,0.999))
  loop = tqdm(range(steps))
  for i in loop:
    opt.zero_grad()
    w = q * w_stds
    image = G.synthesis(w + G.mapping.w_avg, noise_mode='const')
    embed = embed_image(image.add(1).div(2))
    loss = spherical_dist_loss(embed, target).mean()
    loss.backward()
    opt.step()
    loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())

    # q_ema = q_ema * smoothing + q * (1-smoothing)
    # image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const')

  return (w + G.mapping.w_avg).detach()


latent_bass = get_latents_for(text_prompt_bass)
latent_treble = get_latents_for(text_prompt_treble)

latent_story = [get_latents_for(prompt) for prompt in prompts]

In [None]:
def slerp(val, low, high):
  val = val[:,None]
  shape = low.shape
  low = low.reshape([low.shape[0], -1])
  high = high.reshape([high.shape[0], -1])

  low_ = low / torch.norm(low, dim=1, keepdim=True)
  high_ = high / torch.norm(high, dim=1, keepdim=True)
  omega = torch.arccos(torch.clip(torch.sum(low_*high_, axis=1, keepdim=True), -1, 1))
  so = torch.sin(omega)
  so = 0*low + 0*val + so # broadcast
  slerped = (1.0-val) * low + val * high
  slerped_1 = torch.sin((1.0-val)*omega) / so * low + torch.sin(val*omega) / so * high
  slerped[so!=0] = slerped_1[so!=0]
  return slerped.reshape([slerped.shape[0]] + list(shape[1:]))

In [None]:
latent_chapters = []
for chapter in range(len(prompts)-1):
  latent_start = latent_story[chapter]
  latent_end = latent_story[chapter + 1]
  story_speed = get_spec_slice(bass, chapter).sum(0) * bass_story_speed \
              + get_spec_slice(mids, chapter).sum(0) * mids_story_speed \
              + get_spec_slice(treble, chapter).sum(0) * trebles_story_speed
  print(story_speed.shape)
  story_speed = story_speed / story_speed.sum()
  progress = torch.tensor(story_speed.cumsum(0)).to(device)
  latent_chapters.append(slerp(progress, latent_start, latent_end))

latent_story = torch.cat(latent_chapters, dim=0)
latent_story.shape

In [None]:


bright = treble.sum(0)
bright = (bright - bright.min()) / (bright.max() - bright.min())
bright = bright * trebles_pulse_impact
bright = torch.tensor(bright).to(device)

middle = mids.sum(0)
middle = (middle - middle.min()) / (middle.max() - middle.min())
middle = middle * bass_pulse_impact
middle = torch.tensor(middle).to(device)

deep = bass.sum(0)
deep = (deep - deep.min()) / (deep.max() - deep.min())
deep = deep * bass_pulse_impact
deep = torch.tensor(deep).to(device)

In [None]:

latent_story = slerp(bright, latent_story, latent_treble)
# latent_story = slerp(middle, latent_story, latent_treble)
latents = slerp(deep, latent_story, latent_bass)

In [None]:
del clip
del clip_model

In [None]:
import torchvision

In [None]:
!rm -rf parts

In [None]:
os.makedirs("parts", exist_ok=True)
def write_frames(frames, start_id):
  for offset in range(len(frames)):
    frame_id = start_id + offset
    TF.to_pil_image(frames[offset]).save(f'parts/output_{frame_id:08}.jpg')

batch = 10
start_frame = 0
end_frame = batch
part_id = 0
while end_frame < len(latents):
  frames = G.synthesis(latents[start_frame:end_frame], noise_mode='const').add(1).div(2).clamp(0,1).cpu().detach()
  write_frames(frames, start_frame)
  start_frame = end_frame
  end_frame += batch

In [None]:

import soundfile as sf
sf.write("audio_cut.wav", audio, sr)

In [None]:
!rm *.avi
!ffmpeg  -r {frame_rate} -i parts/%*.jpg -y -c:v libx264 vid_no_audio.avi
!ffmpeg -i audio_cut.wav -i vid_no_audio.avi final_video.avi

In [None]:
!ffmpeg -i final_video.avi -vf "scale=trunc(iw/4)*2:trunc(ih/4)*2" -c:v libx265 -crf 28 vid7.mkv

In [None]:
!mkdir {output_path}
!cp final_video.avi {output_path}