Skip to content

Commit

Permalink
Merge branch 'SebastianAigner-main' into development
Browse files Browse the repository at this point in the history
Add support for full CORS headers for dream server.
  • Loading branch information
lstein committed Sep 4, 2022
2 parents 5116c81 + fd7a72e commit 4406fd1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
11 changes: 11 additions & 0 deletions ldm/dream/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class DreamServer(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == "/":
self.send_response(200)
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept")
self.send_header("Content-type", "text/html")
self.end_headers()
with open("./static/dream_web/index.html", "rb") as content:
Expand All @@ -33,6 +35,8 @@ def do_GET(self):
elif self.path == "/cancel":
self.canceled.set()
self.send_response(200)
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept")
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(bytes('{}', 'utf8'))
Expand All @@ -55,6 +59,8 @@ def do_GET(self):

def do_POST(self):
self.send_response(200)
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept")
self.send_header("Content-type", "application/json")
self.end_headers()

Expand Down Expand Up @@ -199,6 +205,11 @@ def image_progress(sample, step):
print(f"Canceled.")
return

def do_OPTIONS(self):
self.send_response(200)
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept")
self.end_headers()

class ThreadingDreamServer(ThreadingHTTPServer):
def __init__(self, server_address):
Expand Down
16 changes: 8 additions & 8 deletions ldm/simplet2i.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def process_image(image,seed):
0.0 <= variation_amount <= 1.0
), '-v --variation_amount must be in [0.0, 1.0]'

if len(with_variations) > 0 or variation_amount > 1.0:
if len(with_variations) > 0 or variation_amount > 0.0:
assert seed is not None,\
'seed must be specified when using with_variations'
if variation_amount == 0.0:
Expand Down Expand Up @@ -346,6 +346,7 @@ def process_image(image,seed):
callback=step_callback,
)
else:
init_latent = None
make_image = self._txt2img(
prompt,
steps=steps,
Expand All @@ -361,11 +362,11 @@ def process_image(image,seed):
if variation_amount > 0 or len(with_variations) > 0:
# use fixed initial noise plus random noise per iteration
seed_everything(seed)
initial_noise = self._get_noise(init_img,width,height)
initial_noise = self._get_noise(init_latent,width,height)
for v_seed, v_weight in with_variations:
seed = v_seed
seed_everything(seed)
next_noise = self._get_noise(init_img,width,height)
next_noise = self._get_noise(init_latent,width,height)
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
if variation_amount > 0:
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
Expand All @@ -377,17 +378,16 @@ def process_image(image,seed):
x_T = None
if variation_amount > 0:
seed_everything(seed)
target_noise = self._get_noise(init_img,width,height)
target_noise = self._get_noise(init_latent,width,height)
x_T = self.slerp(variation_amount, initial_noise, target_noise)
elif initial_noise is not None:
# i.e. we specified particular variations
x_T = initial_noise
else:
seed_everything(seed)
if self.device.type == 'mps':
x_T = self._get_noise(init_img,width,height)
x_T = self._get_noise(init_latent,width,height)
# make_image will do the equivalent of get_noise itself
print(f' DEBUG: seed at make_image() invocation time ={seed}')
image = make_image(x_T)
results.append([image, seed])
if image_callback is not None:
Expand Down Expand Up @@ -617,8 +617,8 @@ def load_model(self):
return self.model

# returns a tensor filled with random numbers from a normal distribution
def _get_noise(self,init_img,width,height):
if init_img:
def _get_noise(self,init_latent,width,height):
if init_latent is not None:
if self.device.type == 'mps':
return torch.randn_like(init_latent, device='cpu').to(self.device)
else:
Expand Down

0 comments on commit 4406fd1

Please sign in to comment.