-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
480f5c5
commit 1e36565
Showing
5 changed files
with
179 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Copyright 2022 Lunar Ring. All rights reserved. | ||
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
|
||
torch.backends.cudnn.benchmark = False | ||
torch.set_grad_enabled(False) | ||
import warnings | ||
|
||
warnings.filterwarnings("ignore") | ||
import warnings | ||
from latent_blending import LatentBlending | ||
from movie_util import concatenate_movies | ||
from stable_diffusion_holder import StableDiffusionHolder | ||
|
||
import argparse | ||
|
||
MAX_BRANCHES = 8 | ||
# 0:48 seconds for a 15s, total is more like 1:00 | ||
T_COMPUTE_MAX_ALLOWED = 30 | ||
INFERENCE_STEPS = 50 | ||
DURATION_TRANSITION = 2 # In seconds | ||
FPS = 60 | ||
PROMPT = "a van gogh painting" | ||
|
||
# Create the parser | ||
parser = argparse.ArgumentParser(description="Create a new instance") | ||
|
||
# Add the arguments | ||
parser.add_argument( | ||
"--prompt", | ||
type=str, | ||
help="Caption for the images", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice. | ||
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt") | ||
fp_ckpt = "/root/.cache/torch/hub/checkpoints/v2-1_768-ema-pruned.ckpt" | ||
sdh = StableDiffusionHolder(fp_ckpt) | ||
|
||
# Specify a list of prompts below | ||
list_prompts = [ | ||
""" | ||
There's a lady who's sure all that glitters is gold | ||
""", | ||
""" | ||
And she's buying a stairway to Heaven | ||
""", | ||
""" | ||
When she gets there she knows, if the stores are all closed | ||
""", | ||
""" | ||
With a word she can get what she came for | ||
""", | ||
""" | ||
Ooh, ooh, and she's buying a stairway to Heaven | ||
""", | ||
""" | ||
There's a sign on the wall, but she wants to be sure | ||
""", | ||
""" | ||
'Cause you know sometimes words have two meanings | ||
""", | ||
] | ||
|
||
# %% Next let's set up all parameters | ||
list_seeds = [] | ||
list_prompts = [] | ||
for idx in range(10): | ||
# list_prompts.append("a beautiful high-resolution closeup photo of a human eye centered in the frame") | ||
list_prompts.append(args.prompt) | ||
list_seeds.append(idx) | ||
|
||
|
||
# Spawn latent blending | ||
lb = LatentBlending(sdh) | ||
# Dimensions must be divisible by 64 | ||
lb.set_width(768) | ||
lb.set_height(768) | ||
crossfeed_power = 0.8 | ||
crossfeed_range = ( | ||
0.1 # The crossfeed is active until 20% of num_iteration, then switched off | ||
) | ||
crossfeed_decay = 0.5 # The power of the crossfeed decreases over diffusion iterations, here it would be 0.5*0.2=0.1 in the end of the range. | ||
lb.set_parental_crossfeed(crossfeed_power, crossfeed_range, crossfeed_decay) | ||
|
||
fp_movie = f"{args.prompt.split(' ')[0]}_{DURATION_TRANSITION}s_{T_COMPUTE_MAX_ALLOWED}s_{crossfeed_power}cp_{crossfeed_range}r_{crossfeed_decay}d" | ||
|
||
list_movie_parts = [] | ||
for i in range(len(list_prompts) - 1): | ||
# For a multi transition we can save some computation time and recycle the latents | ||
if i == 0: | ||
lb.set_prompt1(list_prompts[i]) | ||
lb.set_prompt2(list_prompts[i + 1]) | ||
recycle_img1 = False | ||
else: | ||
lb.swap_forward() | ||
lb.set_prompt2(list_prompts[i + 1]) | ||
recycle_img1 = True | ||
|
||
fixed_seeds = list_seeds[i : i + 2] | ||
fp_movie_part = f"tmp_part_{str(i).zfill(3)}_{fp_movie}s.mp4" | ||
lb.run_transition( | ||
recycle_img1=recycle_img1, | ||
fixed_seeds=fixed_seeds, | ||
t_compute_max_allowed=int(T_COMPUTE_MAX_ALLOWED / 4), | ||
# nmb_max_branches=MAX_BRANCHES, | ||
num_inference_steps=INFERENCE_STEPS, | ||
) | ||
|
||
# Save movie | ||
lb.write_movie_transition(fp_movie_part, DURATION_TRANSITION, fps=FPS) | ||
list_movie_parts.append(fp_movie_part) | ||
|
||
# Finally, concatenate the result | ||
concatenate_movies(f"outputs/{fp_movie}.mp4", list_movie_parts) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule clip
deleted from
a1d071
Submodule taming-transformers
deleted from
3ba01b
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import json | ||
import os | ||
from utils import get_rest_response | ||
|
||
# API_HOST = os.environ.get('TEST_API') | ||
API_HOST = "http://10.0.0.31:8000/api/" | ||
|
||
|
||
def test_get_config(): | ||
endpoint = "brokers/" | ||
query = "?name=edge" | ||
response = get_rest_response( | ||
"/cloud", {"hostname": API_HOST, "endpoint": endpoint, "query": query} | ||
) | ||
print(response) | ||
|
||
|
||
def test_get_config_name(): | ||
endpoint = "/config/mqtt" | ||
query = "?name=hivemq" | ||
response = get_rest_response( | ||
"/cloud", {"hostname": API_HOST, "endpoint": endpoint, "query": query} | ||
) | ||
assert response["code"] == 0 |