Skip to content

Commit

Permalink
Dreamin
Browse files Browse the repository at this point in the history
  • Loading branch information
JimothyJohn committed Aug 5, 2023
1 parent 480f5c5 commit 1e36565
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 20 deletions.
130 changes: 130 additions & 0 deletions dreaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2022 Lunar Ring. All rights reserved.
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)
import warnings

warnings.filterwarnings("ignore")
import warnings
from latent_blending import LatentBlending
from movie_util import concatenate_movies
from stable_diffusion_holder import StableDiffusionHolder

import argparse

MAX_BRANCHES = 8
# 0:48 seconds for a 15s, total is more like 1:00
T_COMPUTE_MAX_ALLOWED = 30
INFERENCE_STEPS = 50
DURATION_TRANSITION = 2 # In seconds
FPS = 60
PROMPT = "a van gogh painting"

# Create the parser
parser = argparse.ArgumentParser(description="Create a new instance")

# Add the arguments
parser.add_argument(
"--prompt",
type=str,
help="Caption for the images",
)

args = parser.parse_args()

# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
fp_ckpt = "/root/.cache/torch/hub/checkpoints/v2-1_768-ema-pruned.ckpt"
sdh = StableDiffusionHolder(fp_ckpt)

# Specify a list of prompts below
list_prompts = [
"""
There's a lady who's sure all that glitters is gold
""",
"""
And she's buying a stairway to Heaven
""",
"""
When she gets there she knows, if the stores are all closed
""",
"""
With a word she can get what she came for
""",
"""
Ooh, ooh, and she's buying a stairway to Heaven
""",
"""
There's a sign on the wall, but she wants to be sure
""",
"""
'Cause you know sometimes words have two meanings
""",
]

# %% Next let's set up all parameters
list_seeds = []
list_prompts = []
for idx in range(10):
# list_prompts.append("a beautiful high-resolution closeup photo of a human eye centered in the frame")
list_prompts.append(args.prompt)
list_seeds.append(idx)


# Spawn latent blending
lb = LatentBlending(sdh)
# Dimensions must be divisible by 64
lb.set_width(768)
lb.set_height(768)
crossfeed_power = 0.8
crossfeed_range = (
0.1 # The crossfeed is active until 20% of num_iteration, then switched off
)
crossfeed_decay = 0.5 # The power of the crossfeed decreases over diffusion iterations, here it would be 0.5*0.2=0.1 in the end of the range.
lb.set_parental_crossfeed(crossfeed_power, crossfeed_range, crossfeed_decay)

fp_movie = f"{args.prompt.split(' ')[0]}_{DURATION_TRANSITION}s_{T_COMPUTE_MAX_ALLOWED}s_{crossfeed_power}cp_{crossfeed_range}r_{crossfeed_decay}d"

list_movie_parts = []
for i in range(len(list_prompts) - 1):
# For a multi transition we can save some computation time and recycle the latents
if i == 0:
lb.set_prompt1(list_prompts[i])
lb.set_prompt2(list_prompts[i + 1])
recycle_img1 = False
else:
lb.swap_forward()
lb.set_prompt2(list_prompts[i + 1])
recycle_img1 = True

fixed_seeds = list_seeds[i : i + 2]
fp_movie_part = f"tmp_part_{str(i).zfill(3)}_{fp_movie}s.mp4"
lb.run_transition(
recycle_img1=recycle_img1,
fixed_seeds=fixed_seeds,
t_compute_max_allowed=int(T_COMPUTE_MAX_ALLOWED / 4),
# nmb_max_branches=MAX_BRANCHES,
num_inference_steps=INFERENCE_STEPS,
)

# Save movie
lb.write_movie_transition(fp_movie_part, DURATION_TRANSITION, fps=FPS)
list_movie_parts.append(fp_movie_part)

# Finally, concatenate the result
concatenate_movies(f"outputs/{fp_movie}.mp4", list_movie_parts)
43 changes: 25 additions & 18 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md

import os
from huggingface_hub import hf_hub_download
from cog import BasePredictor, Input, Path
from stable_diffusion_holder import StableDiffusionHolder
from movie_util import concatenate_movies
Expand All @@ -23,23 +21,27 @@

class Predictor(BasePredictor):
def setup(self) -> None:
# First let us spawn a stable diffusion holder. Uncomment your version of choice.
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1",filename="v2-1_768-ema-pruned.ckpt",)
# Load checkpoint from pre-downloaded location
# TODO: Evaluate fastest approach to do this
fp_ckpt = "/root/.cache/torch/hub/checkpoints/v2-1_768-ema-pruned.ckpt"

# Load stable diffusion model from pre-downloaded checkpoint
self.sdh = StableDiffusionHolder(fp_ckpt)
# Spawn latent blending
self.lb = LatentBlending(self.sdh)
# Dimensions must be divisible by 64
# Dimensions must match model fine-tuning
self.lb.set_width(768)
self.lb.set_height(768)

def predict(
self,
caption1: str = Input(description="Image caption to start with"),
caption2: str = Input(description="Image caption to end with"),
caption: str = Input(description="Image to dream"),
# caption2: str = Input(description="Image caption to end with"),
transition_time: int = Input(
description="Length of time to transition", ge=1, le=60, default=5
description="Length of time to transition", ge=1, le=5, default=2
),
transitions: int = Input(
description="Number of images to drem", ge=2, le=32, default=5
),
compute_time: float = Input(
description="Max compute time per transition", ge=10, le=300, default=15
Expand All @@ -48,13 +50,13 @@ def predict(
description="This fraction of the latents in the last branch are copied from the parents",
ge=0.1,
le=0.9,
default=0.5,
default=0.8,
),
crossfeed_range: float = Input(
description="Crossfeed is active until this fraction of num_iteration",
ge=0.1,
le=0.9,
default=0.5,
default=0.8,
),
crossfeed_decay: float = Input(
description="The power of the crossfeed decreases over diffusion iterations",
Expand All @@ -63,13 +65,19 @@ def predict(
default=0.5,
),
) -> Path:
"""Run a single prediction on the model"""
# processed_input = preprocess(image)
# output = self.model(processed_image, scale)
# return postprocess(output)
# %% Next let's set up all parameters
list_seeds = [0, 1]
list_prompts = [caption1, caption2]
# TODO: Add support for user-selected seeds
list_seeds = []
list_prompts = []
"""
if caption2 != "":
list_seeds = [0, 1]
list_prompts = [caption1, caption2]
else:
"""
for i in range(transitions):
list_seeds.append(i)
list_prompts.append(caption)

self.lb.set_parental_crossfeed(
crossfeed_power, crossfeed_range, crossfeed_decay
)
Expand All @@ -92,7 +100,6 @@ def predict(
recycle_img1=recycle_img1,
fixed_seeds=fixed_seeds,
t_compute_max_allowed=compute_time,
# nmb_max_branches=MAX_BRANCHES,
num_inference_steps=INFERENCE_STEPS,
)

Expand Down
1 change: 0 additions & 1 deletion src/clip
Submodule clip deleted from a1d071
1 change: 0 additions & 1 deletion src/taming-transformers
Submodule taming-transformers deleted from 3ba01b
24 changes: 24 additions & 0 deletions tests/test_cloud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json
import os
from utils import get_rest_response

# API_HOST = os.environ.get('TEST_API')
API_HOST = "http://10.0.0.31:8000/api/"


def test_get_config():
endpoint = "brokers/"
query = "?name=edge"
response = get_rest_response(
"/cloud", {"hostname": API_HOST, "endpoint": endpoint, "query": query}
)
print(response)


def test_get_config_name():
endpoint = "/config/mqtt"
query = "?name=hivemq"
response = get_rest_response(
"/cloud", {"hostname": API_HOST, "endpoint": endpoint, "query": query}
)
assert response["code"] == 0

0 comments on commit 1e36565

Please sign in to comment.