<a href="https://colab.research.google.com/github/melismeric/playlist2image/blob/main/Spotify_Playlist_to_Image_(VQGAN%2BCLIP)_%5BPixtape%5D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

![](https://media.giphy.com/media/7NBy1JuBHskLFqymik/giphy-downsized-large.gif) ![](https://media.giphy.com/media/D5b0iTxU5OOO2DpiPW/giphy-downsized-large.gif) ![](https://media.giphy.com/media/SSOLzb8jZUpEumiLCj/giphy-downsized-large.gif) 
# Spotify Playlist to Image Generation with VQGAN and CLIP.


Notebook for creating images and videos from Spotify playlists according to song names and mood of the playlist.

This notebook is built on research and code by Justin John (https://github.com/justinjohn0306)and Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). The original BigGAN + CLIP method was made by https://twitter.com/advadnoun. 

Modified by: Melis Meriç (https://github.com/melismeric)

In [None]:
# @markdown Type your playlist url and customize the images

playlist = "https://open.spotify.com/playlist/7rN9qgoHNc8CEn5DeTjH60?si=eb0e4ce2610f47c7" #@param {type:"string"}

style = "Fauvist" #@param ["Cubist", "Fauvist", "Impressionist", "Flat", "Brutalist", "Art deco", "Baroque", "Bauhaus", "Classic", "Constructivist", "Dadaist", "Futurist", "Renaissance", "Minimalist", "Pop art", "Rococo", "Surrealist", "Suprematist", "Vitrai", "Pointilist"]
medium = "watercolor painting" #@param ["watercolor painting", "oil on canvas painting", "illustration", "album cover", "poster art", "collage", "Instax", "glass"]
abstract = True #@param {type:"boolean"}
lighting = "glowing neon" #@param ["Volumetric lighting" , "global illumination", "sunrays shine upon it", "iridescent", "glowing neon", "flickering light"]
#mood_classification_method = "NN2" #@param ["NN2", "Mood Map"]

# Connecting to Spotify API and Text Construction for VQGAN+CLIP
This part is for connecting to Spotify API and getting song names and audio features of the playlist songs which will be used for prompt construction.

In [None]:
#@markdown #**Import Libraries**

!pip install spotipy 
import spotipy
from spotipy.oauth2 import SpotifyClientCredentials
import pandas as pd
!pip install webcolors==1.3
import math
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_validate
pd.options.mode.chained_assignment = None  # default='warn'


from scipy.spatial import KDTree
from webcolors import (
    css3_hex_to_names,
    html4_hex_to_names,
    hex_to_rgb,
)

# Connect to spotify api
sp = spotipy.Spotify(auth_manager=SpotifyClientCredentials(client_id="6961f944410f459e8325982c59f1b41e",
                                                           client_secret="be255af2092c410286aefa1b9f5bedf7"))

spotify = spotipy.Spotify(auth_manager=SpotifyClientCredentials(client_id="6961f944410f459e8325982c59f1b41e",client_secret="be255af2092c410286aefa1b9f5bedf7"))


# Last.fm
API_KEY = '66b80802ba47c165803085164751ea30'
USER_AGENT = 'Dataquest'

import requests
import time

headers = {
    'user-agent': USER_AGENT
}

payload = {
    'api_key': API_KEY,
    'method': 'chart.gettopartists',
    'format': 'json'
}

r = requests.get('https://ws.audioscrobbler.com/2.0/', headers=headers, params=payload)
r.status_code

In [None]:
#@markdown #**Helper Functions**

# Get Audio Features from Spotify
def get_track_features(track_ids, spotify):
    chunk_size = 50
    num_chunks = int(math.ceil(len(track_ids) / float(chunk_size)))
    features_add = []
    for i in range(num_chunks):
        chunk_track_ids = track_ids[i * chunk_size:min((i + 1) * chunk_size, len(track_ids))]
        chunk_features = spotify.audio_features(tracks=chunk_track_ids)
        features_add.extend(chunk_features)

    features_df = pd.DataFrame(features_add).drop(['id', 'analysis_url', 'key', 'time_signature',
                                                   'track_href', 'type', 'uri'], axis=1)
    features_df = features_df[features_for_mood]
    return features_df

# COLOR DETECTION
# code from https://gist.github.com/mathebox/e0805f72e7db3269ec22
def rgb_to_hsv(r, g, b):
    r = float(r)
    g = float(g)
    b = float(b)
    high = max(r, g, b)
    low = min(r, g, b)
    h, s, v = high, high, high

    d = high - low
    s = 0 if high == 0 else d/high

    if high == low:
        h = 0.0
    else:
        h = {
            r: (g - b) / d + (6 if g < b else 0),
            g: (b - r) / d + 2,
            b: (r - g) / d + 4,
        }[high]
        h /= 6

    return h, s, v

def hsv_to_rgb(h, s, v):
    i = math.floor(h*6)
    f = h*6 - i
    p = v * (1-s)
    q = v * (1-f*s)
    t = v * (1-(1-f)*s)

    r, g, b = [
        (v, t, p),
        (q, v, p),
        (p, v, t),
        (p, q, v),
        (t, p, v),
        (v, p, q),
    ][int(i%6)]

    return r, g, b

def modifyColor(d,m,r,g,b):
  (h, s, v) = rgb_to_hsv(r, g, b)
  (r, g, b) = hsv_to_rgb(h, d, (1-m)*255)
  return (round(r), round(g), round(b))

def convert_rgb_to_names(rgb_tuple):
    
    # a dictionary of all the hex and their respective names in css3
    css3_db = css3_hex_to_names
    names = []
    rgb_values = []
    for color_hex, color_name in css3_db.items():
        names.append(color_name)
        rgb_values.append(hex_to_rgb(color_hex))
    
    kdt_db = KDTree(rgb_values)
    distance, index = kdt_db.query(rgb_tuple)
    return f'{names[index]}'

def get_colour_name(rgb_triplet):
    min_colours = {}
    for key, name in html4_hex_to_names.items():
        r_c, g_c, b_c = hex_to_rgb(key)
        rd = (r_c - rgb_triplet[0]) ** 2
        gd = (g_c - rgb_triplet[1]) ** 2
        bd = (b_c - rgb_triplet[2]) ** 2
        min_colours[(rd + gd + bd)] = name
    return min_colours[min(min_colours.keys())]

## Mood Detection

russelMoods = ["calm", "angry", "anxious", "sad", "mournful", "confident", "glad", "dreamy", "hopeful", "romantic", "cheerful", "earnest", "gleeful", "brooding", "cynical", "gloomy", "aggressive", "exciting"] 
gracenoteMoods = ["Somber", "Gritty", "Serious", "Brooding", "Aggressive", "Melancholy", "Cool", "Yearning", "Urgent", "Defiant", "Sentimental", "Sophisticated", "Sensual", "Fiery", "Energizing", "Tender", "Romantic",  "Empowering", "Stirring", "Rowdy", "Peaceful", "Easygoing", "Upbeat", "Lively", "Excited"]


# source https://codereview.stackexchange.com/questions/28207/finding-the-closest-point-to-a-list-of-points
def closest_node(node, nodes):
    nodes = np.asarray(nodes)
    dist_2 = np.sum((nodes - node)**2, axis=1)
    return np.argmin(dist_2)

# better way to detect mood (https://neokt.github.io/projects/audio-music-mood-classification/)
# Gracenote mood keywords
def getMood(energy, valence):
  array = [(0.2, 0.2), (0.4, 0.2), (0.5, 0.2), (0.6, 0.2), (0.8, 0.2), (0.2, 0.4), (0.4, 0.4), (0.5, 0.4), (0.6, 0.4), (0.8, 0.4), (0.2, 0.5), (0.4, 0.5), (0.5, 0.5), (0.6, 0.5), (0.8, 0.5), (0.2, 0.6), (0.4, 0.6), (0.5, 0.6), (0.6, 0.6), (0.8, 0.6), (0.2, 0.8), (0.4, 0.8), (0.5, 0.8), (0.6, 0.8), (0.8, 0.8)]
  mood = gracenoteMoods[closest_node((energy, valence), array)]
  return mood

# Russel mood keywords
def getRusselMood(valence, energy):
  array = [(0.4, 0.1), (0.35, 0.85), (0.25, 0.75), (0.2, 0.25), (0.2, 0.4), (0.8, 0.4), (0.9, 0.6), (0.35, 0.3), (0.55, 0.45), (0.55, 0.15), (0.8, 0.65), (0.7, 0.3), (0.8, 0.8), (0.8, 0.2), (0.3, 0.55), (0.1, 0.6), (0.25, 0.9), (0.6, 0.75)]
  mood = russelMoods[closest_node((valence, energy), array)]
  return mood

def moodClassify():
  nn_preds = nn.predict(scaler.transform(songData[features_for_mood]))
  songData["nn_preds"] = nn_preds
  #return nn_pred
  
# Last.fm api
def lastfm_get(payload):
    # define headers and URL
    headers = {'user-agent': USER_AGENT}
    url = 'https://ws.audioscrobbler.com/2.0/'

    # Add API key and format to the payload
    payload['api_key'] = API_KEY
    payload['format'] = 'json'

    response = requests.get(url, headers=headers, params=payload)
    return response


def lookup_tags(artist, track):
    tags = []
    response = lastfm_get({
        'method': 'track.getTopTags',
        'artist': artist,
        'track': track
    })

    # if there's an error, just return nothing
    if response.status_code != 200:
        return None

    # extract the top three tags and turn them into a string
    tags = [t['name'] for t in response.json()['toptags']['tag']]
    #tags_str = ', '.join(tags)

    # rate limiting
    if not getattr(response, 'from_cache', False):
        time.sleep(0.25)
    return list(set(tags).intersection(russelMoods))


In [None]:
#@markdown #**Preparing Playlist Dataset**

tracks = pd.DataFrame()

# Getting spotify playlist dataset created to train the model
url = "https://drive.google.com/file/d/15U_9wgBwA8b5jMfsAfQX2RzNhEjKsTFE/view?usp=sharing"
file_id=url.split('/')[-2]
dwn_url='https://drive.google.com/uc?id=' + file_id
data = pd.read_csv(dwn_url, index_col=0)

moods = []

features_for_mood = ['energy', 'liveness', 'tempo', 'speechiness',
                                     'acousticness', 'instrumentalness', 'danceability', 'duration_ms',
                                     'loudness', 'valence','mode']

# Dataset consisting of playlists to predict the moods
playlists = [playlist]
for link in playlists:
  id = link[34:56]
  artists = []
  try:
      pl_tracks = spotify.playlist_tracks(id)['items']
      print(spotify.playlist_tracks(id)['items'])
      names = [foo['track']['name'] for foo in spotify.playlist_tracks(id)['items']]
      artists2 = [foo['track']['album']['artists'] for foo in spotify.playlist_tracks(id)['items']]

      ids = [foo['track']['id'] for foo in spotify.playlist_tracks(id)['items']]
  except:
      print (link)
      continue
  features = get_track_features(ids, spotify)
  for x in artists2:
    artists.append(x[0]['name'])
  features['id'] = ids
  features['name'] = names
  features['artist'] = artists
  features['R'] = round(features['energy'] * 255)
  features['G'] = round(features['valence'] * 255)
  features['B'] = round(features['acousticness'] * 255)
  tracks = tracks.append(features)
  


songData = pd.DataFrame(data=tracks)




In [None]:
#@markdown #**Training NN2 Model For Mood Prediction**
hyper_opt = False

#split into trainval and test
trainx, testx, trainy, testy = train_test_split(data[features_for_mood], data['mood'], test_size = 0.33,
                                                random_state = 42, stratify=data['mood'])

scaler = StandardScaler()
train_scaled = scaler.fit_transform(trainx)

nn = MLPClassifier(max_iter = 15000, alpha=1.0, hidden_layer_sizes=8)
scores = cross_val_score(nn, train_scaled, trainy, cv=5)
print ("cv score: " + str(scores.mean()))

if hyper_opt:
    params = {"alpha": np.logspace(-4, 2, 7), 'hidden_layer_sizes': [1, 2, 5, 10, 20, 50, 100]}
    clf = GridSearchCV(nn, params)
    clf.fit(train_scaled, trainy)
    print("hyperparam optimized score : " + str(clf.best_score_))
    import pdb
    pdb.set_trace()

results = cross_validate(nn, train_scaled, trainy, return_train_score=True)

nn = MLPClassifier(hidden_layer_sizes=8, max_iter=15000, alpha=1.0)
nn.fit(train_scaled, trainy)
test_preds = nn.predict(scaler.transform(testx))
accuracy_score(test_preds, testy)

In [None]:
#@markdown #**Prompt construction**

text_prompt = []
songData['Color'] = ""
songData['Mood Color'] = ""
songData['Gracenote Mood'] = ""
songData['Russel Mood'] = ""
songData['Last.fm Tag Mood'] = ""


for ind in songData.index:
  (r , g, b) = (songData['R'][ind], songData['G'][ind], songData['B'][ind])
  songData['Color'][ind] = convert_rgb_to_names((r, g, b))
  songData['Mood Color'][ind] = convert_rgb_to_names(modifyColor(songData['danceability'][ind], songData['mode'][ind],  songData['R'][ind], songData['G'][ind], songData['B'][ind]))

  nn_preds = nn.predict(scaler.transform(songData[features_for_mood]))
  songData["nn_preds"] = nn_preds
  songData['Gracenote Mood'][ind] = getMood(songData['energy'][ind], songData['valence'][ind])
  songData['Russel Mood'][ind] = getRusselMood(songData['valence'][ind], songData['energy'][ind])

  try:
    tags = lookup_tags(songData['artist'][ind], songData['name'][ind])
    tags_str = ', '.join(tags)
  except:
    tags_str = ""


  songData['Last.fm Tag Mood'][ind] = tags_str



In [None]:
# construct text prompt
for ind in songData.index:
  if (abstract):
   txt = "abstract " + style + " " + medium + " of " + songData['name'][ind] + ": 2" 
  else:
    txt = style + " " + medium + " of " + songData['name'][ind] + ": 2"
  # color 
  txt = txt + " | " + songData['Color'][ind] + " " + songData['Mood Color'][ind] + ": 0.9"

  # mood
  txt = txt + " | " + songData['nn_preds'][ind] + ": 1.8 " 
  # submood gracenote keywords
  if songData['Russel Mood'][ind] not in songData['Last.fm Tag Mood'][ind]:
    txt = txt + " | " + songData['Russel Mood'][ind] + ": 0.8 "
  # last.fm 
  if songData['Last.fm Tag Mood'][ind]:
    txt = txt + " | " + songData['Last.fm Tag Mood'][ind] + ": 0.8 "

  # lighting 
  txt = txt + " | " + lighting + ": 0.9" 

  #additional
  if (songData['liveness'][ind] > 0.6):
    txt = txt +  " | concert poster: 0.9"

  if (songData['danceability'][ind] > 0.8):
    txt = txt +  " | psychedelic: 0.9"
  
  text_prompt.append(txt)

In [None]:
#@markdown #**Download playlist cover image**
# Download playlist cover image
# If it doesn't have a cover image add a default initial image

playlist_URI = playlist.split("/")[-1].split("?")[0]

imageUrl = sp.playlist_cover_image(playlist_URI)[0]['url']

if "mosaic" not in imageUrl:
  s = imageUrl
else:
  s = "https://ik.imagekit.io/dreamco/color1.jpg"

!wget --output-document="playlistImage" $s

In [None]:
text_prompt

# VQGAN + CLIP 

In [None]:
#@markdown #**Licensed under the MIT License (*Double-click me to read the license agreement*)**
#@markdown ---

# Copyright (c) 2021 Katherine Crowson

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

In [None]:
#@markdown #**Check GPU type**
#@markdown ### Factory reset runtime if you don't have the desired GPU.

#@markdown ---




#@markdown V100 = Excellent (*Available only for Colab Pro users*)

#@markdown P100 = Very Good

#@markdown T4 = Good (*preferred*)

#@markdown K80 = Meh

#@markdown P4 = (*Not Recommended*) 

#@markdown ---

!nvidia-smi -L

In [None]:
#@markdown #**Anti-Disconnect for Google Colab**
#@markdown ## Run this to stop it from disconnecting automatically 
#@markdown  **(disconnects anyhow after 6 - 12 hrs for using the free version of Colab.)**
#@markdown  *(Pro users will get about 24 hrs usage time[depends])*
#@markdown ---

import IPython
js_code = '''
function ClickConnect(){
console.log("Working");
document.querySelector("colab-toolbar-button#connect").click()
}
setInterval(ClickConnect,60000)
'''
display(IPython.display.Javascript(js_code))

In [None]:
#@markdown #**Installation of libraries**
# @markdown This cell will take a little while because it has to download several libraries

#@markdown ---
 
print("Installing CLIP...")
!git clone https://github.com/openai/CLIP                 &> /dev/null
 
print("Installing Python Libraries for AI...")
!git clone https://github.com/CompVis/taming-transformers &> /dev/null
!pip install transformers                                 &> /dev/null
!pip install ftfy regex tqdm omegaconf pytorch-lightning==1.6.5  &> /dev/null
!pip install kornia                                       &> /dev/null
!pip install einops                                       &> /dev/null
!pip install wget                                         &> /dev/null
 
print("Installing libraries for metadata management...")
!pip install stegano                                      &> /dev/null
!apt install exempi                                       &> /dev/null
!pip install python-xmp-toolkit                           &> /dev/null
!pip install imgtag                                       &> /dev/null
!pip install pillow==7.1.2                                &> /dev/null
 
print("Installing Python libraries for creating videos...")
!pip install imageio-ffmpeg                               &> /dev/null
!mkdir steps
print("Installation completed.")

In [None]:
#@markdown #**Selection of models to download**
#@markdown ---

#@markdown By default, the notebook downloads Model 16384 from ImageNet. There are others that are not downloaded by default, since it would be in vain if you are not going to use them, so if you want to use them, simply select the models to download.

#@markdown ---

imagenet_1024 = False #@param {type:"boolean"}
imagenet_16384 = True #@param {type:"boolean"}
gumbel_8192 = False #@param {type:"boolean"}
coco = False #@param {type:"boolean"}
faceshq = False #@param {type:"boolean"}
wikiart_1024 = False #@param {type:"boolean"}
wikiart_16384 = False #@param {type:"boolean"}
sflckr = False #@param {type:"boolean"}
ade20k = False #@param {type:"boolean"}
ffhq = False #@param {type:"boolean"}
celebahq = False #@param {type:"boolean"}

if imagenet_1024:
  !curl -L -o vqgan_imagenet_f16_1024.yaml -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #ImageNet 1024
  !curl -L -o vqgan_imagenet_f16_1024.ckpt -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1'  #ImageNet 1024
if imagenet_16384:
  !curl -L -o vqgan_imagenet_f16_16384.yaml -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #ImageNet 16384
  !curl -L -o vqgan_imagenet_f16_16384.ckpt -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1' #ImageNet 16384
if gumbel_8192:
  !curl -L -o gumbel_8192.yaml -C - 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #Gumbel 8192
  !curl -L -o gumbel_8192.ckpt -C - 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fckpts%2Flast.ckpt&dl=1' #Gumbel 8192
if coco:
  !curl -L -o coco.yaml -C - 'https://dl.nmkd.de/ai/clip/coco/coco.yaml' #COCO
  !curl -L -o coco.ckpt -C - 'https://dl.nmkd.de/ai/clip/coco/coco.ckpt' #COCO
if faceshq:
  !curl -L -o faceshq.yaml -C - 'https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT' #FacesHQ
  !curl -L -o faceshq.ckpt -C - 'https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt' #FacesHQ
if wikiart_1024: 
  !curl -L -o wikiart_1024.yaml -C - 'http://mirror.io.community/blob/vqgan/wikiart.yaml' #WikiArt 1024
  !curl -L -o wikiart_1024.ckpt -C - 'http://mirror.io.community/blob/vqgan/wikiart.ckpt' #WikiArt 1024
if wikiart_16384: 
  !curl -L -o wikiart_16384.yaml -C - 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.yaml' #WikiArt 16384
  !curl -L -o wikiart_16384.ckpt -C - 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.ckpt' #WikiArt 16384
if sflckr:
  !curl -L -o sflckr.yaml -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1' #S-FLCKR
  !curl -L -o sflckr.ckpt -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1' #S-FLCKR
if ade20k:
  !curl -L -o ade20k.yaml -C - 'https://static.miraheze.org/intercriaturaswiki/b/bf/Ade20k.txt' #ADE20K
  !curl -L -o ade20k.ckpt -C - 'https://app.koofr.net/content/links/0f65c2cd-7102-4550-a2bd-07fd383aac9e/files/get/last.ckpt?path=%2F2020-11-20T21-45-44_ade20k_transformer%2Fcheckpoints%2Flast.ckpt' #ADE20K
if ffhq:
  !curl -L -o ffhq.yaml -C - 'https://app.koofr.net/content/links/0fc005bf-3dca-4079-9d40-cdf38d42cd7a/files/get/2021-04-23T18-19-01-project.yaml?path=%2F2021-04-23T18-19-01_ffhq_transformer%2Fconfigs%2F2021-04-23T18-19-01-project.yaml&force' #FFHQ
  !curl -L -o ffhq.ckpt -C - 'https://app.koofr.net/content/links/0fc005bf-3dca-4079-9d40-cdf38d42cd7a/files/get/last.ckpt?path=%2F2021-04-23T18-19-01_ffhq_transformer%2Fcheckpoints%2Flast.ckpt&force' #FFHQ
if celebahq:
  !curl -L -o celebahq.yaml -C - 'https://app.koofr.net/content/links/6dddf083-40c8-470a-9360-a9dab2a94e96/files/get/2021-04-23T18-11-19-project.yaml?path=%2F2021-04-23T18-11-19_celebahq_transformer%2Fconfigs%2F2021-04-23T18-11-19-project.yaml&force' #CelebA-HQ
  !curl -L -o celebahq.ckpt -C - 'https://app.koofr.net/content/links/6dddf083-40c8-470a-9360-a9dab2a94e96/files/get/last.ckpt?path=%2F2021-04-23T18-11-19_celebahq_transformer%2Fcheckpoints%2Flast.ckpt&force' #CelebA-HQ

In [None]:
# @title Loading libraries and definitions
 
import argparse
import math
from pathlib import Path
import sys
 
sys.path.append('./taming-transformers')
from IPython import display
from base64 import b64encode
from omegaconf import OmegaConf
from PIL import Image
from taming.models import cond_transformer, vqgan
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm
 
from CLIP import clip
import kornia.augmentation as K
import numpy as np
import imageio
from PIL import ImageFile, Image
from imgtag import ImgTag    # metadatos 
from libxmp import *         # metadatos
import libxmp                # metadatos
from stegano import lsb
import json
ImageFile.LOAD_TRUNCATED_IMAGES = True
 
def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
 
 
def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()
 
 
def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]
 
 
def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size
 
    input = input.view([n * c, 1, h, w])
 
    if dh < h:
        kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
        pad_h = (kernel_h.shape[0] - 1) // 2
        input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
        input = F.conv2d(input, kernel_h[None, None, :, None])
 
    if dw < w:
        kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
        pad_w = (kernel_w.shape[0] - 1) // 2
        input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
        input = F.conv2d(input, kernel_w[None, None, None, :])
 
    input = input.view([n, c, h, w])
    return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
 
 
class ReplaceGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_forward, x_backward):
        ctx.shape = x_backward.shape
        return x_forward
 
    @staticmethod
    def backward(ctx, grad_in):
        return None, grad_in.sum_to_size(ctx.shape)
 
 
replace_grad = ReplaceGrad.apply
 
 
class ClampWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)
 
    @staticmethod
    def backward(ctx, grad_in):
        input, = ctx.saved_tensors
        return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None
 
 
clamp_with_grad = ClampWithGrad.apply
 
 
def vector_quantize(x, codebook):
    d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
    indices = d.argmin(-1)
    x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
    return replace_grad(x_q, x)
 
 
class Prompt(nn.Module):
    def __init__(self, embed, weight=1., stop=float('-inf')):
        super().__init__()
        self.register_buffer('embed', embed)
        self.register_buffer('weight', torch.as_tensor(weight))
        self.register_buffer('stop', torch.as_tensor(stop))
 
    def forward(self, input):
        input_normed = F.normalize(input.unsqueeze(1), dim=2)
        embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
        dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        dists = dists * self.weight.sign()
        return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
 
 
def parse_prompt(prompt):
    vals = prompt.rsplit(':', 2)
    vals = vals + ['', '1', '-inf'][len(vals):]
    return vals[0], float(vals[1]), float(vals[2])
 
 
class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.augs = nn.Sequential(
            K.RandomHorizontalFlip(p=0.5),
            # K.RandomSolarize(0.01, 0.01, p=0.7),
            K.RandomSharpness(0.3,p=0.4),
            K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
            K.RandomPerspective(0.2,p=0.4),
            K.ColorJitter(hue=0.01, saturation=0.01, p=0.7))
        self.noise_fac = 0.1
 
 
    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        batch = self.augs(torch.cat(cutouts, dim=0))
        if self.noise_fac:
            facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
            batch = batch + facs * torch.randn_like(batch)
        return batch
 
 
def load_vqgan_model(config_path, checkpoint_path):
    config = OmegaConf.load(config_path)
    if config.model.target == 'taming.models.vqgan.VQModel':
        model = vqgan.VQModel(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
    elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
        parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
        parent_model.eval().requires_grad_(False)
        parent_model.init_from_ckpt(checkpoint_path)
        model = parent_model.first_stage_model
    elif config.model.target == 'taming.models.vqgan.GumbelVQ':
        model = vqgan.GumbelVQ(**config.model.params)
        print(config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
    else:
        raise ValueError(f'unknown model type: {config.model.target}')
    del model.loss
    return model
 
 
def resize_image(image, out_size):
    ratio = image.size[0] / image.size[1]
    area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
    size = round((area * ratio)**0.5), round((area / ratio)**0.5)
    return image.resize(size, Image.LANCZOS)

def download_img(img_url):
    try:
        return wget.download(img_url,out="input.jpg")
    except:
        return


In [None]:
# @title Model Functions

def synth(z):
    if is_gumbel:
        z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1)
    else:
        z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
    
    return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)

def add_xmp_data(nombrefichero):
    imagen = ImgTag(filename=nombrefichero)
    imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'creator', 'VQGAN+CLIP', {"prop_array_is_ordered":True, "prop_value_is_array":True})
    if args.prompts:
        imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'title', " | ".join(args.prompts), {"prop_array_is_ordered":True, "prop_value_is_array":True})
    else:
        imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'title', 'None', {"prop_array_is_ordered":True, "prop_value_is_array":True})
    imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'i', str(i), {"prop_array_is_ordered":True, "prop_value_is_array":True})
    imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'model', name_model, {"prop_array_is_ordered":True, "prop_value_is_array":True})
    imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'seed',str(seed) , {"prop_array_is_ordered":True, "prop_value_is_array":True})
    imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'input_images',str(input_images) , {"prop_array_is_ordered":True, "prop_value_is_array":True})
    #for frases in args.prompts:
    #    imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'Prompt' ,frases, {"prop_array_is_ordered":True, "prop_value_is_array":True})
    imagen.close()

def add_stegano_data(filename):
    data = {
        "title": " | ".join(args.prompts) if args.prompts else None,
        "notebook": "VQGAN+CLIP",
        "i": i,
        "model": name_model,
        "seed": str(seed),
        "input_images": input_images
    }
    lsb.hide(filename, json.dumps(data)).save(filename)

@torch.no_grad()
def checkin(i, losses):
    losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
    tqdm.write(f'i: {i}, max_iter: {max_iterations}, loss: {sum(losses).item():g}, losses: {losses_str}')
    out = synth(z)
    TF.to_pil_image(out[0].cpu()).save('progress.png')
    add_stegano_data('progress.png')
    add_xmp_data('progress.png')
    display.display(display.Image('progress.png'))

def ascend_txt():
    global i
    out = synth(z)
    iii = perceptor.encode_image(normalize(make_cutouts(out))).float()

    result = []

    if args.init_weight:
        result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)

    for prompt in pMs:
        result.append(prompt(iii))
    img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
    img = np.transpose(img, (1, 2, 0))
    filename = f"steps/{i:04}.png"
    imageio.imwrite(filename, np.array(img))
    add_stegano_data(filename)
    add_xmp_data(filename)
    return result

def train(i):
    opt.zero_grad()
    lossAll = ascend_txt()
    if i % args.display_freq == 0:
        checkin(i, lossAll)
    loss = sum(lossAll)
    loss.backward()
    opt.step()
    with torch.no_grad():
        z.copy_(z.maximum(z_min).minimum(z_max))

In [None]:
# @title Initializing

width =  512
height =  512
model = "vqgan_imagenet_f16_16384"
images_interval =  50
target_images = ""
seed = -1
input_images = ""
init_image =  "/content/playlistImage"

model_names={"vqgan_imagenet_f16_16384": 'ImageNet 16384',"vqgan_imagenet_f16_1024":"ImageNet 1024", 
                "wikiart_1024":"WikiArt 1024", "wikiart_16384":"WikiArt 16384", "coco":"COCO-Stuff", "faceshq":"FacesHQ", "sflckr":"S-FLCKR", "ade20k":"ADE20K", "ffhq":"FFHQ", "celebahq":"CelebA-HQ", "gumbel_8192": "Gumbel 8192"}
name_model = model_names[model]     

if model == "gumbel_8192":
    is_gumbel = True
else:
    is_gumbel = False

if seed == -1:
    seed = None

if target_images == "None" or not target_images:
    target_images = []
else:
    target_images = target_images.split("|")
    target_images = [image.strip() for image in target_images]

args = argparse.Namespace(
    prompts="",
    image_prompts=target_images,
    noise_prompt_seeds=[],
    noise_prompt_weights=[],
    size=[width, height],
    init_image=init_image,
    init_weight=0.,
    clip_model='ViT-B/32',
    vqgan_config=f'{model}.yaml',
    vqgan_checkpoint=f'{model}.ckpt',
    step_size=0.1,
    cutn=64,
    cut_pow=1.,
    display_freq=images_interval,
    seed=seed,
)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
if args.seed is None:
    seed = torch.seed()
else:
    seed = args.seed
torch.manual_seed(seed)
print('Using seed:', seed)
model = load_vqgan_model(f'vqgan_imagenet_f16_16384.yaml', args.vqgan_checkpoint).to(device)
perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)

# Generate Images


In [None]:
# @markdown the default option takes 5 songs of the notebook to create the visuals. If you would like to go on for the full playlist please check this box.
fullPlaylist = False #@param {type:"boolean"}

i = 0
limit = 5

if(fullPlaylist):
  limit = len(text_prompt)

for ind in range(limit):
  print(ind)
  texts = text_prompt[ind]
  max_iterations = 250 + ind*250
  if(ind == 0):
    init_image = args.init_image
  else:
    init_image = f"steps/{ind*250:04}.png"

  texts = [frase.strip() for frase in texts.split("|")]
  if texts == ['']:
      texts = []

  print(texts)

  if init_image == "None":
    init_image = None
  elif init_image and init_image.lower().startswith("http"):
      init_image = download_img(init_image)

  if init_image or target_images != []:
    input_images = True

  args.prompts = texts
  args.init_image = init_image

 
  cut_size = perceptor.visual.input_resolution
  if is_gumbel:
      e_dim = model.quantize.embedding_dim
  else:
      e_dim = model.quantize.e_dim

  f = 2**(model.decoder.num_resolutions - 1)
  make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
  if is_gumbel:
      n_toks = model.quantize.n_embed
  else:
      n_toks = model.quantize.n_e

  toksX, toksY = args.size[0] // f, args.size[1] // f
  sideX, sideY = toksX * f, toksY * f
  if is_gumbel:
      z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None]
      z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None]
  else:
      z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
      z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]

  if args.init_image:
      pil_image = Image.open(args.init_image).convert('RGB')
      pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
      z, *_ = model.encode(TF.to_tensor(pil_image).to(device).unsqueeze(0) * 2 - 1)
  else:
      one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()
      if is_gumbel:
          z = one_hot @ model.quantize.embed.weight
      else:
          z = one_hot @ model.quantize.embedding.weight
      z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
  z_orig = z.clone()
  z.requires_grad_(True)
  opt = optim.Adam([z], lr=args.step_size)

  normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                  std=[0.26862954, 0.26130258, 0.27577711])

  pMs = []

  for prompt in args.prompts:
      txt, weight, stop = parse_prompt(prompt)
      embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
      pMs.append(Prompt(embed, weight, stop).to(device))

  for prompt in args.image_prompts:
      path, weight, stop = parse_prompt(prompt)
      img = resize_image(Image.open(path).convert('RGB'), (sideX, sideY))
      batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
      embed = perceptor.encode_image(normalize(batch)).float()
      pMs.append(Prompt(embed, weight, stop).to(device))

  for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
      gen = torch.Generator().manual_seed(seed)
      embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
      pMs.append(Prompt(embed, weight).to(device))

  try:
      with tqdm() as pbar:
          while True:
              train(i)
              if i == max_iterations:
                  break
              i += 1
              pbar.update()
  except KeyboardInterrupt:
      pass



#Generate The Video

In [None]:
#@markdown **Generate a video with the result (You can edit frame rate and stuff by double-clicking this tab)**
init_frame = 1 #This is the frame where the video will start
last_frame = i #You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.

min_fps = 10
max_fps = 60

total_frames = last_frame-init_frame

length = 15 #Desired video time in seconds

frames = []
tqdm.write('Generating video...')
for i in range(init_frame,last_frame): #
    filename = f"steps/{i:04}.png"
    frames.append(Image.open(filename))
   
#fps = last_frame/10
fps = np.clip(total_frames/length,min_fps,max_fps)

from subprocess import Popen, PIPE
p = Popen(['ffmpeg', '-y', '-f', 'image2pipe', '-vcodec', 'png', '-r', str(fps), '-i', '-', '-vcodec', 'libx264', '-r', str(fps), '-pix_fmt', 'yuv420p', '-crf', '17', '-preset', 'veryslow', 'video.mp4'], stdin=PIPE)
for im in tqdm(frames):
    im.save(p.stdin, 'PNG')
p.stdin.close()

print("The video is now being compressed, wait...")
p.wait()
print("The video is ready")

In [None]:
#@markdown **View video in browser**

# @markdown *This process may take a little longer. If you don't want to wait, download it by executing the next cell instead of using this cell.*
mp4 = open('video.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
display.HTML("""
<video width=400 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)

In [None]:
#@markdown #**Download the result video**
from google.colab import files
files.download("video.mp4")