Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ldm/models/diffusion/ksampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def sample(
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
def route_callback(k_callback_values):
if img_callback is not None:
img_callback(k_callback_values['x'], k_callback_values['i'])

sigmas = self.model.get_sigmas(S)
if x_T:
Expand All @@ -78,7 +81,8 @@ def sample(
}
return (
K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args
model_wrap_cfg, x, sigmas, extra_args=extra_args,
callback=route_callback
),
None,
)
13 changes: 12 additions & 1 deletion ldm/simplet2i.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def prompt2image(
ddim_eta=None,
skip_normalize=False,
image_callback=None,
step_callback=None,
# these are specific to txt2img
width=None,
height=None,
Expand All @@ -228,9 +229,14 @@ def prompt2image(
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
step_callback // a function or method that will be called each step
image_callback // a function or method that will be called each time an image is generated

To use the callback, define a function of method that receives two arguments, an Image object
To use the step callback, define a function that receives two arguments:
- Image GPU data
- The step number

To use the image callback, define a function of method that receives two arguments, an Image object
and the seed. You can then do whatever you like with the image, including converting it to
different formats and manipulating it. For example:

Expand Down Expand Up @@ -285,6 +291,7 @@ def process_image(image,seed):
skip_normalize=skip_normalize,
init_img=init_img,
strength=strength,
callback=step_callback,
)
else:
images_iterator = self._txt2img(
Expand All @@ -297,6 +304,7 @@ def process_image(image,seed):
skip_normalize=skip_normalize,
width=width,
height=height,
callback=step_callback,
)

with scope(self.device.type), self.model.ema_scope():
Expand Down Expand Up @@ -348,6 +356,7 @@ def _txt2img(
skip_normalize,
width,
height,
callback,
):
"""
An infinite iterator of images from the prompt.
Expand All @@ -371,6 +380,7 @@ def _txt2img(
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
img_callback=callback
)
yield self._samples_to_images(samples)

Expand All @@ -386,6 +396,7 @@ def _img2img(
skip_normalize,
init_img,
strength,
callback, # Currently not implemented for img2img
):
"""
An infinite iterator of images from the prompt and the initial image
Expand Down
73 changes: 47 additions & 26 deletions scripts/dream_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

print("Loading model...")
from ldm.simplet2i import T2I
from ldm.dream.pngwriter import PngWriter
model = T2I(sampler_name='k_lms')

# to get rid of annoying warning messages from pytorch
Expand Down Expand Up @@ -56,44 +57,64 @@ def do_POST(self):

print(f"Request to generate with prompt: {prompt}")

outputs = []
def image_done(image, seed):
config = post_data.copy() # Shallow copy
config['initimg'] = ''

# Write PNGs
pngwriter = PngWriter(
"./outputs/img-samples/", config['prompt'], 1
)
# metadata_str = f'prompt2png({json.dumps(config)} seed={seed}' # gets written into the PNG
pngwriter.write_image(image, seed)

# Append post_data to log
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
for file_path, _ in pngwriter.files_written:
log.write(f"{file_path}: {json.dumps(config)}\n")

self.wfile.write(bytes(json.dumps(
{'event':'result', 'files':pngwriter.files_written, 'config':config}
) + '\n',"utf-8"))

def image_progress(image, step):
self.wfile.write(bytes(json.dumps(
{'event':'step', 'step':step}
) + '\n',"utf-8"))

# outputs = []
if initimg is None:
# Run txt2img
outputs = model.txt2img(prompt,
iterations=iterations,
cfg_scale = cfgscale,
width = width,
height = height,
seed = seed,
steps = steps)
model.prompt2image(prompt,
iterations=iterations,
cfg_scale = cfgscale,
width = width,
height = height,
seed = seed,
steps = steps,

step_callback=image_progress,
image_callback=image_done)
else:
# Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f:
initimg = initimg.split(",")[1] # Ignore mime type
f.write(base64.b64decode(initimg))

# Run img2img
outputs = model.img2img(prompt,
init_img = "./img2img-tmp.png",
iterations = iterations,
cfg_scale = cfgscale,
seed = seed,
steps = steps)
model.prompt2image(prompt,
init_img = "./img2img-tmp.png",
iterations = iterations,
cfg_scale = cfgscale,
seed = seed,
steps = steps,

step_callback=image_progress,
image_callback=image_done)
# Remove the temp file
os.remove("./img2img-tmp.png")

print(f"Prompt generated with output: {outputs}")

post_data['initimg'] = '' # Don't send init image back

# Append post_data to log
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
for output in outputs:
log.write(f"{output[0]}: {json.dumps(post_data)}\n")

outputs = [x + [post_data] for x in outputs] # Append config to each output
result = {'outputs': outputs}
self.wfile.write(bytes(json.dumps(result), "utf-8"))
print(f"Prompt generated!")

if __name__ == "__main__":
# Change working directory to the stable-diffusion directory
Expand Down
3 changes: 2 additions & 1 deletion static/dream_web/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ <h2 id="header">Stable Diffusion Dream Server</h2>
</fieldset>
</form>
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
<br>
<progress id="progress" value="0" max="1"></progress>
</div>
<hr>
<div id="results">
<div id="no-results-message">
<i><p>No results...</p></i>
Expand Down
53 changes: 34 additions & 19 deletions static/dream_web/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,28 @@ function toBase64(file) {
});
}

function appendOutput(output) {
function appendOutput(src, seed, config) {
let outputNode = document.createElement("img");
outputNode.src = output[0];
outputNode.src = src;

let outputConfig = output[2];
let altText = output[1].toString() + " | " + outputConfig.prompt;
let altText = seed.toString() + " | " + config.prompt;
outputNode.alt = altText;
outputNode.title = altText;

// Reload image config
outputNode.addEventListener('click', () => {
let form = document.querySelector("#generate-form");
for (const [k, v] of new FormData(form)) {
form.querySelector(`*[name=${k}]`).value = outputConfig[k];
form.querySelector(`*[name=${k}]`).value = config[k];
}
document.querySelector("#seed").value = output[1];
document.querySelector("#seed").value = seed;

saveFields(document.querySelector("#generate-form"));
});

document.querySelector("#results").prepend(outputNode);
}

function appendOutputs(outputs) {
for (const output of outputs) {
appendOutput(output);
}
}

function saveFields(form) {
for (const [k, v] of new FormData(form)) {
if (typeof v !== 'object') { // Don't save 'file' type
Expand All @@ -59,21 +52,43 @@ async function generateSubmit(form) {
let formData = Object.fromEntries(new FormData(form));
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;

// Post as JSON
document.querySelector('progress').setAttribute('max', formData.steps);

// Post as JSON, using Fetch streaming to get results
fetch(form.action, {
method: form.method,
body: JSON.stringify(formData),
}).then(async (result) => {
let data = await result.json();
}).then(async (response) => {
const reader = response.body.getReader();

let noOutputs = true;
while (true) {
let {value, done} = await reader.read();
value = new TextDecoder().decode(value);
if (done) break;

for (let event of value.split('\n').filter(e => e !== '')) {
const data = JSON.parse(event);

if (data.event == 'result') {
noOutputs = false;
document.querySelector("#no-results-message")?.remove();

for (let [file, seed] of data.files) {
appendOutput(file, seed, data.config);
}
} else if (data.event == 'step') {
document.querySelector('progress').setAttribute('value', data.step.toString());
}
}
}

// Re-enable form, remove no-results-message
form.querySelector('fieldset').removeAttribute('disabled');
document.querySelector("#prompt").value = prompt;
document.querySelector('progress').setAttribute('value', '0');

if (data.outputs.length != 0) {
document.querySelector("#no-results-message")?.remove();
appendOutputs(data.outputs);
} else {
if (noOutputs) {
alert("Error occurred while generating.");
}
});
Expand Down