In [1]:
from diffusers import StableDiffusionPipeline
import ttools.modules
from tqdm import tqdm
import torch
import random
import pydiffvg
import skimage
import skimage.io
import PIL
import utils
import warnings
warnings.filterwarnings("ignore", category=UserWarning) 

device = (
    "mps"
    if torch.backends.mps.is_available()
    else "cuda"
    if torch.cuda.is_available()
    else "cpu"
)
print(device)

cuda


In [3]:
num_paths = 500
max_width = 4.0
use_lpips_loss = True
num_iter = 100
use_blob = True

In [4]:
model_path = "CompVis/stable-diffusion-v1-4"
pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to(device)

height = 512
width = 512
num_images_per_prompt = 1
num_inference_steps = num_iter
guidance_scale = 7.5
do_classifier_free_guidance = guidance_scale > 1.0
generator = None

pipeline.vae.requires_grad_(False)
pipeline.unet.requires_grad_(False)
pipeline.text_encoder.requires_grad_(False)

CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0): CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05, ele

In [5]:
# 3. Prepare prompt embeddings
prompt = "a panda rowing a boat minimal 2d vector graphics"
if prompt is not None and isinstance(prompt, str):
    batch_size = 1
else:
    batch_size = len(prompt)
prompt_embeds = pipeline._encode_prompt(
    prompt,
    device,
    num_images_per_prompt,
    do_classifier_free_guidance,
)

# 4. Prepare timesteps
pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = pipeline.scheduler.timesteps

In [6]:
pydiffvg.set_print_timing(False)

gamma = 1.0

# Set the device
pydiffvg.set_use_gpu(True)

perception_loss = ttools.modules.LPIPS().to(pydiffvg.get_device())

# Load the image and scale it to [0, 1]
# target = torch.from_numpy(skimage.io.imread(target)).to(torch.float32) / 255.0
# target = target.pow(gamma)
# target = target.to(pydiffvg.get_device())
# target = target.unsqueeze(0)
# target = target.permute(0, 3, 1, 2) # NHWC -> NCHW
#target = torch.nn.functional.interpolate(target, size = [256, 256], mode = 'area')
canvas_width, canvas_height = width, height

random.seed(1234)
torch.manual_seed(1234)

shapes = []
shape_groups = []
if use_blob:
    for i in range(num_paths):
        num_segments = random.randint(3, 5)
        num_control_points = torch.zeros(num_segments, dtype = torch.int32) + 2
        points = []
        p0 = (random.random(), random.random())
        points.append(p0)
        for j in range(num_segments):
            radius = 0.05
            p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5))
            p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5))
            p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5))
            points.append(p1)
            points.append(p2)
            if j < num_segments - 1:
                points.append(p3)
                p0 = p3
        points = torch.tensor(points)
        points[:, 0] *= canvas_width
        points[:, 1] *= canvas_height
        path = pydiffvg.Path(num_control_points = num_control_points,
                                points = points,
                                stroke_width = torch.tensor(1.0),
                                is_closed = True)
        shapes.append(path)
        path_group = pydiffvg.ShapeGroup(shape_ids = torch.tensor([len(shapes) - 1]),
                                            fill_color = torch.tensor([random.random(),
                                                                    random.random(),
                                                                    random.random(),
                                                                    random.random()]))
        shape_groups.append(path_group)
else:
    for i in range(num_paths):
        num_segments = random.randint(1, 3)
        num_control_points = torch.zeros(num_segments, dtype = torch.int32) + 2
        points = []
        p0 = (random.random(), random.random())
        points.append(p0)
        for j in range(num_segments):
            radius = 0.05
            p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5))
            p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5))
            p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5))
            points.append(p1)
            points.append(p2)
            points.append(p3)
            p0 = p3
        points = torch.tensor(points)
        points[:, 0] *= canvas_width
        points[:, 1] *= canvas_height
        #points = torch.rand(3 * num_segments + 1, 2) * min(canvas_width, canvas_height)
        path = pydiffvg.Path(num_control_points = num_control_points,
                                points = points,
                                stroke_width = torch.tensor(1.0),
                                is_closed = False)
        shapes.append(path)
        path_group = pydiffvg.ShapeGroup(shape_ids = torch.tensor([len(shapes) - 1]),
                                            fill_color = None,
                                            stroke_color = torch.tensor([random.random(),
                                                                        random.random(),
                                                                        random.random(),
                                                                        random.random()]))
        shape_groups.append(path_group)

scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, shape_groups)

render = pydiffvg.RenderFunction.apply
img = render(canvas_width, # width
                canvas_height, # height
                2,   # num_samples_x
                2,   # num_samples_y
                0,   # seed
                None,
                *scene_args)
pydiffvg.imwrite(img.cpu(), './init.png', gamma=gamma)

points_vars = []
stroke_width_vars = []
color_vars = []
for path in shapes:
    path.points.requires_grad = True
    points_vars.append(path.points)
if not use_blob:
    for path in shapes:
        path.stroke_width.requires_grad = True
        stroke_width_vars.append(path.stroke_width)
if use_blob:
    for group in shape_groups:
        group.fill_color.requires_grad = True
        color_vars.append(group.fill_color)
else:
    for group in shape_groups:
        group.stroke_color.requires_grad = True
        color_vars.append(group.stroke_color)

# Optimizers for points, 
points_optim = torch.optim.Adam(points_vars, lr=1.0)
if len(stroke_width_vars) > 0:
    width_optim = torch.optim.Adam(stroke_width_vars, lr=0.1)
color_optim = torch.optim.Adam(color_vars, lr=0.01)

for t in tqdm(range(num_iter)):

    points_optim.zero_grad()
    if len(stroke_width_vars) > 0:
        width_optim.zero_grad()
    color_optim.zero_grad()
    # Forward pass: render the image.
    scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, shape_groups)
    img = render(canvas_width, 
                    canvas_height, 
                    2,   # num_samples_x
                    2,   # num_samples_y
                    t,   # seed
                    None,
                    *scene_args)

    # Compose img with white background
    img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device = pydiffvg.get_device()) * (1 - img[:, :, 3:4])
    # Save the intermediate render.

    # pydiffvg.imwrite(img.cpu(), 'results/iter_{}.png'.format(t), gamma=gamma)
    img = img[:, :, :3]
    # Convert img from HWC to NCHW
    img = img.unsqueeze(0)
    img = img.permute(0, 3, 1, 2) # NHWC -> NCHW
    
    with torch.no_grad():
        # 5. Prepare latents
        latents = img + 0.01 * torch.randn_like(img)
        latents = pipeline.vae.encode(latents.to(dtype=torch.float32)).latent_dist.sample()
        latents = latents * pipeline.vae.config.scaling_factor
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)

        # 6. Predict the noise residual
        noise_pred = pipeline.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds,
            return_dict=False,
        )[0]

        # 7. perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # 8. Compute the previous noisy sample x_t -> x_t-1
        latents = pipeline.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
        target = pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor, return_dict=False)[0]

    if use_lpips_loss:
        loss = perception_loss(img, target) + (img.mean() - target.mean()).pow(2)
    else:
        loss = (img - target).pow(2).mean()

    # Backpropagate the gradients
    loss.backward()

    # Perform Gradient Descent
    points_optim.step()
    if len(stroke_width_vars) > 0:
        width_optim.step()
    color_optim.step()

    if len(stroke_width_vars) > 0:
        for path in shapes:
            path.stroke_width.data.clamp_(1.0, max_width)
    if use_blob:
        for group in shape_groups:
            group.fill_color.data.clamp_(0.0, 1.0)
    else:
        for group in shape_groups:
            group.stroke_color.data.clamp_(0.0, 1.0)

    # if t % 10 == 0 or t == num_iter - 1:
    #     pydiffvg.save_svg('results/painterly_rendering/iter_{}.svg'.format(t), canvas_width, canvas_height, shapes, shape_groups)

img = render(canvas_width, # width
                canvas_height, # height
                2,   # num_samples_x
                2,   # num_samples_y
                0,   # seed
                None,
                *scene_args)
# Compose img with white background
img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device = pydiffvg.get_device()) * (1 - img[:, :, 3:4])
# Save the intermediate render.
pydiffvg.imwrite(img.cpu(), './final.png'.format(t), gamma=gamma)

LPIPS is untested
100%|██████████| 100/100 [01:01<00:00,  1.63it/s]


In [None]:
pydiffvg.set_print_timing(False)

gamma = 1.0

# Set the device
pydiffvg.set_use_gpu(True)

perception_loss = ttools.modules.LPIPS().to(pydiffvg.get_device())

# Load the image and scale it to [0, 1]
target = torch.from_numpy(skimage.io.imread(target)).to(torch.float32) / 255.0
target = target.pow(gamma)
target = target.to(pydiffvg.get_device())
target = target.unsqueeze(0)
target = target.permute(0, 3, 1, 2) # NHWC -> NCHW
#target = torch.nn.functional.interpolate(target, size = [256, 256], mode = 'area')
canvas_width, canvas_height = target.shape[3], target.shape[2]

random.seed(1234)
torch.manual_seed(1234)

shapes = []
shape_groups = []
if use_blob:
    for i in range(num_paths):
        num_segments = random.randint(3, 5)
        num_control_points = torch.zeros(num_segments, dtype = torch.int32) + 2
        points = []
        p0 = (random.random(), random.random())
        points.append(p0)
        for j in range(num_segments):
            radius = 0.05
            p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5))
            p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5))
            p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5))
            points.append(p1)
            points.append(p2)
            if j < num_segments - 1:
                points.append(p3)
                p0 = p3
        points = torch.tensor(points)
        points[:, 0] *= canvas_width
        points[:, 1] *= canvas_height
        path = pydiffvg.Path(num_control_points = num_control_points,
                                points = points,
                                stroke_width = torch.tensor(1.0),
                                is_closed = True)
        shapes.append(path)
        path_group = pydiffvg.ShapeGroup(shape_ids = torch.tensor([len(shapes) - 1]),
                                            fill_color = torch.tensor([random.random(),
                                                                    random.random(),
                                                                    random.random(),
                                                                    random.random()]))
        shape_groups.append(path_group)
else:
    for i in range(num_paths):
        num_segments = random.randint(1, 3)
        num_control_points = torch.zeros(num_segments, dtype = torch.int32) + 2
        points = []
        p0 = (random.random(), random.random())
        points.append(p0)
        for j in range(num_segments):
            radius = 0.05
            p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5))
            p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5))
            p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5))
            points.append(p1)
            points.append(p2)
            points.append(p3)
            p0 = p3
        points = torch.tensor(points)
        points[:, 0] *= canvas_width
        points[:, 1] *= canvas_height
        #points = torch.rand(3 * num_segments + 1, 2) * min(canvas_width, canvas_height)
        path = pydiffvg.Path(num_control_points = num_control_points,
                                points = points,
                                stroke_width = torch.tensor(1.0),
                                is_closed = False)
        shapes.append(path)
        path_group = pydiffvg.ShapeGroup(shape_ids = torch.tensor([len(shapes) - 1]),
                                            fill_color = None,
                                            stroke_color = torch.tensor([random.random(),
                                                                        random.random(),
                                                                        random.random(),
                                                                        random.random()]))
        shape_groups.append(path_group)

scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, shape_groups)

render = pydiffvg.RenderFunction.apply
img = render(canvas_width, # width
                canvas_height, # height
                2,   # num_samples_x
                2,   # num_samples_y
                0,   # seed
                None,
                *scene_args)
pydiffvg.imwrite(img.cpu(), './init.png', gamma=gamma)

points_vars = []
stroke_width_vars = []
color_vars = []
for path in shapes:
    path.points.requires_grad = True
    points_vars.append(path.points)
if not use_blob:
    for path in shapes:
        path.stroke_width.requires_grad = True
        stroke_width_vars.append(path.stroke_width)
if use_blob:
    for group in shape_groups:
        group.fill_color.requires_grad = True
        color_vars.append(group.fill_color)
else:
    for group in shape_groups:
        group.stroke_color.requires_grad = True
        color_vars.append(group.stroke_color)

# Optimizers for points, 
points_optim = torch.optim.Adam(points_vars, lr=1.0)
if len(stroke_width_vars) > 0:
    width_optim = torch.optim.Adam(stroke_width_vars, lr=0.1)
color_optim = torch.optim.Adam(color_vars, lr=0.01)

for t in tqdm(range(num_iter)):

    points_optim.zero_grad()
    if len(stroke_width_vars) > 0:
        width_optim.zero_grad()
    color_optim.zero_grad()
    # Forward pass: render the image.
    scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, shape_groups)
    img = render(canvas_width, 
                    canvas_height, 
                    2,   # num_samples_x
                    2,   # num_samples_y
                    t,   # seed
                    None,
                    *scene_args)

    # Compose img with white background
    img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device = pydiffvg.get_device()) * (1 - img[:, :, 3:4])
    # Save the intermediate render.

    # pydiffvg.imwrite(img.cpu(), 'results/iter_{}.png'.format(t), gamma=gamma)
    img = img[:, :, :3]
    # Convert img from HWC to NCHW
    img = img.unsqueeze(0)
    img = img.permute(0, 3, 1, 2) # NHWC -> NCHW

    

    if use_lpips_loss:
        loss = perception_loss(img, target) + (img.mean() - target.mean()).pow(2)
    else:
        loss = (img - target).pow(2).mean()

    # Backpropagate the gradients
    loss.backward()

    # Perform Gradient Descent
    points_optim.step()
    if len(stroke_width_vars) > 0:
        width_optim.step()
    color_optim.step()

    if len(stroke_width_vars) > 0:
        for path in shapes:
            path.stroke_width.data.clamp_(1.0, max_width)
    if use_blob:
        for group in shape_groups:
            group.fill_color.data.clamp_(0.0, 1.0)
    else:
        for group in shape_groups:
            group.stroke_color.data.clamp_(0.0, 1.0)

    # if t % 10 == 0 or t == num_iter - 1:
    #     pydiffvg.save_svg('results/painterly_rendering/iter_{}.svg'.format(t), canvas_width, canvas_height, shapes, shape_groups)

img = render(canvas_width, # width
                canvas_height, # height
                2,   # num_samples_x
                2,   # num_samples_y
                0,   # seed
                None,
                *scene_args)
# Compose img with white background
img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device = pydiffvg.get_device()) * (1 - img[:, :, 3:4])
# Save the intermediate render.
pydiffvg.imwrite(img.cpu(), './final.png'.format(t), gamma=gamma)

In [None]:



# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
    for i, t in enumerate(timesteps):
        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)

        # predict the noise residual
        noise_pred = pipeline.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds,
            return_dict=False,
        )[0]

        # perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_t-1
        latents = pipeline.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

        # call the callback, if provided
        if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
            progress_bar.update()

image = pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor, return_dict=False)[0]

In [None]:
image = utils.denormalize(image)
image = utils.pt_to_numpy(image)
image = utils.numpy_to_pil(image)
image[0]