In [1]:
import numpy as np
import glob
import random
import re
import torch
import PIL
import argparse
import os
import sys
sys.path.insert(0, '/data2/mhf/DXL/Lingxiao/Codes/HypDiffusion')
from torchvision.utils import make_grid
from tqdm import tqdm, trange
from omegaconf import OmegaConf
from PIL import Image
from itertools import islice
from einops import rearrange, repeat
from torch import autocast
from contextlib import nullcontext
from pytorch_lightning import seed_everything
import cv2
import time
from ldm.util import instantiate_from_config
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim_org import DDIMSampler
from ldm.models.diffusion.dpm_solver_org import DPMSolverSampler

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

  from .autonotebook import tqdm as notebook_tqdm


[2025-01-16 12:01:39,027] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


## Define neccessary functions

In [2]:
from notebook_utils.pmath import *
# we utilize geoopt package for hyperbolic calculation
import geoopt.manifolds.stereographic.math as gmath
def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


def load_img(path, size=[256, 256]):
    image = cv2.imread(path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    w, h = image.shape[:2]
    # print(f"loaded input image of size ({w}, {h}) from {path}")
    # resize to integer multiple of 32
    # w, h = map(lambda x: x - x % 32, (w, h))
    w, h = size
    image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
    image = np.array(image).astype(np.uint8)

    image = (image / 127.5 - 1.0).astype(np.float32)
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return image


def get_unconditional_embedding(model, scale, n_samples, device, prompts):
    # return the learned unconditioning
    if scale != 1.0:
        _, _, _, _, uc = model.get_learned_conditioning(
            n_samples * [torch.zeros((1, 3, 224, 224)).to(device)])
    else:
        _, _, _, _, uc = model.get_learned_conditioning(prompts)

    return uc

def feature_fusion(model, prompt_1, prompt_2, alpha):
    # fuse the feature with different attribute levels in hyperbolic space
    '''
    inputs:
    prompt_1: the first image
    prompt_2: the second image
    alpha: the fusion ratio between the two prompts, value: [0, 6.2126]
    
    outputs:
    fused_hyp_code: the fused latent code in hyperbolic space
    '''
    _, _, _, hyp_code_1, _ = model.get_learned_conditioning(prompt_1)
    _, _, _, hyp_code_2, _ = model.get_learned_conditioning(prompt_2)
    rescaled_hyp_code_1 = rescale(alpha, hyp_code_1)
    rescaled_hyp_code_2 = rescale(alpha, hyp_code_2)
    delta_hyp_code_1 = mobius_add(hyp_code_1, -rescaled_hyp_code_1)
    delta_hyp_code_2 = mobius_add(hyp_code_2, -rescaled_hyp_code_2)
    fused_hyp_code = mobius_add(rescaled_hyp_code_1, delta_hyp_code_2)
    return fused_hyp_code

def get_hyp_codes(model, prompts):
    # return latent codes in the hyperbolic space for the given prompts
    _, _, feature, feature_dist, _ = model.get_learned_conditioning(prompts)
    return feature, feature_dist

def get_hyp_codes_given_feature(model, feature):
    # return latent codes in the hyperbolic space for the given latent codes in CLIP space
    _, _, feature, feature_dist, _ = model.get_learned_conditioning(feature, input_feature=False, input_code=True)
    return feature, feature_dist

def get_condition_given_feature(model, feature):
    # return latent codes in the CLIP space for the given latent codes in hyperbolic space
    _, _, _, _, feature_euc = model.get_learned_conditioning(feature, input_feature=False, input_code=True)
    return feature_euc

def get_condition_given_hyp_codes(model, hyp_codes):
    # return latent codes in the CLIP space for the given latent codes in hyperbolic space
    _, _, _, _, feature_euc = model.get_learned_conditioning(hyp_codes, input_feature=True)
    return feature_euc


# rescale function
def rescale(target_radius, x):
    r_change = target_radius / \
        dist0(gmath.mobius_scalar_mul(
            r=torch.tensor(1), x=x, k=torch.tensor(-1.0)))
    return gmath.mobius_scalar_mul(r=r_change, x=x, k=torch.tensor(-1.0))


# function for generating images with fixed radius (also contains raw geodesic images of 'shorten' images, and stretched images to boundary)
def geo_interpolate_fix_r(x, y, interval, target_radius, save_codes=False):
    feature_geo = []
    feature_geo_normalized = []
    dist_to_start = []
    feature_geo_current_target_boundaries = []
    target_radius_ratio = torch.tensor(target_radius/6.2126)
    geodesic_start_short = gmath.mobius_scalar_mul(
        r=target_radius_ratio, x=x, k=torch.tensor(-1.0))
    geodesic_end_short = gmath.mobius_scalar_mul(
        r=target_radius_ratio, x=y, k=torch.tensor(-1.0))
    index = 0
    for i in interval:
        # this is raw image on geodesic, instead of fixed radius
        feature_geo_current = gmath.geodesic(t=torch.tensor(
            i), x=geodesic_start_short, y=geodesic_end_short, k=torch.tensor(-1.0))

        # here we fix the radius and don't revert them now
        r_change = target_radius / \
            dist0(gmath.mobius_scalar_mul(r=torch.tensor(1),
                  x=feature_geo_current, k=torch.tensor(-1.0)))
        feature_geo.append(feature_geo_current)
        feature_geo_current_target_radius = gmath.mobius_scalar_mul(
            r=r_change, x=feature_geo_current, k=torch.tensor(-1.0))
        feature_geo_normalized.append(feature_geo_current_target_radius)
        dist = gmath.dist(
            geodesic_start_short, feature_geo_current_target_radius, k=torch.tensor(-1.0))
        dist_to_start.append(dist)

        # here is to revert the feature to boundary
        r_change_to_boundary = 6.2126 / \
            dist0(gmath.mobius_scalar_mul(r=torch.tensor(1),
                  x=feature_geo_current, k=torch.tensor(-1.0)))
        feature_geo_current_target_boundary = gmath.mobius_scalar_mul(
            r=r_change_to_boundary, x=feature_geo_current, k=torch.tensor(-1.0))
        feature_geo_current_target_boundaries.append(feature_geo_current_target_boundary)

    return feature_geo, feature_geo_normalized, feature_geo_current_target_boundaries, dist_to_start

# function for generating images with fixed radius with optional latent codes list output


def geo_interpolate_fix_r_with_codes(x, y, interval, target_radius):
    # please use this with batch_size = 1
    feature_geo = []
    feature_geo_normalized = []
    dist_to_start = []
    target_radius_ratio = torch.tensor(target_radius/6.2126)
    geodesic_start_short = gmath.mobius_scalar_mul(
        r=target_radius_ratio, x=x, k=torch.tensor(-1.0))
    geodesic_end_short = gmath.mobius_scalar_mul(
        r=target_radius_ratio, x=y, k=torch.tensor(-1.0))
    for i in interval:
        # this is raw image on geodesic, instead of fixed radius
        feature_geo_current = gmath.geodesic(t=torch.tensor(
            i), x=geodesic_start_short, y=geodesic_end_short, k=torch.tensor(-1.0))

        # here we fix the radius and don't revert them now
        r_change = target_radius / \
            dist0(gmath.mobius_scalar_mul(r=torch.tensor(1),
                  x=feature_geo_current, k=torch.tensor(-1.0)))
        feature_geo.append(feature_geo_current)
        feature_geo_current_target_radius = gmath.mobius_scalar_mul(
            r=r_change, x=feature_geo_current, k=torch.tensor(-1.0))
        feature_geo_normalized.append(feature_geo_current_target_radius)
        dist = gmath.dist(
            geodesic_start_short, feature_geo_current_target_radius, k=torch.tensor(-1.0))
        dist_to_start.append(dist)
        # print(feature_geo_current_target_radius.norm())

        # here is to revert the feature to boundary
        r_change_to_boundary = 6.2126 / \
            dist0(gmath.mobius_scalar_mul(r=torch.tensor(1),
                  x=feature_geo_current, k=torch.tensor(-1.0)))
        feature_geo_current_target_boundary = gmath.mobius_scalar_mul(
            r=r_change_to_boundary, x=feature_geo_current, k=torch.tensor(-1.0))
        # print(feature_geo_current_target_boundary.norm())

    return dist_to_start, feature_geo_current, feature_geo_current_target_radius, feature_geo_current_target_boundary


def geo_perturbation(x, distances, perturb_codes, target_radius=6.2126, num_samples=10, save_codes=False):
    """
    该函数在双曲空间围绕给定的点 x 生成随机扰动，并确保扰动点的半径与 x 相同。
    函数会为每个给定的双曲距离生成多个样本。

    参数:
        x (tensor): Poincaré圆盘中的起始点。
        distances (list): 目标的双曲距离列表，例如 [0.01, 0.1, 0.2]。
        target_radius (float): 要将扰动点缩放到的目标半径。
        num_samples (int): 每个距离生成的样本数量。
        save_codes (bool, 可选): 是否保存其他数据，默认为 False。
    
    返回:
        tuple: 包含以下列表的元组：
            - feature_geo: 未归一化的扰动点。
            - feature_geo_normalized: 归一化到目标半径的扰动点。
            - feature_geo_current_target_boundaries: 扰动点缩放到边界。
            - dist_to_start: 扰动点与起始点 x 的双曲距离。
            - perturbation_distances: 每个扰动点与起始点的双曲距离。
    """
    
    feature_geo = []
    feature_geo_normalized = []
    dist_to_start = []
    feature_geo_current_target_boundaries = []
    perturbation_distances = []  # 保存每个 feature_geo_current 和 geodesic_start_short 之间的距离
    
    # 1. 计算目标半径的比例，并缩放输入向量 x
    target_radius_ratio = torch.tensor(target_radius / 6.2126)
    
    # 缩放 x 到目标半径，得到 geodesic_start_short
    geodesic_start_short = gmath.mobius_scalar_mul(r=target_radius_ratio, x=x, k=torch.tensor(-1.0))
    print('start_radius', dist0(gmath.mobius_scalar_mul(r=torch.tensor(1), x=geodesic_start_short, k=torch.tensor(-1.0))))
    # 外层循环遍历每个距离
    for distance in distances:
        # 内层循环生成每个距离下的多个样本
        for perturb_code in perturb_codes:
            # 2. 生成随机方向
            # random_direction = torch.randn_like(geodesic_start_short)/100
            random_direction = perturb_code
            
            #print(dist0(gmath.mobius_scalar_mul(r=torch.tensor(1), x=random_direction, k=torch.tensor(-1.0))))
            # 3. 使用 dist0 计算随机方向的双曲距离，并调整使其与 geodesic_start_short 一样长
            r_change_direction_to = target_radius / dist0(gmath.mobius_scalar_mul(r=torch.tensor(1), x=random_direction, k=torch.tensor(-1.0)))
            geodesic_end_short = gmath.mobius_scalar_mul(r=r_change_direction_to, x=random_direction, k=torch.tensor(-1.0))
            # print(dist0(gmath.mobius_scalar_mul(r=torch.tensor(1), x=geodesic_end_short, k=torch.tensor(-1.0))))
            # 4. 计算 geodesic_start_short 和 geodesic_end_short 之间的双曲距离
            distance_x_y = gmath.dist(geodesic_start_short, geodesic_end_short, k=torch.tensor(-1.0))
            # print('current_distance',distance_x_y)
            # 5. 计算插值比例 t，确保 geodesic_start_short 和 feature_geo_current 之间的双曲距离等于指定的 distance
            if distance_x_y > 0:
                t = distance / distance_x_y  # 插值比例，确保按照双曲距离采样
            else:
                t = torch.tensor(1.0)  # 当距离极小时，设为1
            
            # 6. 沿着 geodesic_start_short 和 geodesic_end_short 插值生成 feature_geo_current
            feature_geo_current = gmath.geodesic(t=t, x=geodesic_start_short, y=geodesic_end_short, k=torch.tensor(-1.0))
            
            # 7. 保存未归一化的 feature_geo_current
            feature_geo.append(feature_geo_current)

            # print('current_radius', dist0(gmath.mobius_scalar_mul(r=torch.tensor(1), x=feature_geo_current, k=torch.tensor(-1.0))))
            # 8. 修正到目标半径，确保 feature_geo_current 的半径与 target_radius 一致
            r_change = target_radius / dist0(gmath.mobius_scalar_mul(r=torch.tensor(1), x=feature_geo_current, k=torch.tensor(-1.0)))
            # print('change ratio for interpolation sample',r_change)
            feature_geo_current_target_radius = gmath.mobius_scalar_mul(r=r_change, x=feature_geo_current, k=torch.tensor(-1.0))
            feature_geo_normalized.append(feature_geo_current_target_radius)
            
            # 9. 计算扰动点到起始点的双曲距离，并保存
            dist = gmath.dist(geodesic_start_short, feature_geo_current_target_radius, k=torch.tensor(-1.0))
            dist_to_start.append(dist)
            
            # 10. 保存每个 feature_geo_current 和 geodesic_start_short 的双曲距离
            perturbation_distances.append(gmath.dist(geodesic_start_short, feature_geo_current, k=torch.tensor(-1.0)))
            
            # 11. 将扰动点调整到边界
            r_change_to_boundary = 6.2126 / dist0(gmath.mobius_scalar_mul(r=torch.tensor(1), x=feature_geo_current, k=torch.tensor(-1.0)))
            feature_geo_current_target_boundary = gmath.mobius_scalar_mul(r=r_change_to_boundary, x=feature_geo_current, k=torch.tensor(-1.0))
            feature_geo_current_target_boundaries.append(feature_geo_current_target_boundary)
    
    return feature_geo, feature_geo_normalized, feature_geo_current_target_boundaries, dist_to_start, perturbation_distances

### Define animalfaces variables

In [76]:
# define variables
init_image_path = './inputs/same_domain_test/animalfaces/cls_4_2.jpg'
ref_image_path = './inputs/same_domain_test/animalfaces/cls_4_2.jpg'
ref_image_2_path = './inputs/same_domain_test/animalfaces/cls_4_2.jpg'
# sampling image
files = glob.glob("/data2/mhf/DXL/Lingxiao/datasets/animals/*/*.jpg")
sampled_imgs = random.sample(files, 0)

outdir = './outputs/animalfaces'
skip_grid = False
skip_save = True
ddim_steps = 50
ddim_eta = 0.0
n_iter = 1
C = 4
f = 8
n_samples = 5
n_rows = 0
scale = 1.3
strength = 0.95
config_path = './configs/stable-diffusion/v2_inference_animalfaces.yaml'
ckpt = '/data2/mhf/DXL/Lingxiao/Codes/Paint-by-Example-test/models/Paint-by-Example/animal_faces/2024-09-11T11-30-12_v1/checkpoints/epoch=000037.ckpt'
seed = 3408
precision = 'autocast'

### Define ffhq variables

In [14]:
# define variables
init_image_path = './inputs/same_domain_test/ffhq/2.png'
ref_image_path = './inputs/same_domain_test/ffhq/2.png'
# init_image_path = '/data2/mhf/DXL/Lingxiao/datasets/FFHQ/ffhq512/00009.png'
# ref_image_path = '/data2/mhf/DXL/Lingxiao/datasets/FFHQ/ffhq512/00009.png'
ref_image_2_path = './inputs/same_domain_test/ffhq/40.jpg'
# sampling image
files = glob.glob("/data2/mhf/DXL/Lingxiao/datasets/FFHQ/*/*.png")
sampled_imgs = random.sample(files, 2)

outdir = './outputs/ffhq'
skip_grid = False
skip_save = False
ddim_steps = 50
ddim_eta = 0.0
n_iter = 1
C = 4
f = 8
n_samples = 5
n_rows = 0
scale = 1.3
strength = 1.0
config_path = './configs/stable-diffusion/v2_inference_ffhq.yaml'
ckpt = '/data2/mhf/DXL/Lingxiao/Codes/Paint-by-Example-test/models/Paint-by-Example/vgg_faces/2024-10-02T02-12-51_v1/checkpoints/epoch=000037-ffhq.ckpt'
seed = 3408
precision = 'autocast'

### Define vggfaces variables

In [171]:
# define variables
init_image_path = './inputs/same_domain_test/vggfaces/cls_4_2.jpg'
ref_image_path = './inputs/same_domain_test/vggfaces/cls_4_2.jpg'
ref_image_2_path = './inputs/same_domain_test/vggfaces/3.jpg'
# sampling image
files = glob.glob("/data2/mhf/DXL/Lingxiao/datasets/vggfaces/*/*.jpg")
sampled_imgs = random.sample(files, 2)

outdir = './outputs/vggfaces'
skip_grid = False
skip_save = True
ddim_steps = 50
ddim_eta = 0.0
n_iter = 1
C = 4
f = 8
n_samples = 10
n_rows = 0
scale = 1.3
strength = 0.95
config_path = './configs/stable-diffusion/v2_inference_vggfaces.yaml'
ckpt = '/data2/mhf/DXL/Lingxiao/Codes/Paint-by-Example-test/models/Paint-by-Example/vgg_faces/2024-10-02T02-12-51_v1/checkpoints/epoch=000020-vgg.ckpt'
seed = 3408
precision = 'autocast'

### Define flowers variables

In [44]:
# define variables
init_image_path = './inputs/same_domain_test/flowers/cls_3_2.jpg'
ref_image_path = './inputs/same_domain_test/flowers/cls_3_2.jpg'
ref_image_2_path = './inputs/same_domain_test/flowers/cls_4_1.jpg'
# sampling image
files = glob.glob("/data2/mhf/DXL/Lingxiao/datasets/flowers/dataset/train/*/*.jpg")
sampled_imgs = random.sample(files, 10)

outdir = './outputs/flowers'
skip_grid = False
skip_save = True
ddim_steps = 50
ddim_eta = 0.0
n_iter = 1
C = 4
f = 8
n_samples = 8
n_rows = 0
scale = 1.3
strength = 1.0
config_path = './configs/stable-diffusion/v2_inference_flowers.yaml'
ckpt = '/data2/mhf/DXL/Lingxiao/Codes/Paint-by-Example-test/models/Paint-by-Example/flowers/2024-10-12T05-45-07_v1/checkpoints/epoch=000298.ckpt'
seed = 3408
precision = 'autocast'

### Define nabirds variables

In [34]:
# define variables
init_image_path = './inputs/same_domain_test/nabirds/cls_1_1.jpg'
ref_image_path = './inputs/same_domain_test/nabirds/cls_1_1.jpg'
ref_image_2_path = './inputs/same_domain_test/nabirds/cls_2_1.jpg'
# sampling image
files = glob.glob("/data2/mhf/DXL/Lingxiao/datasets/nabirds/images/*/*.jpg")
sampled_imgs = random.sample(files, 0)

outdir = './outputs/nabirds'
skip_grid = False
skip_save = True
ddim_steps = 50
ddim_eta = 0.0
n_iter = 1
C = 4
f = 8
n_samples = 5
n_rows = 0
scale = 1.3
strength = 0.95
config_path = './configs/stable-diffusion/v2_inference_nabirds.yaml'
ckpt = '/data2/mhf/DXL/Lingxiao/Codes/Paint-by-Example-test/models/Paint-by-Example/nabirds/2024-10-17T10-15-57_v1/checkpoints/epoch=000083.ckpt'
seed = 3408
precision = 'autocast'

### Load the model

In [4]:
# seed_everything(seed)

config = OmegaConf.load(f"{config_path}")
model = load_model_from_config(config, f"{ckpt}")

device = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu") 
model = model.to(device)
print("Successfully loaded model!")

sampler = DDIMSampler(model)


Loading model from /data2/mhf/DXL/Lingxiao/Codes/Paint-by-Example-test/models/Paint-by-Example/animal_faces/2024-09-11T11-30-12_v1/checkpoints/epoch=000037.ckpt
Global Step: 82620
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 865.91 M params.
making attention of type 'vanilla-xformers' with 512 in_channels
building MemoryEfficientAttnBlock with 512 in_channels...
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla-xformers' with 512 in_channels
building MemoryEfficientAttnBlock with 512 in_channels...


Some weights of the model checkpoint at /data2/mhf/DXL/Lingxiao/Cache/huggingface/hub/models--openai--clip-vit-large-patch14 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.10.self_attn.out_proj.weight', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.6.self_attn.k_proj.bias', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.1.layer_norm1.weight', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.5.layer_norm2.bias', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.e

Use hyperbolic: True
Loading HAE from checkpoint: /data2/mhf/DXL/Lingxiao/Codes/hyperediting/exp_out/hyper_animalfaces_512_5_30_init_v1/checkpoints/iteration_29000.pt
Successfully loaded model!


### Load Delta Mapper for text-guided editing

In [5]:
from argparse import Namespace
sys.path.insert(0, '/data2/mhf/DXL/Lingxiao/Codes')
from DeltaHyperEditing.models.delta_hyp_clip import hae_clip

# model_path = '/data2/mhf/DXL/Lingxiao/Codes/DeltaHyperEditing/exp_out/hyper_animalfaces_512_5_30_v2/checkpoints/iteration_52500.pt'
model_path = '/data2/mhf/DXL/Lingxiao/Codes/DeltaHyperEditing/exp_out/hyper_ffhq_512_5_30_v2/checkpoints/iteration_62500.pt'
# model_path = '/data2/mhf/DXL/Lingxiao/Codes/DeltaHyperEditing/exp_out/hyper_flowers_512_5_5_v1/checkpoints/iteration_102500.pt'
# model_path = '/data2/mhf/DXL/Lingxiao/Codes/DeltaHyperEditing/exp_out/hyper_nabirds_512_5_30_v1/checkpoints/iteration_162500.pt'
# model_path = '/data2/mhf/DXL/Lingxiao/Codes/DeltaHyperEditing/exp_out/hyper_vggfaces_512_5_30_v2/checkpoints/iteration_62500.pt'

ckpt = torch.load(model_path, map_location='cpu')
opts = ckpt['opts']
del ckpt
opts['checkpoint_path'] = model_path
opts['load_mapper'] = True
# instantialize model with checkpoints and args
opts = Namespace(**opts)
net = hae_clip(opts)
net.eval()
net.to(device)
print('Model successfully loaded!')

Some weights of the model checkpoint at /data2/mhf/DXL/Lingxiao/Cache/huggingface/hub/models--openai--clip-vit-large-patch14 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.embeddings.token_embedding.weight', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.7.layer_norm1.bias', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.10.self_attn.k_proj.bias', 'text_model.encoder.layers.9.layer_norm2.bias', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_model.encoder.layers.10.self_attn.q_proj.weight', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.lay

Use hyperbolic: True
Loading HAE from checkpoint: /data2/mhf/DXL/Lingxiao/Codes/DeltaHyperEditing/exp_out/hyper_ffhq_512_5_30_v2/checkpoints/iteration_62500.pt
Model successfully loaded!


In [77]:
os.makedirs(outdir, exist_ok=True)
outpath = outdir

batch_size = n_samples
n_rows = n_rows if n_rows > 0 else batch_size

sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1

### Load images

In [78]:
# load images
# load init image
assert os.path.isfile(init_image_path)
init_image = load_img(init_image_path, [512, 512]).to(device)
init_image_resized = load_img(init_image_path, [224, 224]).to(device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_image_resized = repeat(init_image_resized, '1 ... -> b ...', b=batch_size)
init_latent = model.get_first_stage_encoding(
    model.encode_first_stage(init_image))  # move to latent space

# load ref image
assert os.path.isfile(ref_image_path)
ref_image = load_img(ref_image_path, [224, 224]).to(device)
ref_image_resized = load_img(ref_image_path, [512, 512]).to(device)
ref_image = repeat(ref_image, '1 ... -> b ...', b=batch_size)
ref_image_resized = repeat(ref_image_resized, '1 ... -> b ...', b=batch_size)
init_ref_latent = model.get_first_stage_encoding(
    model.encode_first_stage(ref_image_resized))  # move to latent space

assert os.path.isfile(ref_image_2_path)
ref_image_2 = load_img(ref_image_2_path, [224, 224]).to(device)
ref_image_2 = repeat(ref_image_2, '1 ... -> b ...', b=batch_size)

# load sampled images
sampled_images = []
for sampled_img in sampled_imgs:
    assert os.path.isfile(sampled_img)
    sampled_image = load_img(sampled_img, [224, 224]).to(device)
    sampled_images.append(sampled_image)

## Edit latent codes using text instructions

In [28]:
prompt = 'a face of a baby'
prompt_delta = 'a crying face of a baby'
# prompt_delta = 'a smiling face of a baby'

In [29]:
fake_delta_s = net.get_fake_delta_s_given_data(ref_image, prompt, prompt_delta, feature_type='hyperbolic', device=device)
print(f"fake_delta_s shape: {fake_delta_s.shape}")
logits, ocodes, feature, feature_dist, feature_euc = model.get_learned_conditioning(ref_image)
_, ocodes_2, feature_2, feature_dist_2, feature_euc_2 = model.get_learned_conditioning(ref_image_2)
print(f"feature shape: {feature.shape}")
reconstruct_code = feature + fake_delta_s
# reconstruct_code = feature_dist + fake_delta_s
print(f"reconstruct_code shape: {reconstruct_code.shape}")
reconstruct_feature = get_condition_given_feature(model, reconstruct_code)
reconstruct_codes = [feature_euc, reconstruct_feature]
# reconstruct_codes = [feature_euc, reconstruct_code]

fake_delta_s shape: torch.Size([10, 1, 512])
feature shape: torch.Size([10, 1, 512])
reconstruct_code shape: torch.Size([10, 1, 512])


In [11]:
_, hyp_codes = get_hyp_codes_given_feature(model, reconstruct_code)
print(hyp_codes.shape)
hyp_code = hyp_codes[0].unsqueeze(0)
print(gmath.dist0(hyp_code, k=torch.tensor(-1.0)))
rescaled_codes = []
target_radii = [6.2126, 4, 2.5, 1, 0.5, 0]
for i in target_radii:
    hyp_code_rescaled = rescale(i, hyp_code)
    hyp_code_rescaled = repeat(hyp_code_rescaled, '1 ... -> b ...', b=batch_size)
    feature_euc = get_condition_given_hyp_codes(model, hyp_code_rescaled)
    rescaled_codes.append(feature_euc)

torch.Size([4, 512])
tensor([6.2126], device='cuda:0', grad_fn=<MulBackward1>)


In [63]:
interpolated_codes = []
interval = [0, 0.3, 0.4, 0.5, 0.6, 0.7, 1]
for i in interval:
    feature = (1-i) * feature_euc + i * reconstruct_feature
    interpolated_codes.append(feature)

## Manipulate latent codes in Hyperbolic space

### Moving latent codes from edge to center

In [113]:
_, hyp_code = get_hyp_codes(model, ref_image[0].unsqueeze(0))
print(hyp_code.shape)
print(gmath.dist0(hyp_code, k=torch.tensor(-1.0)))
# this is used for generating figure of varying radius in our paper
rescaled_codes = []
target_radii = [6.2126, 5, 4, 3, 2.5, 1, 0.5, 0]
for i in target_radii:
    hyp_code_rescaled = rescale(i, hyp_code)
    hyp_code_rescaled = repeat(hyp_code_rescaled, '1 ... -> b ...', b=batch_size)
    feature_euc = get_condition_given_hyp_codes(model, hyp_code_rescaled)
    rescaled_codes.append(feature_euc)
    

torch.Size([1, 512])
tensor([6.2126], device='cuda:0', grad_fn=<MulBackward1>)


### Interpolate latent codes in Hyperbolic space along the geodesic

In [48]:
_, hyp_code = get_hyp_codes(model, ref_image[0].unsqueeze(0))
_, hyp_code_2 = get_hyp_codes(model, ref_image_2[0].unsqueeze(0))
print(hyp_code.shape)
print(gmath.dist0(hyp_code, k=torch.tensor(-1.0)))
# this is used for generating figure of varying radius in our paper
interpolated_codes = []
interval = [0, 0.3, 0.4, 0.45, 0.48, 0.49, 0.495, 0.50, 0.505, 0.51, 0.52, 0.55, 0.6, 0.7, 0.8, 1]
feature_geo, feature_geo_current_target_radius, feature_geo_current_target_boundary, dist_to_start = geo_interpolate_fix_r(
    hyp_code, hyp_code_2, interval, target_radius=6.2126)
for i in feature_geo_current_target_boundary:
    feature = repeat(i, '1 ... -> b ...', b=batch_size)
    feature_euc = get_condition_given_hyp_codes(model, feature)
    interpolated_codes.append(feature_euc)

torch.Size([1, 512])
tensor([6.2126], device='cuda:0', grad_fn=<MulBackward1>)


### Feature fusion in hyperbolic space

In [57]:
fused_codes = []
fused_scale = [0, 0.45, 0.8, 1.0, 1.5, 2.0, 3, 5, 6.2126]
for scale in fused_scale:
    fused_hyp_code = feature_fusion(model, ref_image[0].unsqueeze(0), ref_image_2[0].unsqueeze(0), scale)
    fused_hyp_code = repeat(fused_hyp_code, '1 ... -> b ...', b=batch_size)
    fused_feature_euc = get_condition_given_hyp_codes(model, fused_hyp_code)
    fused_codes.append(fused_feature_euc)

### Perturbation given an image embedding with certain radius

In [28]:
perturbed_codes = []
logits, ocodes, feature, feature_dist, feature_euc = model.get_learned_conditioning(ref_image)
perturbed_codes.append(feature_euc)
_, hyp_code = get_hyp_codes(model, ref_image[0].unsqueeze(0))
_, perturb_codes = get_hyp_codes(model, sampled_images)
print(hyp_code.shape)
print(gmath.dist0(hyp_code, k=torch.tensor(-1.0)))
distances = [4.0]
num_samples = len(sampled_images)
feature_geo, feature_geo_normalized, feature_geo_current_target_boundaries, dist_to_start, perturbation_distances = geo_perturbation(
    hyp_code, distances, target_radius=6.2126, perturb_codes=perturb_codes, num_samples=num_samples)
for i in feature_geo_current_target_boundaries:
    feature = repeat(i, '1 ... -> b ...', b=batch_size)
    feature_euc = get_condition_given_hyp_codes(model, feature)
    perturbed_codes.append(feature_euc)

torch.Size([1, 512])
tensor([6.2126], device='cuda:0', grad_fn=<MulBackward1>)
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)


## Manipulate latent codes in Euclidean space

### Interpolate latent codes in Euclidean space

In [41]:
logits, ocodes, feature, feature_dist, feature_euc = model.get_learned_conditioning(ref_image)
_, ocodes, feature_2, feature_dist_2, feature_euc_2 = model.get_learned_conditioning(ref_image_2)
print('feature_euc shape:', feature_euc.shape)
interpolated_codes = []
interval = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
for i in interval:
    feature = (1-i) * feature_euc + i * feature_euc_2
    interpolated_codes.append(feature)

feature_euc shape: torch.Size([5, 1, 1024])


### Compare Images before and after hyperbolic space

In [8]:
logits, ocodes, feature, feature_dist, feature_euc = model.get_learned_conditioning(ref_image)
print('feature_euc shape:', feature_euc.shape)
print('ocodes shape:', ocodes.shape)
compare_codes = []
compare_codes.append(ocodes)
compare_codes.append(feature_euc)

feature_euc shape: torch.Size([4, 1, 1024])
ocodes shape: torch.Size([4, 1, 1024])


  k = torch.tensor(k)


### Get the condition only

In [79]:
logits, ocodes, feature, feature_dist, feature_euc = model.get_learned_conditioning(ref_image)
print('feature_euc shape:', feature_euc.shape)
print('ocodes shape:', ocodes.shape)
code = []
code.append(feature_euc)

feature_euc shape: torch.Size([5, 1, 1024])
ocodes shape: torch.Size([5, 1, 1024])


## Sample Images

In [80]:
scale = 1.3
strength = 1.0
skip_save = True
change_noise = False

In [81]:
sampler.make_schedule(ddim_num_steps=ddim_steps,
                      ddim_eta=ddim_eta, verbose=False)
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(strength * ddim_steps)
print(f"target t_enc is {t_enc} steps")

precision_scope = autocast if precision == "autocast" else nullcontext
n_rows = n_rows if n_rows > 0 else batch_size
uc = get_unconditional_embedding(model, scale, n_samples, device, ref_image)

with torch.no_grad():
    with precision_scope("cuda"):
        with model.ema_scope():
            tic = time.time()
            all_samples = list()
            shape = [C, 64, 64]
            if not change_noise:
                if strength < 1.0:
                    z_enc = sampler.stochastic_encode(
                        init_latent, torch.tensor([t_enc]*batch_size).to(device))
                # print(f"z_enc shape: {z_enc.shape}")
                else:
                    z_enc = torch.randn([n_samples, 4, 64, 64], device=device)
            # decode it
            # for c in rescaled_codes:
            # for c in interpolated_codes:
            # for c in reconstruct_codes:
            # for c in perturbed_codes:
            # for c in compare_codes:
            # for c in fused_codes:
            for c in code:
            # encode (scaled latent)
                if change_noise:
                    if strength < 1.0:
                        z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
                    # print(f"z_enc shape: {z_enc.shape}")
                    else:
                        z_enc = torch.randn([n_samples, 4, 64, 64], device=device)
                    
                if strength == 1.0:
                    samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                        conditioning=c,
                                                        batch_size=n_samples,
                                                        shape=shape,
                                                        verbose=False,
                                                        unconditional_guidance_scale=scale,
                                                        unconditional_conditioning=uc,
                                                        eta=ddim_eta,
                                                        x_T=z_enc)
                else:
                    samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=scale,
                                                unconditional_conditioning=uc)
                                            

                x_samples = model.decode_first_stage(samples_ddim)
                x_samples = torch.clamp(
                    (x_samples + 1.0) / 2.0, min=0.0, max=1.0)

                if not skip_save:
                    for x_sample in x_samples:
                        x_sample = 255. * \
                            rearrange(x_sample.cpu().numpy(),
                                        'c h w -> h w c')
                        Image.fromarray(x_sample.astype(np.uint8)).save(
                            os.path.join(sample_path, f"{base_count:05}.png"))
                        base_count += 1
                all_samples.append(x_samples)

                if not skip_grid:
                    # additionally, save as grid
                    grid = torch.stack(all_samples, 0)
                    grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                    grid = make_grid(grid, nrow=n_rows)

            # to image
            grid = 255. * \
                rearrange(grid, 'c h w -> h w c').cpu().numpy()
            Image.fromarray(grid.astype(np.uint8)).save(
                os.path.join(outpath, f'grid-{grid_count:04}.png'))
            print(f"grid saved to {outpath}/grid-{grid_count:04}.png")
            grid_count += 1
            del grid

        toc = time.time()

print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
        f" \nEnjoy.")

target t_enc is 50 steps
Data shape for DDIM sampling is (5, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:16<00:00,  3.08it/s]


grid saved to ./outputs/animalfaces/grid-0133.png
Your samples are ready and waiting for you here: 
./outputs/animalfaces 
 
Enjoy.


## Mean of the variation

In [133]:
sampler.make_schedule(ddim_num_steps=ddim_steps,
                      ddim_eta=ddim_eta, verbose=False)
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(strength * ddim_steps)
print(f"target t_enc is {t_enc} steps")

precision_scope = autocast if precision == "autocast" else nullcontext
n_rows = n_rows if n_rows > 0 else batch_size
uc = get_unconditional_embedding(model, scale, n_samples, device, ref_image)

with torch.no_grad():
    with precision_scope("cuda"):
        with model.ema_scope():
            tic = time.time()
            all_samples = list()
            shape = [C, 64, 64]
            # encode (scaled latent)
            if strength < 1.0:
                z_enc = sampler.stochastic_encode(
                    init_latent, torch.tensor([t_enc]*batch_size).to(device))
            # print(f"z_enc shape: {z_enc.shape}")
            else:
                z_enc = torch.randn([n_samples, 4, 64, 64], device=device)
            # decode it
            for c in rescaled_codes:
                samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                 conditioning=c,
                                                 batch_size=n_samples,
                                                 shape=shape,
                                                 verbose=False,
                                                 unconditional_guidance_scale=scale,
                                                 unconditional_conditioning=uc,
                                                 eta=ddim_eta,
                                                 x_T=z_enc)
                '''
                samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=scale,
                                            unconditional_conditioning=uc,)
                                            '''

                x_samples = model.decode_first_stage(samples_ddim)
                x_samples = torch.clamp(
                    (x_samples + 1.0) / 2.0, min=0.0, max=1.0)
                print(x_samples.shape)
                x_samples_mean = x_samples.mean(0).unsqueeze(0)
                if not skip_save:
                    for x_sample in x_samples:
                        x_sample = 255. * \
                            rearrange(x_sample.cpu().numpy(),
                                      'c h w -> h w c')
                        Image.fromarray(x_sample.astype(np.uint8)).save(
                            os.path.join(sample_path, f"{base_count:05}.png"))
                        base_count += 1
                all_samples.append(x_samples_mean)

                if not skip_grid:
                    # additionally, save as grid
                    grid = torch.stack(all_samples, 0)
                    grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                    grid = make_grid(grid, nrow=n_rows)

            # to image
            grid = 255. * \
                rearrange(grid, 'c h w -> h w c').cpu().numpy()
            Image.fromarray(grid.astype(np.uint8)).save(
                os.path.join(outpath, f'grid-{grid_count:04}.png'))
            grid_count += 1
            del grid

        toc = time.time()

print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
      f" \nEnjoy.")

target t_enc is 30 steps
Data shape for DDIM sampling is (20, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:33<00:00,  1.51it/s]


torch.Size([20, 3, 512, 512])
Data shape for DDIM sampling is (20, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:33<00:00,  1.50it/s]


torch.Size([20, 3, 512, 512])
Data shape for DDIM sampling is (20, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:33<00:00,  1.49it/s]


torch.Size([20, 3, 512, 512])
Data shape for DDIM sampling is (20, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:33<00:00,  1.49it/s]


torch.Size([20, 3, 512, 512])
Data shape for DDIM sampling is (20, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:33<00:00,  1.49it/s]


torch.Size([20, 3, 512, 512])
Data shape for DDIM sampling is (20, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:33<00:00,  1.49it/s]


torch.Size([20, 3, 512, 512])
Your samples are ready and waiting for you here: 
./outputs/ffhq 
 
Enjoy.


## Evaluate Benchmark

### Define neccessary functions

In [16]:
import os
import shutil
from tqdm import tqdm
import random
import cv2

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir
    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
    return images


def find_classes(directory):
    """Finds the class folders in a dataset.
    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(
        directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(
            f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx


def copy_dataset(path, new_path):
    # copy all images in this dictionary to new_path
    images = make_dataset(path)
    print(len(images))
    for image in tqdm(images):
        img = cv2.imread(image)
        # print(image)
        image_name = image.split('/')[-2] + '_' + image.split('/')[-1]
        if not os.path.exists(new_path):
            os.makedirs(new_path)
        if img is not None:
            cv2.imwrite(new_path + '/' + image_name, img)
            

def copy_dataset_new(path, new_path):
    # copy all images in this dictionary to new_path
    classes, _ = find_classes(path)
    for class_name in tqdm(classes):
        class_path = os.path.join(path, class_name)
        new_class_path = os.path.join(new_path, class_name)
        images = make_dataset(class_path)
        for image in images:
            img = cv2.imread(image)
            if not os.path.exists(new_class_path):
                os.makedirs(new_class_path)
            if img is not None:
                cv2.imwrite(new_class_path + '/' + image.split('/')[-1], img)
            

def sample_dataset(path, new_path, size):
    images = make_dataset(path)
    sample_images = random.sample(images, k=size)
    # print(sample_images)

    for image in tqdm(sample_images):
        img = cv2.imread(image)
        # print(img)
        dirs = new_path
        if not os.path.exists(dirs):
            os.makedirs(dirs)
        if img is not None:
            cv2.imwrite(dirs + '/' + image.split('/')[-1], img)
            

def select_subset(path, new_path, num_classes, num_samples_per_class=None):
    classes, class_to_idx = find_classes(path)
    test_classes = random.sample(classes, k=num_classes)
    train_classes = set(classes)-set(test_classes)
    print(len(test_classes))
    print(len(train_classes))
    for i in range(num_classes):
        test_class = classes[i]
        source_dir = path + '/' + str(test_class)
        destination_dir = new_path + '/' + str(test_class)
        if num_samples_per_class is not None:
            sample_dataset(source_dir, destination_dir, num_samples_per_class)
        else:
            shutil.copytree(source_dir, destination_dir)
        print('finish {}/{}'.format(i, num_classes))
    '''
    for test_class in tqdm(test_classes):
        source_dir = path + 'valid/' + str(test_class)
        destination_dir = new_path + 'test/' + str(test_class)
        shutil.copytree(source_dir, destination_dir)
        '''
    print("Finish creating new dataset!")

In [8]:
dataset_path = '/data2/mhf/DXL/Lingxiao/datasets/animals'
select_subset(dataset_path, '/data2/mhf/DXL/Lingxiao/datasets/animals_eva_random', 30, None)

30
119
finish 0/30
finish 1/30
finish 2/30
finish 3/30
finish 4/30
finish 5/30
finish 6/30
finish 7/30
finish 8/30
finish 9/30
finish 10/30
finish 11/30
finish 12/30
finish 13/30
finish 14/30
finish 15/30
finish 16/30
finish 17/30
finish 18/30
finish 19/30
finish 20/30
finish 21/30
finish 22/30
finish 23/30
finish 24/30
finish 25/30
finish 26/30
finish 27/30
finish 28/30
finish 29/30
Finish creating new dataset!


In [32]:
test123 = '/data2/mhf/DXL/Lingxiao/datasets/animals/n02085620'
images = make_dataset(test123)
print(len(images))

913


In [21]:
sample_dataset(
    '/data2/mhf/DXL/Lingxiao/datasets/nabirds_eva/test/0299', '/data2/mhf/DXL/Lingxiao/datasets/nabirds_eva_30/0299', 30)

100%|██████████| 30/30 [00:00<00:00, 33.54it/s]


In [8]:
image_folder = '/data2/mhf/DXL/Lingxiao/datasets/flowers/dataset/train_eva'
outdir = './outputs/flowers_eva_genrated_test_cfg_1.7_strength_0.9'
skip_save = False
ddim_steps = 50
ddim_eta = 0.0
n_iter = 1
C = 4
f = 8
n_samples = 16
total_samples = 128
n_rows = 0
scale = 1.7
strength = 0.9
precision = 'autocast'

### Sample images for calculating LPIPS

In [10]:
sampler.make_schedule(ddim_num_steps=ddim_steps,
                      ddim_eta=ddim_eta, verbose=False)
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
precision_scope = autocast if precision == "autocast" else nullcontext
batch_size = n_samples
n_rows = n_rows if n_rows > 0 else batch_size
classes, class_to_idx = find_classes(image_folder)
count = 0
for class_name in classes:
    class_path = os.path.join(image_folder, class_name)
    images = make_dataset(class_path)
    print(len(images))
    for i in range(int(len(images)/batch_size)):
        input_images = []
        input_images_resized = []
        for m in range(i*batch_size, (i+1)*batch_size):
            image_path = images[m]
            assert os.path.isfile(image_path)
            init_image = load_img(image_path, [512, 512]).to(device)
            init_image_resized = load_img(image_path, [224, 224]).to(device)
            input_images.append(init_image)
            input_images_resized.append(init_image_resized)
            
        init_image = torch.stack(input_images).squeeze(1)
        ref_image = torch.stack(input_images_resized).squeeze(1)
        init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))  # move to latent space
        t_enc = int(strength * ddim_steps)
        uc = get_unconditional_embedding(model, scale, n_samples, device, ref_image)
        
        logits, ocodes, feature, feature_dist, feature_euc = model.get_learned_conditioning(ref_image)
        code = [feature_euc]
        with torch.no_grad():
            with precision_scope("cuda"):
                with model.ema_scope():
                    tic = time.time()
                    all_samples = list()
                    shape = [C, 64, 64]
                    # encode (scaled latent)
                    if strength < 1.0:
                        z_enc = sampler.stochastic_encode(
                            init_latent, torch.tensor([t_enc]*batch_size).to(device))
                    # print(f"z_enc shape: {z_enc.shape}")
                    else:
                        z_enc = torch.randn([n_samples, 4, 64, 64], device=device)
                    # decode it
                    # for c in rescaled_codes:
                    # for c in interpolated_codes:
                    # for c in reconstruct_codes:
                    # for c in perturbed_codes:
                    # for c in compare_codes:
                    # for c in fused_codes:
                    for c in code:
                        if strength == 1.0:
                            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                            conditioning=c,
                                                            batch_size=n_samples,
                                                            shape=shape,
                                                            verbose=False,
                                                            unconditional_guidance_scale=scale,
                                                            unconditional_conditioning=uc,
                                                            eta=ddim_eta,
                                                            x_T=z_enc)
                        else:
                            samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=scale,
                                                        unconditional_conditioning=uc)

                        x_samples = model.decode_first_stage(samples_ddim)
                        x_samples = torch.clamp(
                            (x_samples + 1.0) / 2.0, min=0.0, max=1.0)
                        count_img = 0
                        if not skip_save:
                            for x_sample in x_samples:
                                x_sample = 255. * \
                                    rearrange(x_sample.cpu().numpy(),
                                            'c h w -> h w c')
                                output_dir = os.path.join(outdir, images[i*batch_size+count_img].split('/')[-2])
                                if not os.path.exists(output_dir):
                                    os.makedirs(output_dir)
                                Image.fromarray(x_sample.astype(np.uint8)).save(
                                    os.path.join(output_dir, images[i*batch_size+count_img].split('/')[-1]))
                                count_img += 1

    toc = time.time()
    count += 1
    print('finish {}/{}'.format(count, len(classes)))

print(f"Your samples are ready and waiting for you \n"
        f" \nEnjoy.")

27
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


finish 1/19
38
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:45<00:00,  1.00s/it]


finish 2/19
35
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.00it/s]


finish 3/19
49
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:45<00:00,  1.00s/it]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


finish 4/19
36
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:45<00:00,  1.00s/it]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


finish 5/19
68
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.00it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.00it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


finish 6/19
73
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:45<00:00,  1.01s/it]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:42<00:00,  1.07it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:41<00:00,  1.09it/s]


finish 7/19
38
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:41<00:00,  1.08it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:45<00:00,  1.01s/it]


finish 8/19
44
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:45<00:00,  1.00s/it]


finish 9/19
38
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.00it/s]


finish 10/19
36
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


finish 11/19
60
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


finish 12/19
65
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:45<00:00,  1.00s/it]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.00it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:45<00:00,  1.01s/it]


finish 13/19
38
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.00it/s]


finish 14/19
49
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:45<00:00,  1.00s/it]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:45<00:00,  1.00s/it]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


finish 15/19
46
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:45<00:00,  1.00s/it]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


finish 16/19
34
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:45<00:00,  1.00s/it]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


finish 17/19
47
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.00it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


finish 18/19
72
Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:45<00:00,  1.00s/it]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.00it/s]


Running DDIM Sampling with 45 timesteps


Decoding image: 100%|██████████| 45/45 [00:44<00:00,  1.01it/s]


finish 19/19
Your samples are ready and waiting for you 
 
Enjoy.


### Sample images for calculating FID

In [7]:
image_folder = '/data2/mhf/DXL/Lingxiao/datasets/animals_eva/train'
outdir = './outputs/animals_eva_genrated_cfg_1.3_strength_0.95_r_5.6_30_samples_30_each'
files = glob.glob("/data2/mhf/DXL/Lingxiao/datasets/animals/n02085620/*.jpg")
skip_save = False
ddim_steps = 50
ddim_eta = 0.0
n_iter = 1
C = 4
f = 8
n_samples = 5
n_perturb_samples = 5
n_selected = 30
n_rows = 0
scale = 1.3
strength = 0.95
precision = 'autocast'

In [32]:
sampler.make_schedule(ddim_num_steps=ddim_steps,
                      ddim_eta=ddim_eta, verbose=False)
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
precision_scope = autocast if precision == "autocast" else nullcontext
batch_size = n_samples
n_rows = n_rows if n_rows > 0 else batch_size
classes, class_to_idx = find_classes(image_folder)
count = 0
for class_name in classes:
    class_path = os.path.join(image_folder, class_name)
    images = make_dataset(class_path)
    if n_selected < len(images):
        images = random.sample(images, n_selected)
    for image_path in images:
        assert os.path.isfile(image_path)
        init_image = load_img(image_path, [512, 512]).to(device)
        ref_image = load_img(image_path, [224, 224]).to(device)
        init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
        ref_image = repeat(ref_image, '1 ... -> b ...', b=batch_size)
        # sampling image
        sampled_imgs = random.sample(files, n_perturb_samples)
        perturbed_codes = []
        logits, ocodes, feature, feature_dist, feature_euc = model.get_learned_conditioning(ref_image)
        perturbed_codes.append(feature_euc)
        # load sampled images
        if n_perturb_samples > 0:
            sampled_images = []
            for sampled_img in sampled_imgs:
                assert os.path.isfile(sampled_img)
                sampled_image = load_img(sampled_img, [224, 224]).to(device)
                sampled_images.append(sampled_image)
        
            _, hyp_code = get_hyp_codes(model, ref_image[0].unsqueeze(0))
            _, perturb_codes = get_hyp_codes(model, sampled_images)
            distances = [5.6]
            num_samples = len(sampled_images)
            feature_geo, feature_geo_normalized, feature_geo_current_target_boundaries, dist_to_start, perturbation_distances = geo_perturbation(
                hyp_code, distances, target_radius=6.2126, perturb_codes=perturb_codes, num_samples=num_samples)
            for i in feature_geo_current_target_boundaries:
                feature = repeat(i, '1 ... -> b ...', b=batch_size)
                feature_euc = get_condition_given_hyp_codes(model, feature)
                perturbed_codes.append(feature_euc)
        init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))  # move to latent space
        t_enc = int(strength * ddim_steps)

        uc = get_unconditional_embedding(model, scale, n_samples, device, ref_image)

        with torch.no_grad():
            with precision_scope("cuda"):
                with model.ema_scope():
                    tic = time.time()
                    all_samples = list()
                    shape = [C, 64, 64]
                    # encode (scaled latent)
                    if strength < 1.0:
                        z_enc = sampler.stochastic_encode(
                            init_latent, torch.tensor([t_enc]*batch_size).to(device))
                    # print(f"z_enc shape: {z_enc.shape}")
                    else:
                        z_enc = torch.randn([n_samples, 4, 64, 64], device=device)
                    # decode it
                    i = 0
                    # for c in rescaled_codes:
                    # for c in interpolated_codes:
                    # for c in reconstruct_codes:
                    for c in perturbed_codes:
                    # for c in compare_codes:
                    # for c in fused_codes:
                    # for c in code:
                        if strength == 1.0:
                            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                            conditioning=c,
                                                            batch_size=n_samples,
                                                            shape=shape,
                                                            verbose=False,
                                                            unconditional_guidance_scale=scale,
                                                            unconditional_conditioning=uc,
                                                            eta=ddim_eta,
                                                            x_T=z_enc)
                        else:
                            samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=scale,
                                                        unconditional_conditioning=uc)

                        x_samples = model.decode_first_stage(samples_ddim)
                        x_samples = torch.clamp(
                            (x_samples + 1.0) / 2.0, min=0.0, max=1.0)
                        
                        if not skip_save:
                            for x_sample in x_samples:
                                x_sample = 255. * \
                                    rearrange(x_sample.cpu().numpy(),
                                            'c h w -> h w c')
                                output_dir = os.path.join(outdir, image_path.split('/')[-2])
                                if not os.path.exists(output_dir):
                                    os.makedirs(output_dir)
                                Image.fromarray(x_sample.astype(np.uint8)).save(
                                    os.path.join(output_dir, str(i) + '_' + image_path.split('/')[-1]))
                                i += 1

            toc = time.time()
            count += 1
            print('finish {}/{}'.format(count, len(images)))
    break
print(f"Your samples are ready and waiting for you \n"
      f" \nEnjoy.")

start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.04it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.04it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


finish 1/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.09it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:13<00:00,  3.38it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:14<00:00,  3.18it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:14<00:00,  3.30it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:13<00:00,  3.40it/s]


finish 2/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:14<00:00,  3.20it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:14<00:00,  3.32it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:14<00:00,  3.29it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.99it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.97it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.98it/s]


finish 3/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.98it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.04it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.08it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


finish 4/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.04it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.06it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.00it/s]


finish 5/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.07it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


finish 6/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.06it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.06it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


finish 7/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.98it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.99it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.00it/s]


finish 8/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.06it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.00it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


finish 9/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.08it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


finish 10/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


finish 11/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.09it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


finish 12/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.07it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.07it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.04it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


finish 13/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.04it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.08it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.04it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


finish 14/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


finish 15/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.04it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


finish 16/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.07it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.97it/s]


finish 17/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.99it/s]


finish 18/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.00it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.08it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


finish 19/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.06it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.04it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.98it/s]


finish 20/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.00it/s]


finish 21/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.97it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


finish 22/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.04it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.04it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.00it/s]


finish 23/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.99it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.98it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.00it/s]


finish 24/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.06it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.04it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


finish 25/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.06it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.06it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


finish 26/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.06it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.06it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.99it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.00it/s]


finish 27/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.99it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.98it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.00it/s]


finish 28/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  2.98it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.06it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.04it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.00it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.00it/s]


finish 29/30
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.02it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.06it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.04it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.01it/s]


Running DDIM Sampling with 47 timesteps


Decoding image: 100%|██████████| 47/47 [00:15<00:00,  3.00it/s]


finish 30/30
Your samples are ready and waiting for you 
 
Enjoy.


In [20]:
copy_dataset('/data2/mhf/DXL/Lingxiao/Codes/HypDiffusion/outputs/flowers_eva_random_genrated_cfg_1.7_strength_0.95_r_5.5_10_samples_12_each',
             '/data2/mhf/DXL/Lingxiao/Codes/HypDiffusion/outputs/flowers_eva_random_genrated_cfg_1.7_strength_0.95_r_5.5_10_samples_12_each_fid')

2040


  0%|          | 0/2040 [00:00<?, ?it/s]

100%|██████████| 2040/2040 [00:10<00:00, 191.55it/s]


In [4]:
copy_dataset('/data2/mhf/DXL/Lingxiao/datasets/animals_eva/test_30/test',
             '/data2/mhf/DXL/Lingxiao/datasets/animals_eva/test_30/test_fid')

1050


100%|██████████| 1050/1050 [00:01<00:00, 915.03it/s]


In [2]:
def remove_redundant_images(path_1, path_2, dataset_type):
    classes, _ = find_classes(path_1)
    for class_name in tqdm(classes):
        class_path = os.path.join(path_1, class_name)
        images = make_dataset(class_path)
        images_2 = make_dataset(os.path.join(path_2, class_name))
        # print(len(images))
        image_names = []
        for image in images:
            if dataset_type == 'flowers':
                image_name = image.split('/')[-1].split('_')[-2] + '_' + image.split('/')[-1].split('_')[-1]
            elif dataset_type == 'nabirds':
                image_name = image.split('/')[-1][2:]
            elif dataset_type == 'animals':
                image_name = image.split('/')[-1][2:]
            elif dataset_type == 'vggfaces':
                image_name = image.split('/')[-1][2:]
            
            # print(image_name)
            image_names.append(image_name)
        for image in images_2:
            if dataset_type == 'flowers':
                image_name = image.split('/')[-1].split('_')[-2] + '_' + image.split('/')[-1].split('_')[-1]
            elif dataset_type == 'nabirds':
                image_name = image.split('/')[-1]
            elif dataset_type == 'animals':
                image_name = image.split('/')[-1]
            elif dataset_type == 'vggfaces':
                image_name = image.split('/')[-1]
            if image_name not in image_names:

                img = cv2.imread(image)
                new_path = os.path.join(path_2+'_clear', class_name)
                if not os.path.exists(new_path):
                    os.makedirs(new_path)
                if img is not None:
                    cv2.imwrite(new_path + '/' + image_name, img)
                    
            

In [5]:
remove_redundant_images(
    '/data2/mhf/DXL/Lingxiao/Codes/HypDiffusion/outputs/vggfaces_eva_genrated_cfg_1.3_strength_0.95_r_5.4_30_samples_6_each_128',
    '/data2/mhf/DXL/Lingxiao/datasets/vggfaces_eva/test', 'vggfaces')

100%|██████████| 572/572 [00:16<00:00, 35.64it/s]
