In [1]:
import copy
from diffusers import UNet2DConditionModel
import torch
from src.models.evolve_prior_model import PriorModelEvolution

### Implementation

In [2]:
device = 'cuda'
DS_unet = UNet2DConditionModel.from_pretrained(
    'Lykon/dreamshaper-8',
    subfolder='unet',
    torch_dtype=torch.float16
).to(device)
SD_inpaint_unet = UNet2DConditionModel.from_pretrained(
    'stable-diffusion-v1-5/stable-diffusion-inpainting',
    subfolder='unet',
    torch_dtype=torch.float16
).to(device)
SD_unet = UNet2DConditionModel.from_pretrained(
    'stable-diffusion-v1-5/stable-diffusion-v1-5',
    subfolder='unet',
    torch_dtype=torch.float16
).to(device)

An error occurred while trying to fetch stable-diffusion-v1-5/stable-diffusion-inpainting: stable-diffusion-v1-5/stable-diffusion-inpainting does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.


In [3]:
alpha = 1.0
beta = 1.1

# alpha * (W_inp - W_base)
for (name, param1), (_, param2) in zip(SD_inpaint_unet.named_parameters(), SD_unet.named_parameters()):
    if 'conv_in' in name:
        with torch.no_grad():
            if len(param1.data.shape) == 4:
                # conv weights
                param1.data[:, :4, :, :].sub_(param2) * alpha
                param1.data[:, 4:, :, :] * alpha
            else:
                # conv bias
                param1.data.sub_(param2) * alpha
    else:
        with torch.no_grad():
            param1.data.sub_(param2) * alpha

# beta * (W_DS - W_base)
for (name, param1), (_, param2) in zip(DS_unet.named_parameters(), SD_unet.named_parameters()):
    with torch.no_grad():
        param1.data.sub_(param2) * beta

# W_base + alpha * (W_inp - W_base) + beta * (W_DS - W_base)
# Here, we treat W_inp as W_base in the above expression to implement easier.
for (name, param1), (_, param2), (_, param3) in zip(SD_inpaint_unet.named_parameters(), SD_unet.named_parameters(), DS_unet.named_parameters()):
    if 'conv_in' in name:
        if len(param1.data.shape) == 4:
            param1.data[:, :4, :, :].add_(param2)
            param1.data[:, :4, :, :].add_(param3)
        else:
            param1.data.add_(param2)
            param1.data.add_(param3)
    else:
        param1.data.add_(param2)
        param1.data.add_(param3)

In [4]:
SD_inpaint_unet.conv_in.bias.data

tensor([-0.1215, -0.2026,  0.1359,  0.0339,  0.0492, -0.0740, -0.0136,  0.0609,
        -0.0308,  0.0450, -0.0215, -0.0568,  0.1191,  0.0833, -0.0334, -0.1379,
        -0.1210,  0.0368,  0.0194,  0.0812, -0.1058,  0.0873, -0.0411,  0.0468,
         0.0703,  0.0294,  0.1013, -0.1018, -0.0536,  0.0298, -0.0308,  0.0386,
         0.0598, -0.0557,  0.0455,  0.0764, -0.1082,  0.0995,  0.0306,  0.0428,
         0.0332, -0.0571, -0.1096, -0.0758,  0.0496, -0.0625,  0.0792,  0.0332,
        -0.1465, -0.0933, -0.1064, -0.1371,  0.0522,  0.0326, -0.1034,  0.1141,
         0.0147,  0.0958,  0.0993, -0.1476,  0.1267, -0.0881, -0.1114, -0.1041,
         0.0145, -0.0811,  0.1450, -0.0221, -0.0632,  0.1099,  0.0255,  0.1118,
        -0.0884, -0.0123, -0.0899,  0.0185,  0.1083,  0.1432, -0.0765,  0.1047,
         0.0667, -0.1564,  0.0427, -0.1071, -0.1011,  0.1199,  0.0596,  0.1506,
        -0.1110, -0.1161,  0.1810,  0.1586,  0.1204, -0.1119, -0.1589, -0.0591,
        -0.1099, -0.0024,  0.1182, -0.06

In [4]:
merged_unet = copy.deepcopy(SD_inpaint_unet)

In [None]:
def evolve_prior_model(base, inpaint, shaper):
    """ Prior Model Evolution
    (section 3.3 in https://arxiv.org/abs/2405.18172)
    
    This implementation use in-place PyTorch operation for tensors.
    """
    # default values from paper
    alpha = 1.0
    beta = 1.1

    # alpha * (W_inp - W_base)
    for (name, param1), (_, param2) in zip(inpaint.named_parameters(), base.named_parameters()):
        if 'conv_in' in name:
            with torch.no_grad():
                if len(param1.data.shape) == 4:
                    # conv weights
                    param1.data[:, :4, :, :].sub_(param2) * alpha
                    param1.data[:, 4:, :, :] * alpha
                else:
                    # conv bias
                    param1.data.sub_(param2) * alpha
        else:
            with torch.no_grad():
                param1.data.sub_(param2) * alpha

    # beta * (W_DS - W_base)
    for (name, param1), (_, param2) in zip(shaper.named_parameters(), base.named_parameters()):
        with torch.no_grad():
            param1.data.sub_(param2) * beta
    
    # W_base + alpha * (W_inp - W_base) + beta * (W_DS - W_base)
    # Here, we treat W_inp as W_base in the above expression to implement easier.
    for (name, param1), (_, param2), (_, param3) in zip(inpaint.named_parameters(), base.named_parameters(), shaper.named_parameters()):
        if 'conv_in' in name:
            if len(param1.data.shape) == 4:
                param1.data[:, :4, :, :].add_(param2)
                param1.data[:, :4, :, :].add_(param3)
            else:
                param1.data.add_(param2)
                param1.data.add_(param3)
        else:
            param1.data.add_(param2)
            param1.data.add_(param3)
    
    return copy.deepcopy(inpaint)

### Testing

In [2]:
model_evolver = PriorModelEvolution()

In [3]:
device = 'cuda'
my_unet = UNet2DConditionModel.from_pretrained(
    'stable-diffusion-v1-5/stable-diffusion-inpainting',
    subfolder='unet',
    torch_dtype=torch.float16
).to(device)

An error occurred while trying to fetch stable-diffusion-v1-5/stable-diffusion-inpainting: stable-diffusion-v1-5/stable-diffusion-inpainting does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.


In [4]:
model_evolver(my_unet)

UNet2DConditionModel(
  (conv_in): Conv2d(9, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): CrossAttnDownBlock2D(
      (attentions): ModuleList(
        (0-1): 2 x Transformer2DModel(
          (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=320, out_features=320, bias=False)
                (to_v): Linear(in_features=320, out_fe

In [4]:
del model_evolver
torch.cuda.empty_cache()