In [4]:
import torch
from diffusers import DiffusionPipeline

from safetensors.torch import load_file

pipe = DiffusionPipeline.from_pretrained(
    "segmind/SSD-1B",
    torch_dtype=torch.float16,
).to("cuda")

tensors = load_file("/root/bigdisk/project_structured_prompt/stage1_sdxl_training/checkpoint_lr_high/unet/checkpoint-12001.unet.safetensors")
prefix = 'tt'

with torch.no_grad():
    model = pipe.unet
    prefix = "finetuned"
    for k, v in tensors.items():
        print(
            f"Weight diff for {k}: {torch.abs(model.state_dict()[k].flatten()[:8].cpu() - v.flatten()[:8].cpu()).mean()}"
        )
        v = v.to("cuda").to(torch.float16)

pipe.unet.load_state_dict(tensors, strict=False)  # should take < 2 seconds


Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading pipeline components...: 100%|██████████| 7/7 [00:09<00:00,  1.39s/it]


Weight diff for down_blocks.1.attentions.0.norm.bias: 0.000795598840340972
Weight diff for down_blocks.1.attentions.0.proj_in.bias: 0.00035599491093307734
Weight diff for down_blocks.1.attentions.0.proj_in.weight: 0.00046758464304730296
Weight diff for down_blocks.1.attentions.0.proj_out.bias: 0.0003796196833718568
Weight diff for down_blocks.1.attentions.0.proj_out.weight: 0.0006582210771739483
Weight diff for down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight: 0.0008608432835899293
Weight diff for down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias: 0.00046892924001440406
Weight diff for down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight: 0.0005575929535552859
Weight diff for down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight: 0.0005012233159504831
Weight diff for down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight: 0.0005378096248023212
Weight diff for down_blocks.1.attentions.0.transformer_blocks.0.attn2

_IncompatibleKeys(missing_keys=['conv_in.weight', 'conv_in.bias', 'time_embedding.linear_1.weight', 'time_embedding.linear_1.bias', 'time_embedding.linear_2.weight', 'time_embedding.linear_2.bias', 'add_embedding.linear_1.weight', 'add_embedding.linear_1.bias', 'add_embedding.linear_2.weight', 'add_embedding.linear_2.bias', 'down_blocks.0.resnets.0.norm1.weight', 'down_blocks.0.resnets.0.norm1.bias', 'down_blocks.0.resnets.0.conv1.weight', 'down_blocks.0.resnets.0.conv1.bias', 'down_blocks.0.resnets.0.time_emb_proj.weight', 'down_blocks.0.resnets.0.time_emb_proj.bias', 'down_blocks.0.resnets.0.norm2.weight', 'down_blocks.0.resnets.0.norm2.bias', 'down_blocks.0.resnets.0.conv2.weight', 'down_blocks.0.resnets.0.conv2.bias', 'down_blocks.0.resnets.1.norm1.weight', 'down_blocks.0.resnets.1.norm1.bias', 'down_blocks.0.resnets.1.conv1.weight', 'down_blocks.0.resnets.1.conv1.bias', 'down_blocks.0.resnets.1.time_emb_proj.weight', 'down_blocks.0.resnets.1.time_emb_proj.bias', 'down_blocks.0.res

In [5]:
gen = torch.Generator().manual_seed(10)

import pandas as pd
import numpy as np
from PIL import Image
from PIL import ImageDraw, ImageFont

data = pd.read_csv("/root/bigdisk/project_structured_prompt/stage0_prompt_decomposition/scripts/journeydb_subsampled.csv")
# random split with seed 0
data = data.sample(frac=1, random_state=0)

data = data.iloc[int(len(data) * 0.8) :]

# reset index
data = data.reset_index(drop=True)
STEP = 4
for i in range(0, 2, STEP):
    captions = data["content"][i:i + STEP].tolist()
    captions_2 = data["style"][i:i + STEP].tolist()
    original_images = data["img_path"][i:i + STEP].tolist()
    print(captions)
    img = pipe(
        height=1024,
        width=1024,
        prompt=captions,
        prompt_2=captions_2,
        generator=gen,
        guidance_scale=4.0,
    ).images
    for j, im in enumerate(img):
        org_im = Image.open(original_images[j])
        # concat images
        
        org_im = org_im.resize((512, 512))
        org_im = Image.fromarray(np.array(org_im))
        org_im = np.array(org_im)

        im = im.resize((512, 512))
        im = np.array(im)
        
        im = np.concatenate((org_im, im), axis=1)
        # put the caption on bottom of the iamge
    
        im = Image.fromarray(im)
        
        caption = captions[j] + "\n" + captions_2[j]
        font = ImageFont.load_default()
        text_w, text_h = 25, 25
        caption_image = Image.new('RGB', (im.width, im.height + text_h), (255, 255, 255))
        draw = ImageDraw.Draw(caption_image)
        caption_image.paste(im, (0, 0))
        draw.text((0, im.height), caption, font=font, fill='black')

        im = caption_image
                
            

        im.save(f"images/{i + j}.png")

['Jason the killer movie cover', 'Leonard, hairy baby child actor Johnny Galecki, beautiful eyes, realistic toddler, portrait, black, dark, dark background', 'goddess, guest', 'Tyler Hoechlin, Slytherin robes']


100%|██████████| 50/50 [00:15<00:00,  3.32it/s]
