<a href="https://colab.research.google.com/github/keigoyoshida7/Stable-diffusion/blob/main/keigo_yoshida'22_Stable_Diffusion_txt_to_album.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Keigo Yoshida '22 Stable Diffusion txt to album



# Check

In [None]:
!nvidia-smi

# Setup Environment for disc jacket

In [None]:
# @title install library
dependencies = [
  "albumentations==0.4.3",
  "diffusers==0.7.1",
  "opencv-python==4.1.2.30",
  "pudb==2019.2",
  "imageio==2.9.0",
  "imageio-ffmpeg==0.4.2",
  "pytorch-lightning==1.6.5",
  "omegaconf==2.1.1",
  "test-tube>=0.7.5",
  "streamlit>=0.73.1",
  "einops==0.3.0",
  "torch-fidelity==0.3.0",
  "transformers==4.19.2",
  "kornia==0.6.8",
]
dep_list = " ".join(map(lambda dep: f"\"{dep}\"",dependencies))
print(f"installing {dep_list}")
!pip install {dep_list}
!pip install -e "git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers"
!pip install -e "git+https://github.com/openai/CLIP.git@main#egg=clip"
!pip install "invisible-watermark==0.1.5"

print("finish!")

In [None]:
!pip install -e taming-transformers

## ランタイムの再起動

メニュー > ランタイム > ランタイムを再起動

In [None]:
# @title install stable-diffusion
%cd /content/
!git clone https://github.com/CompVis/stable-diffusion.git
%cd /content/stable-diffusion
!pip install -e .


In [None]:
# @title download model
# https://huggingface.co/CompVis/stable-diffusion-v1-4
%cd /content
!gdown "https://drive.google.com/uc?export=download&id=12Hzk3DEzGCmI2ha4XrSCLypz7ia9mn9w"

# execution

In [None]:
# @title import module

%cd /content/stable-diffusion

import argparse, os, sys, glob
import cv2
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from imwatermark import WatermarkEncoder
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor

from IPython.display import Image as IPImage, display


In [None]:
# @title define functions etc

safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]

    return pil_images


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


def put_watermark(img, wm_encoder=None):
    if wm_encoder is not None:
        img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
        img = wm_encoder.encode(img, 'dwtDct')
        img = Image.fromarray(img[:, :, ::-1])
    return img


def load_replacement(x):
    try:
        hwc = x.shape
        y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
        y = (np.array(y)/255.0).astype(x.dtype)
        assert y.shape == x.shape
        return y
    except Exception:
        return x


def check_safety(x_image):
    safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
    x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
    assert x_checked_image.shape[0] == len(has_nsfw_concept)
    for i in range(len(has_nsfw_concept)):
        if has_nsfw_concept[i]:
            x_checked_image[i] = load_replacement(x_checked_image[i])
    return x_checked_image, has_nsfw_concept

config = OmegaConf.load("/content/stable-diffusion/configs/stable-diffusion/v1-inference.yaml")
model = load_model_from_config(config, "/content/sd-v1-4.ckpt")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)


In [None]:
#@title **Setup Environment for music**

import subprocess, time
print("Setting up environment...")
start_time = time.time()
all_process = [
    ['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],
    ['pip', 'install', '-U', 'sentence-transformers'],
    ['pip', 'install', 'httpx'],
]
for process in all_process:
    running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')

end_time = time.time()


In [None]:
#@title **Define Mubert methods and pre-compute things**

import numpy as np
from sentence_transformers import SentenceTransformer
minilm = SentenceTransformer('all-MiniLM-L6-v2')

mubert_tags_string = 'tribal,action,kids,neo-classic,run 130,pumped,jazz / funk,ethnic,dubtechno,reggae,acid jazz,liquidfunk,funk,witch house,tech house,underground,artists,mystical,disco,sensorium,r&b,agender,psychedelic trance / psytrance,peaceful,run 140,piano,run 160,setting,meditation,christmas,ambient,horror,cinematic,electro house,idm,bass,minimal,underscore,drums,glitchy,beautiful,technology,tribal house,country pop,jazz & funk,documentary,space,classical,valentines,chillstep,experimental,trap,new jack swing,drama,post-rock,tense,corporate,neutral,happy,analog,funky,spiritual,sberzvuk special,chill hop,dramatic,catchy,holidays,fitness 90,optimistic,orchestra,acid techno,energizing,romantic,minimal house,breaks,hyper pop,warm up,dreamy,dark,urban,microfunk,dub,nu disco,vogue,keys,hardcore,aggressive,indie,electro funk,beauty,relaxing,trance,pop,hiphop,soft,acoustic,chillrave / ethno-house,deep techno,angry,dance,fun,dubstep,tropical,latin pop,heroic,world music,inspirational,uplifting,atmosphere,art,epic,advertising,chillout,scary,spooky,slow ballad,saxophone,summer,erotic,jazzy,energy 100,kara mar,xmas,atmospheric,indie pop,hip-hop,yoga,reggaeton,lounge,travel,running,folk,chillrave & ethno-house,detective,darkambient,chill,fantasy,minimal techno,special,night,tropical house,downtempo,lullaby,meditative,upbeat,glitch hop,fitness,neurofunk,sexual,indie rock,future pop,jazz,cyberpunk,melancholic,happy hardcore,family / kids,synths,electric guitar,comedy,psychedelic trance & psytrance,edm,psychedelic rock,calm,zen,bells,podcast,melodic house,ethnic percussion,nature,heavy,bassline,indie dance,techno,drumnbass,synth pop,vaporwave,sad,8-bit,chillgressive,deep,orchestral,futuristic,hardtechno,nostalgic,big room,sci-fi,tutorial,joyful,pads,minimal 170,drill,ethnic 108,amusing,sleepy ambient,psychill,italo disco,lofi,house,acoustic guitar,bassline house,rock,k-pop,synthwave,deep house,electronica,gabber,nightlife,sport & fitness,road trip,celebration,electro,disco house,electronic'
mubert_tags = np.array(mubert_tags_string.split(','))
mubert_tags_embeddings = minilm.encode(mubert_tags)

from IPython.display import Audio, display
import httpx
import json

def get_track_by_tags(tags, pat, duration, maxit=20, autoplay=False, loop=False):
  if loop:
    mode = "loop"
  else:
    mode = "track"
  r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM', 
      json={
          "method":"RecordTrackTTM",
          "params": {
              "pat": pat, 
              "duration": duration,
              "tags": tags,
              "mode": mode
          }
      })

  rdata = json.loads(r.text)
  assert rdata['status'] == 1, rdata['error']['text']
  trackurl = rdata['data']['tasks'][0]['download_link']

  print('Generating track ', end='')
  for i in range(maxit):
      r = httpx.get(trackurl)
      if r.status_code == 200:
          display(Audio(trackurl, autoplay=autoplay))
          break
      time.sleep(1)
      print('.', end='')

def find_similar(em, embeddings, method='cosine'):
    scores = []
    for ref in embeddings:
        if method == 'cosine': 
            scores.append(1 - np.dot(ref, em)/(np.linalg.norm(ref)*np.linalg.norm(em)))
        if method == 'norm': 
            scores.append(np.linalg.norm(ref - em))
    return np.array(scores), np.argsort(scores)

def get_tags_for_prompts(prompts, top_n=3, debug=False):
    prompts_embeddings = minilm.encode(prompts)
    ret = []
    for i, pe in enumerate(prompts_embeddings):
        scores, idxs = find_similar(pe, mubert_tags_embeddings)
        top_tags = mubert_tags[idxs[:top_n]]
        top_prob = 1 - scores[idxs[:top_n]]
        if debug:
            print(f"Prompt: {prompts[i]}\nTags: {', '.join(top_tags)}\nScores: {top_prob}\n\n\n")
        ret.append((prompts[i], list(top_tags)))
    return ret

In [None]:
#@markdown **Get personal access token in Mubert and define API methods**
email = "sample@gmail.com" #@param {type:"string"}

r = httpx.post('https://api-b2b.mubert.com/v2/GetServiceAccess', 
    json={
        "method":"GetServiceAccess",
        "params": {
            "email": email,
            "license":"ttmmubertlicense#f0acYBenRcfeFpNT4wpYGaTQIyDI4mJGv5MfIhBFz97NXDwDNFHmMRsBSzmGsJwbTpP1A6i07AXcIeAHo5",
            "token":"4951f6428e83172a4f39de05d5b3ab10d58560b8",
            "mode": "loop"
        }
    })

rdata = json.loads(r.text)
assert rdata['status'] == 1, "probably incorrect e-mail"
pat = rdata['data']['pat']
print(f'Got token: {pat}')

In [None]:
# @title parameter settings { run: "auto" }

# @markdown ### basic

prompt = "wings of freedom" # @param { type: "string" }
duration = 90 #@param {type:"number"}
loop = False #@param {type:"boolean"}

def generate_track_by_prompt(prompt, duration, loop=False):
  _, tags = get_tags_for_prompts([prompt,])[0]
  try:
    get_track_by_tags(tags, pat, duration, autoplay=True, loop=loop)
  except Exception as e:
    print(str(e))
  print('\n')

outdir = '/content/results' # @param { type: "string" }
width = 512 # @param { type: "integer" }
height = 512 # @param { type: "integer" }
n_samples = 1 # @param { type: "integer", min: 1 }
seed = 43 # @param { type: "integer" }
scale = 7.5 # @param { type: "number", min: 0.0 }

# @markdown ----
# @markdown ### advanced

n_iter = 1 # @param { type: "integer" }
skip_grid = False # @param { type: "boolean" }
skip_save = False # @param { type: "boolean" }
ddim_steps = 50 # @param { type: "integer" }
plms = False # @param { type: "boolean" }
fixed_code = False # @param { type: "boolean" }
ddim_eta = 0.0 # @param { type: "number" }
latent_channels = 4 # @param { type: "integer", min: 1 }
downsampling_factor = 8 # @param { type: "integer", min: 1 }
n_rows = 0 # @param { type: "integer", min: 0 }
precision = "autocast" # @param [ "full", "autocast" ]


In [None]:
# @title drop the album

torch.cuda.empty_cache()

seed_everything(seed)

if plms:
    sampler = PLMSSampler(model)
else:
    sampler = DDIMSampler(model)

os.makedirs(outdir, exist_ok=True)
outpath = outdir

wm = "DDA 2022"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))

batch_size = n_samples
n_rows = n_rows if n_rows > 0 else batch_size
assert prompt is not None
data = [ batch_size * [prompt] ]

sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1

start_code = None
if fixed_code:
    start_code = torch.randn([ n_samples, latent_channels, height // downsampling_factor, width // downsampling_factor ], device=device)

precision_scope = autocast if precision=="autocast" else nullcontext
with torch.no_grad():
    with precision_scope("cuda"):
        with model.ema_scope():
            tic = time.time()
            all_samples = list()
            for n in trange(n_iter, desc="Sampling"):
                for prompts in tqdm(data, desc="data"):
                    uc = None
                    if scale != 1.0:
                        uc = model.get_learned_conditioning(batch_size * [""])
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    c = model.get_learned_conditioning(prompts)
                    shape = [ latent_channels, height // downsampling_factor, width // downsampling_factor ]
                    samples_ddim, _ = sampler.sample(S = ddim_steps,
                                                      conditioning = c,
                                                      batch_size = n_samples,
                                                      shape = shape,
                                                      verbose = False,
                                                      unconditional_guidance_scale = scale,
                                                      unconditional_conditioning = uc,
                                                      eta = ddim_eta,
                                                      x_T = start_code)

                    x_samples_ddim = model.decode_first_stage(samples_ddim)
                    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                    x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()

                    x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)

                    x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)

                    if not skip_save:
                        for x_sample in x_checked_image_torch:
                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                            img = Image.fromarray(x_sample.astype(np.uint8))
                            img = put_watermark(img, wm_encoder)
                            img.save(os.path.join(sample_path, f"{base_count:05}.png"))
                            base_count += 1

                    if not skip_grid:
                        all_samples.append(x_checked_image_torch)

            if not skip_grid:
                # additionally, save as grid
                grid = torch.stack(all_samples, 0)
                grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                grid = make_grid(grid, nrow=n_rows)

                # to image
                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                img = Image.fromarray(grid.astype(np.uint8))
                img = put_watermark(img, wm_encoder)
                save_path = os.path.join(outpath, f'grid-{grid_count:04}.png')
                img.save(save_path)
                display(IPImage(save_path))
                grid_count += 1

            toc = time.time()

generate_track_by_prompt(prompt, duration, loop)




# Set up environment for music video

# Install library

In [None]:
%cd /content

!pip install stable_diffusion_videos[realesrgan]

# Hugging face access token

In [None]:
access_tokens="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" # @param {type:"string"}

# Import library

In [None]:
import torch
from stable_diffusion_videos import StableDiffusionWalkPipeline, Interface
from IPython.display import HTML
from base64 import b64encode

# Loading model

In [None]:
model_id = "CompVis/stable-diffusion-v1-4"

pipeline = StableDiffusionWalkPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    revision="fp16",
    use_auth_token=access_tokens
).to("cuda")

interface = Interface(pipeline)

# Define utility function

In [None]:
def visualize_video_colab(video_path):
  """動画をインライン表示"""
  mp4 = open(video_path,'rb').read()
  data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
  return HTML("""
    <video width=400 controls>
        <source src="%s" type="video/mp4">
    </video>
  """ % data_url)

# Drop music video

In [None]:
interface.launch(debug=True)

# Music video settings

In [None]:
# @markdown src prompt
prompt1 = "take my breath" # @param {type:"string"}
# @markdown dst prompt
prompt2 = "just the way god made you" # @param {type:"string"}

# @markdown seeds
seed1 = 12 #@param {type:"integer"}
seed2 = 1212 #@param {type:"integer"}

# @markdown FPS
fps = 10 #@param {type:"integer"}

# @markdown number of interpolation steps
num_interpolation_steps = 10 #@param {type:"integer"}

# @markdown video size
height = 512 #@param　{type:"integer"}
width = 512 #@param {type:"integer"}

# Music video inference

In [None]:
video_path = pipeline.walk(
    [prompt1, prompt2],
    [seed1, seed2],
    fps=fps,                      
    num_interpolation_steps=num_interpolation_steps,
    height=int(height),                
    width=int(width),                  
)

In [None]:
visualize_video_colab(video_path)

In [None]:
!zip -r /content/results.zip /content/results