In [1]:
import glob
import numpy as np
import re
import torch
import PIL
import argparse
import os
import sys
sys.path.insert(0, '/data2/mhf/DXL/Lingxiao/Codes/HypDiffusion')
from torchvision.utils import make_grid
from tqdm import tqdm, trange
from omegaconf import OmegaConf
from PIL import Image
from itertools import islice
from einops import rearrange, repeat
from torch import autocast
from contextlib import nullcontext
from pytorch_lightning import seed_everything
import cv2
import time
import random
from ldm.util import instantiate_from_config
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim_org import DDIMSampler
from ldm.models.diffusion.dpm_solver_org import DPMSolverSampler

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

  from .autonotebook import tqdm as notebook_tqdm


## Define neccessary functions

In [2]:
from notebook_utils.pmath import *
# we utilize geoopt package for hyperbolic calculation
import geoopt.manifolds.stereographic.math as gmath
def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


def load_img(path, size=[256, 256]):
    image = cv2.imread(path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    w, h = image.shape[:2]
    print(f"loaded input image of size ({w}, {h}) from {path}")
    # resize to integer multiple of 32
    # w, h = map(lambda x: x - x % 32, (w, h))
    w, h = size
    image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
    image = np.array(image).astype(np.uint8)

    image = (image / 127.5 - 1.0).astype(np.float32)
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return image


def load_model_and_get_prompt_embedding(model, scale, n_samples, device, prompts, inv=False):

    if inv:
        inv_emb = model.get_learned_conditioning(prompts, inv)
        c = uc = inv_emb
    else:
        inv_emb = None

    if scale != 1.0:
        uc = model.get_learned_conditioning(
            n_samples * [torch.zeros((1, 3, 224, 224))])
    else:
        uc = None
    c = model.get_learned_conditioning(prompts)

    return c, uc, inv_emb

def get_hyp_codes(model, prompts):
    # return latent codes in the hyperbolic space for the given prompts
    logits, feature, feature_dist, feature_euc = model.get_learned_conditioning(prompts)
    return feature, feature_dist

def get_hyp_codes_given_feature(model, feature):
    # return latent codes in the hyperbolic space for the given latent codes in CLIP space
    logits, feature, feature_dist, feature_euc = model.get_learned_conditioning(feature, input_feature=False, input_code=True)
    return feature, feature_dist

def get_condition_given_feature(model, feature):
    # return latent codes in the CLIP space for the given latent codes in hyperbolic space
    logits, feature, feature_dist, feature_euc = model.get_learned_conditioning(feature, input_feature=False, input_code=True)
    return feature_euc

def get_condition_given_hyp_codes(model, hyp_codes):
    # return latent codes in the CLIP space for the given latent codes in hyperbolic space
    logits, feature, feature_dist, feature_euc = model.get_learned_conditioning(hyp_codes, input_feature=True)
    return feature_euc


# rescale function
def rescale(target_radius, x):
    r_change = target_radius / \
        dist0(gmath.mobius_scalar_mul(
            r=torch.tensor(1), x=x, k=torch.tensor(-1.0)))
    return gmath.mobius_scalar_mul(r=r_change, x=x, k=torch.tensor(-1.0))


# function for generating images with fixed radius (also contains raw geodesic images of 'shorten' images, and stretched images to boundary)
def geo_interpolate_fix_r(x, y, interval, target_radius, save_codes=False):
    feature_geo = []
    feature_geo_normalized = []
    dist_to_start = []
    feature_geo_current_target_boundaries = []
    target_radius_ratio = torch.tensor(target_radius/6.2126)
    geodesic_start_short = gmath.mobius_scalar_mul(
        r=target_radius_ratio, x=x, k=torch.tensor(-1.0))
    geodesic_end_short = gmath.mobius_scalar_mul(
        r=target_radius_ratio, x=y, k=torch.tensor(-1.0))
    index = 0
    for i in interval:
        # this is raw image on geodesic, instead of fixed radius
        feature_geo_current = gmath.geodesic(t=torch.tensor(
            i), x=geodesic_start_short, y=geodesic_end_short, k=torch.tensor(-1.0))

        # here we fix the radius and don't revert them now
        r_change = target_radius / \
            dist0(gmath.mobius_scalar_mul(r=torch.tensor(1),
                  x=feature_geo_current, k=torch.tensor(-1.0)))
        feature_geo.append(feature_geo_current)
        feature_geo_current_target_radius = gmath.mobius_scalar_mul(
            r=r_change, x=feature_geo_current, k=torch.tensor(-1.0))
        feature_geo_normalized.append(feature_geo_current_target_radius)
        dist = gmath.dist(
            geodesic_start_short, feature_geo_current_target_radius, k=torch.tensor(-1.0))
        dist_to_start.append(dist)

        # here is to revert the feature to boundary
        r_change_to_boundary = 6.2126 / \
            dist0(gmath.mobius_scalar_mul(r=torch.tensor(1),
                  x=feature_geo_current, k=torch.tensor(-1.0)))
        feature_geo_current_target_boundary = gmath.mobius_scalar_mul(
            r=r_change_to_boundary, x=feature_geo_current, k=torch.tensor(-1.0))
        feature_geo_current_target_boundaries.append(feature_geo_current_target_boundary)

    return feature_geo, feature_geo_normalized, feature_geo_current_target_boundaries, dist_to_start

# function for generating images with fixed radius with optional latent codes list output


def geo_interpolate_fix_r_with_codes(x, y, interval, target_radius):
    # please use this with batch_size = 1
    feature_geo = []
    feature_geo_normalized = []
    dist_to_start = []
    target_radius_ratio = torch.tensor(target_radius/6.2126)
    geodesic_start_short = gmath.mobius_scalar_mul(
        r=target_radius_ratio, x=x, k=torch.tensor(-1.0))
    geodesic_end_short = gmath.mobius_scalar_mul(
        r=target_radius_ratio, x=y, k=torch.tensor(-1.0))
    for i in interval:
        # this is raw image on geodesic, instead of fixed radius
        feature_geo_current = gmath.geodesic(t=torch.tensor(
            i), x=geodesic_start_short, y=geodesic_end_short, k=torch.tensor(-1.0))

        # here we fix the radius and don't revert them now
        r_change = target_radius / \
            dist0(gmath.mobius_scalar_mul(r=torch.tensor(1),
                  x=feature_geo_current, k=torch.tensor(-1.0)))
        feature_geo.append(feature_geo_current)
        feature_geo_current_target_radius = gmath.mobius_scalar_mul(
            r=r_change, x=feature_geo_current, k=torch.tensor(-1.0))
        feature_geo_normalized.append(feature_geo_current_target_radius)
        dist = gmath.dist(
            geodesic_start_short, feature_geo_current_target_radius, k=torch.tensor(-1.0))
        dist_to_start.append(dist)
        # print(feature_geo_current_target_radius.norm())

        # here is to revert the feature to boundary
        r_change_to_boundary = 6.2126 / \
            dist0(gmath.mobius_scalar_mul(r=torch.tensor(1),
                  x=feature_geo_current, k=torch.tensor(-1.0)))
        feature_geo_current_target_boundary = gmath.mobius_scalar_mul(
            r=r_change_to_boundary, x=feature_geo_current, k=torch.tensor(-1.0))
        # print(feature_geo_current_target_boundary.norm())

    return dist_to_start, feature_geo_current, feature_geo_current_target_radius, feature_geo_current_target_boundary


def geo_perturbation(x, distances, perturb_codes, target_radius=6.2126, num_samples=10, save_codes=False):
    """
    该函数在双曲空间围绕给定的点 x 生成随机扰动，并确保扰动点的半径与 x 相同。
    函数会为每个给定的双曲距离生成多个样本。

    参数:
        x (tensor): Poincaré圆盘中的起始点。
        distances (list): 目标的双曲距离列表，例如 [0.01, 0.1, 0.2]。
        target_radius (float): 要将扰动点缩放到的目标半径。
        num_samples (int): 每个距离生成的样本数量。
        save_codes (bool, 可选): 是否保存其他数据，默认为 False。
    
    返回:
        tuple: 包含以下列表的元组：
            - feature_geo: 未归一化的扰动点。
            - feature_geo_normalized: 归一化到目标半径的扰动点。
            - feature_geo_current_target_boundaries: 扰动点缩放到边界。
            - dist_to_start: 扰动点与起始点 x 的双曲距离。
            - perturbation_distances: 每个扰动点与起始点的双曲距离。
    """
    
    feature_geo = []
    feature_geo_normalized = []
    dist_to_start = []
    feature_geo_current_target_boundaries = []
    perturbation_distances = []  # 保存每个 feature_geo_current 和 geodesic_start_short 之间的距离
    
    # 1. 计算目标半径的比例，并缩放输入向量 x
    target_radius_ratio = torch.tensor(target_radius / 6.2126)
    
    # 缩放 x 到目标半径，得到 geodesic_start_short
    geodesic_start_short = gmath.mobius_scalar_mul(r=target_radius_ratio, x=x, k=torch.tensor(-1.0))
    print('start_radius', dist0(gmath.mobius_scalar_mul(r=torch.tensor(1), x=geodesic_start_short, k=torch.tensor(-1.0))))
    # 外层循环遍历每个距离
    for distance in distances:
        # 内层循环生成每个距离下的多个样本
        for perturb_code in perturb_codes:
            # 2. 生成随机方向
            # random_direction = torch.randn_like(geodesic_start_short)/100
            random_direction = perturb_code
            
            #print(dist0(gmath.mobius_scalar_mul(r=torch.tensor(1), x=random_direction, k=torch.tensor(-1.0))))
            # 3. 使用 dist0 计算随机方向的双曲距离，并调整使其与 geodesic_start_short 一样长
            r_change_direction_to = target_radius / dist0(gmath.mobius_scalar_mul(r=torch.tensor(1), x=random_direction, k=torch.tensor(-1.0)))
            geodesic_end_short = gmath.mobius_scalar_mul(r=r_change_direction_to, x=random_direction, k=torch.tensor(-1.0))
            # print(dist0(gmath.mobius_scalar_mul(r=torch.tensor(1), x=geodesic_end_short, k=torch.tensor(-1.0))))
            # 4. 计算 geodesic_start_short 和 geodesic_end_short 之间的双曲距离
            distance_x_y = gmath.dist(geodesic_start_short, geodesic_end_short, k=torch.tensor(-1.0))
            # print('current_distance',distance_x_y)
            # 5. 计算插值比例 t，确保 geodesic_start_short 和 feature_geo_current 之间的双曲距离等于指定的 distance
            if distance_x_y > 0:
                t = distance / distance_x_y  # 插值比例，确保按照双曲距离采样
            else:
                t = torch.tensor(1.0)  # 当距离极小时，设为1
            
            # 6. 沿着 geodesic_start_short 和 geodesic_end_short 插值生成 feature_geo_current
            feature_geo_current = gmath.geodesic(t=t, x=geodesic_start_short, y=geodesic_end_short, k=torch.tensor(-1.0))
            
            # 7. 保存未归一化的 feature_geo_current
            feature_geo.append(feature_geo_current)

            # print('current_radius', dist0(gmath.mobius_scalar_mul(r=torch.tensor(1), x=feature_geo_current, k=torch.tensor(-1.0))))
            # 8. 修正到目标半径，确保 feature_geo_current 的半径与 target_radius 一致
            r_change = target_radius / dist0(gmath.mobius_scalar_mul(r=torch.tensor(1), x=feature_geo_current, k=torch.tensor(-1.0)))
            # print('change ratio for interpolation sample',r_change)
            feature_geo_current_target_radius = gmath.mobius_scalar_mul(r=r_change, x=feature_geo_current, k=torch.tensor(-1.0))
            feature_geo_normalized.append(feature_geo_current_target_radius)
            
            # 9. 计算扰动点到起始点的双曲距离，并保存
            dist = gmath.dist(geodesic_start_short, feature_geo_current_target_radius, k=torch.tensor(-1.0))
            dist_to_start.append(dist)
            
            # 10. 保存每个 feature_geo_current 和 geodesic_start_short 的双曲距离
            perturbation_distances.append(gmath.dist(geodesic_start_short, feature_geo_current, k=torch.tensor(-1.0)))
            
            # 11. 将扰动点调整到边界
            r_change_to_boundary = 6.2126 / dist0(gmath.mobius_scalar_mul(r=torch.tensor(1), x=feature_geo_current, k=torch.tensor(-1.0)))
            feature_geo_current_target_boundary = gmath.mobius_scalar_mul(r=r_change_to_boundary, x=feature_geo_current, k=torch.tensor(-1.0))
            feature_geo_current_target_boundaries.append(feature_geo_current_target_boundary)
    
    return feature_geo, feature_geo_normalized, feature_geo_current_target_boundaries, dist_to_start, perturbation_distances

### Define animalfaces variables

In [3]:
# define variables
init_image_path = './inputs/same_domain_test/animalfaces/test1.jpg'
ref_image_path = './inputs/same_domain_test/animalfaces/test1.jpg'
ref_image_2_path = './inputs/same_domain_test/animalfaces/test2.jpg'
# sampling image
random.seed(50)
files = glob.glob("/data2/mhf/DXL/Lingxiao/datasets/animals/*/*.jpg")
sampled_imgs = random.sample(files, 5)

outdir = './outputs/animalfaces_yi'
skip_grid = False
skip_save = True
ddim_steps = 50
ddim_eta = 0.0
n_iter = 1
C = 4
f = 8
n_samples = 4
n_rows = 0
scale = 7.5
strength = 1.0
config_path = './configs/stable-diffusion/v2_inference_animalfaces.yaml'
ckpt = '/data2/mhf/DXL/Lingxiao/Codes/Paint-by-Example-test/models/Paint-by-Example/animal_faces/2024-09-11T11-30-12_v1/checkpoints/epoch=000035.ckpt'
seed = 3408
precision = 'autocast'

### Define vggfaces variables

In [3]:
# define variables
init_image_path = './inputs/same_domain_test/vggfaces/40.jpg'
ref_image_path = './inputs/same_domain_test/vggfaces/40.jpg'
ref_image_2_path = './inputs/same_domain_test/vggfaces/18.jpg'
outdir = './outputs/vggfaces'
skip_grid = False
skip_save = True
ddim_steps = 50
ddim_eta = 0.0
n_iter = 1
C = 4
f = 8
n_samples = 4
n_rows = 0
scale = 7.5
strength = 1.0
config_path = './configs/stable-diffusion/v2_inference_vggfaces.yaml'
ckpt = '/data2/mhf/DXL/Lingxiao/Codes/Paint-by-Example-test/models/Paint-by-Example/vgg_faces/2024-10-02T02-12-51_v1/checkpoints/epoch=000037-ffhq.ckpt'
seed = 3408
precision = 'autocast'

### Define flowers variables

In [3]:
# define variables
init_image_path = './inputs/same_domain_test/flowers/0.jpg'
ref_image_path = './inputs/same_domain_test/flowers/0.jpg'
ref_image_2_path = './inputs/same_domain_test/flowers/1.jpg'
outdir = './outputs/flowers'
skip_grid = False
skip_save = True
ddim_steps = 50
ddim_eta = 0.0
n_iter = 1
C = 4
f = 8
n_samples = 4
n_rows = 0
scale = 7.5
strength = 1.0
config_path = './configs/stable-diffusion/v2_inference_flowers.yaml'
ckpt = '/data2/mhf/DXL/Lingxiao/Codes/Paint-by-Example-test/models/Paint-by-Example/flowers/2024-10-12T05-45-07_v1/checkpoints/epoch=000298.ckpt'
seed = 3408
precision = 'autocast'

### Load the model

In [4]:
seed_everything(seed)

config = OmegaConf.load(f"{config_path}")
model = load_model_from_config(config, f"{ckpt}")

device = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
print("Successfully loaded model!")

sampler = DDIMSampler(model)


Global seed set to 3408


Loading model from /data2/mhf/DXL/Lingxiao/Codes/Paint-by-Example-test/models/Paint-by-Example/animal_faces/2024-09-11T11-30-12_v1/checkpoints/epoch=000035.ckpt
Global Step: 77724
No module 'xformers'. Proceeding without it.
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 865.91 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


Some weights of the model checkpoint at /data2/mhf/DXL/Lingxiao/Cache/huggingface/hub/models--openai--clip-vit-large-patch14 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.3.mlp.fc1.bias', 'text_model.encoder.layers.10.self_attn.v_proj.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.6.self_attn.q_proj.weight', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.3.self_attn.out_proj.weight', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.k_proj.bias', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.k_proj.

Use hyperbolic: True
Loading HAE from checkpoint: /data2/mhf/DXL/Lingxiao/Codes/hyperediting/exp_out/hyper_styleGANinversion_animalfaces_512_5_30_init_v2/checkpoints/iteration_11000.pt
Successfully loaded model!


### Load Delta Mapper for text-guided editing

In [5]:
from argparse import Namespace
sys.path.insert(0, '/data2/mhf/DXL/Lingxiao/Codes')
from DeltaHyperEditing.models.delta_hyp_clip import hae_clip

# model_path = '/data2/mhf/DXL/Lingxiao/Codes/DeltaHyperEditing/exp_out/hyper_animalfaces_512_5_30_v3/checkpoints/iteration_8000.pt'
model_path = '/data2/mhf/DXL/Lingxiao/Codes/DeltaHyperEditing/exp_out/hyper_ffhq_512_5_30_v1/checkpoints/iteration_36000.pt'
# model_path = '/data2/mhf/DXL/Lingxiao/Codes/DeltaHyperEditing/exp_out/hyper_flowers_512_5_30_v1/checkpoints/iteration_26000.pt'
ckpt = torch.load(model_path, map_location='cpu')
opts = ckpt['opts']
del ckpt
opts['checkpoint_path'] = model_path
# instantialize model with checkpoints and args
opts = Namespace(**opts)
net = hae_clip(opts)
net.eval()
net.to(device)
print('Model successfully loaded!')

Some weights of the model checkpoint at /data2/mhf/DXL/Lingxiao/Cache/huggingface/hub/models--openai--clip-vit-large-patch14 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.1.self_attn.q_proj.weight', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.6.self_attn.v_proj.weight', 'text_model.encoder.layers.3.layer_norm2.weight', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.final_layer_norm.weight', 'text_model.encoder.layers.1.self_attn.out_proj.bias', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.2.layer_norm2.bias', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.encoder.layers.6.mlp.fc1.weight', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.10.self_attn

Use hyperbolic: True
Loading HAE from checkpoint: /data2/mhf/DXL/Lingxiao/Codes/DeltaHyperEditing/exp_out/hyper_ffhq_512_5_30_v1/checkpoints/iteration_36000.pt
Model successfully loaded!


In [5]:
os.makedirs(outdir, exist_ok=True)
outpath = outdir

batch_size = n_samples
n_rows = n_rows if n_rows > 0 else batch_size

sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1

### Load images

In [6]:
# load images
# load init image
assert os.path.isfile(init_image_path)
init_image = load_img(init_image_path, [512, 512]).to(device)
init_image_resized = load_img(init_image_path, [224, 224]).to(device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_image_resized = repeat(init_image_resized, '1 ... -> b ...', b=batch_size)
init_latent = model.get_first_stage_encoding(
    model.encode_first_stage(init_image))  # move to latent space

# load ref image
assert os.path.isfile(ref_image_path)
ref_image = load_img(ref_image_path, [224, 224]).to(device)
ref_image = repeat(ref_image, '1 ... -> b ...', b=batch_size)

assert os.path.isfile(ref_image_2_path)
ref_image_2 = load_img(ref_image_2_path, [224, 224]).to(device)
ref_image_2 = repeat(ref_image_2, '1 ... -> b ...', b=batch_size)

# load sampled images
sampled_images = []
for sampled_img in sampled_imgs:
    assert os.path.isfile(sampled_img)
    sampled_image = load_img(sampled_img, [224, 224]).to(device)
    sampled_images.append(sampled_image)

sampler.make_schedule(ddim_num_steps=ddim_steps,
                        ddim_eta=ddim_eta, verbose=False)

assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(strength * ddim_steps)
print(f"target t_enc is {t_enc} steps")

loaded input image of size (175, 159) from ./inputs/same_domain_test/animalfaces/test1.jpg
loaded input image of size (175, 159) from ./inputs/same_domain_test/animalfaces/test1.jpg
loaded input image of size (175, 159) from ./inputs/same_domain_test/animalfaces/test1.jpg
loaded input image of size (198, 226) from ./inputs/same_domain_test/animalfaces/test2.jpg
loaded input image of size (134, 147) from /data2/mhf/DXL/Lingxiao/datasets/animals/n02088364/n02088364_13877.JPEG_193_107_340_241.jpg
loaded input image of size (126, 132) from /data2/mhf/DXL/Lingxiao/datasets/animals/n02093428/n02093428_8947.JPEG_165_10_297_136.jpg
loaded input image of size (154, 167) from /data2/mhf/DXL/Lingxiao/datasets/animals/n02109047/n02109047_6862.JPEG_17_59_184_213.jpg
loaded input image of size (187, 193) from /data2/mhf/DXL/Lingxiao/datasets/animals/n02090379/n02090379_2046.JPEG_235_78_428_265.jpg
loaded input image of size (127, 148) from /data2/mhf/DXL/Lingxiao/datasets/animals/n02099712/n02099712

## Edit latent codes using text instructions

In [9]:
prompt = 'a face'
prompt_delta = 'a face with red hair'

In [10]:
fake_delta_s = net.get_fake_delta_s_given_data(ref_image, prompt, prompt_delta, feature_type='hyperbolic', device=device)
print(f"fake_delta_s shape: {fake_delta_s.shape}")
logits, feature, feature_dist, feature_euc = model.get_learned_conditioning(ref_image)
_, feature_2, feature_dist_2, feature_euc_2 = model.get_learned_conditioning(ref_image_2)
print(f"feature shape: {feature.shape}")
reconstruct_code = feature + fake_delta_s
# reconstruct_code = feature_dist + fake_delta_s
print(f"reconstruct_code shape: {reconstruct_code.shape}")
reconstruct_feature = get_condition_given_feature(model, reconstruct_code)
reconstruct_codes = [feature_euc, reconstruct_feature]
# reconstruct_codes = [feature_euc, reconstruct_code]

fake_delta_s shape: torch.Size([4, 1, 512])


  k = torch.tensor(k)


feature shape: torch.Size([4, 1, 512])
reconstruct_code shape: torch.Size([4, 1, 512])


In [11]:
_, hyp_codes = get_hyp_codes_given_feature(model, reconstruct_code)
print(hyp_codes.shape)
hyp_code = hyp_codes[0].unsqueeze(0)
print(gmath.dist0(hyp_code, k=torch.tensor(-1.0)))
rescaled_codes = []
target_radii = [6.2126, 4, 2.5, 1, 0.5, 0]
for i in target_radii:
    hyp_code_rescaled = rescale(i, hyp_code)
    hyp_code_rescaled = repeat(hyp_code_rescaled, '1 ... -> b ...', b=batch_size)
    feature_euc = get_condition_given_hyp_codes(model, hyp_code_rescaled)
    rescaled_codes.append(feature_euc)

torch.Size([4, 512])
tensor([6.2126], device='cuda:0', grad_fn=<MulBackward1>)


## Manipulate latent codes in Hyperbolic space

### Moving latent codes from edge to center

In [13]:
_, hyp_code = get_hyp_codes(model, ref_image[0].unsqueeze(0))
print(hyp_code.shape)
print(gmath.dist0(hyp_code, k=torch.tensor(-1.0)))
# this is used for generating figure of varying radius in our paper
rescaled_codes = []
target_radii = [6.2126, 4, 2.5, 1, 0.5, 0]
for i in target_radii:
    hyp_code_rescaled = rescale(i, hyp_code)
    hyp_code_rescaled = repeat(hyp_code_rescaled, '1 ... -> b ...', b=batch_size)
    feature_euc = get_condition_given_hyp_codes(model, hyp_code_rescaled)
    rescaled_codes.append(feature_euc)
    

torch.Size([1, 512])
tensor([6.2126], device='cuda:0', grad_fn=<MulBackward1>)


### Interpolate latent codes in Hyperbolic space along the geodesic

In [49]:
_, hyp_code = get_hyp_codes(model, ref_image[0].unsqueeze(0))
_, hyp_code_2 = get_hyp_codes(model, ref_image_2[0].unsqueeze(0))
print(hyp_code.shape)
print(gmath.dist0(hyp_code, k=torch.tensor(-1.0)))
print(gmath.dist0(hyp_code_2, k=torch.tensor(-1.0)))
# this is used for generating figure of varying radius in our paper
interpolated_codes = []
interval = [0, 0.3, 0.4, 0.5, 0.6, 0.7, 1]
feature_geo, feature_geo_current_target_radius, feature_geo_current_target_boundary, dist_to_start = geo_interpolate_fix_r(
    hyp_code, hyp_code_2, interval, target_radius=6.2126)
for i in feature_geo_current_target_boundary:
    feature = repeat(i, '1 ... -> b ...', b=batch_size)
    feature_euc = get_condition_given_hyp_codes(model, feature)
    interpolated_codes.append(feature_euc)

torch.Size([1, 512])
tensor([6.2126], device='cuda:0', grad_fn=<MulBackward1>)
tensor([6.2126], device='cuda:0', grad_fn=<MulBackward1>)


### Perturbation given an image embedding with certain radius

In [7]:
interpolated_codes = []
_, hyp_code = get_hyp_codes(model, ref_image[0].unsqueeze(0))
_, perturb_codes = get_hyp_codes(model, sampled_images)
print(hyp_code.shape)
print(gmath.dist0(hyp_code, k=torch.tensor(-1.0)))
distances = [5.7]
num_samples = len(sampled_images)
feature_geo, feature_geo_normalized, feature_geo_current_target_boundaries, dist_to_start, perturbation_distances = geo_perturbation(
    hyp_code, distances, target_radius=6.2126, perturb_codes=perturb_codes, num_samples=num_samples)
for i in feature_geo_current_target_boundaries:
    feature = repeat(i, '1 ... -> b ...', b=batch_size)
    feature_euc = get_condition_given_hyp_codes(model, feature)
    interpolated_codes.append(feature_euc)

  k = torch.tensor(k)


torch.Size([1, 512])
tensor([6.2126], device='cuda:0', grad_fn=<MulBackward1>)
start_radius tensor([6.2126], device='cuda:0', grad_fn=<DivBackward0>)
tensor(6.2126, device='cuda:0', grad_fn=<DivBackward0>)
current_distance tensor([11.4315], device='cuda:0', grad_fn=<MulBackward1>)
current_radius tensor([1.0788], device='cuda:0', grad_fn=<DivBackward0>)
change ratio for interpolation sample tensor([5.7588], device='cuda:0', grad_fn=<MulBackward0>)
tensor(6.2126, device='cuda:0', grad_fn=<DivBackward0>)
current_distance tensor([11.5921], device='cuda:0', grad_fn=<MulBackward1>)
current_radius tensor([0.9823], device='cuda:0', grad_fn=<DivBackward0>)
change ratio for interpolation sample tensor([6.3247], device='cuda:0', grad_fn=<MulBackward0>)
tensor(6.2126, device='cuda:0', grad_fn=<DivBackward0>)
current_distance tensor([11.4044], device='cuda:0', grad_fn=<MulBackward1>)
current_radius tensor([1.1005], device='cuda:0', grad_fn=<DivBackward0>)
change ratio for interpolation sample tenso

In [12]:
perturbation_distances

[tensor([5.0000], device='cuda:0', grad_fn=<MulBackward1>),
 tensor([5.0000], device='cuda:0', grad_fn=<MulBackward1>),
 tensor([5.0003], device='cuda:0', grad_fn=<MulBackward1>),
 tensor([5.0005], device='cuda:0', grad_fn=<MulBackward1>),
 tensor([5.0010], device='cuda:0', grad_fn=<MulBackward1>),
 tensor([5.0005], device='cuda:0', grad_fn=<MulBackward1>),
 tensor([5.0003], device='cuda:0', grad_fn=<MulBackward1>),
 tensor([5.0001], device='cuda:0', grad_fn=<MulBackward1>),
 tensor([5.0003], device='cuda:0', grad_fn=<MulBackward1>),
 tensor([5.0000], device='cuda:0', grad_fn=<MulBackward1>)]

### Interpolate latent codes in Euclidean space

In [79]:
logits, feature, feature_dist, feature_euc = model.get_learned_conditioning(ref_image)
_, feature_2, feature_dist_2, feature_euc_2 = model.get_learned_conditioning(ref_image_2)
print('feature_euc shape:', feature_euc.shape)
interpolated_codes = []
interval = [0, 0.3, 0.4, 0.5, 0.6, 0.7, 1]
for i in interval:
    feature = (1-i) * feature_euc + i * feature_euc_2
    interpolated_codes.append(feature)

feature_euc shape: torch.Size([4, 1, 1024])


## Sample Images

In [8]:
precision_scope = autocast if precision == "autocast" else nullcontext
n_rows = n_rows if n_rows > 0 else batch_size
with torch.no_grad():
    with precision_scope("cuda"):
        with model.ema_scope():
            tic = time.time()
            all_samples = list()
            shape = [C, 64, 64]
            # encode (scaled latent)
            if strength < 1.0:
                z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
            # print(f"z_enc shape: {z_enc.shape}")
            else:
                z_enc = torch.randn([n_samples, 4, 64, 64], device=device)
            # decode it
            # for c in rescaled_codes:
            for c in interpolated_codes:
            # for c in reconstruct_codes:
                samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                    conditioning=c,
                                                    batch_size=n_samples,
                                                    shape=shape,
                                                    verbose=False,
                                                    unconditional_guidance_scale=scale,
                                                    unconditional_conditioning=c,
                                                    eta=ddim_eta,
                                                    x_T=z_enc)
                '''
                samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=scale,
                                            unconditional_conditioning=uc,)
                                            '''

                x_samples = model.decode_first_stage(samples_ddim)
                x_samples = torch.clamp(
                    (x_samples + 1.0) / 2.0, min=0.0, max=1.0)

                if not skip_save:
                    for x_sample in x_samples:
                        x_sample = 255. * \
                            rearrange(x_sample.cpu().numpy(),
                                        'c h w -> h w c')
                        Image.fromarray(x_sample.astype(np.uint8)).save(
                            os.path.join(sample_path, f"{base_count:05}.png"))
                        base_count += 1
                all_samples.append(x_samples)

                if not skip_grid:
                    # additionally, save as grid
                    grid = torch.stack(all_samples, 0)
                    grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                    grid = make_grid(grid, nrow=n_rows)

            # to image
            grid = 255. * \
                rearrange(grid, 'c h w -> h w c').cpu().numpy()
            Image.fromarray(grid.astype(np.uint8)).save(
                os.path.join(outpath, f'grid-{grid_count:04}.png'))
            grid_count += 1
            del grid

        toc = time.time()

print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
        f" \nEnjoy.")

Data shape for DDIM sampling is (4, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:07<00:00,  7.08it/s]


Data shape for DDIM sampling is (4, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:07<00:00,  7.06it/s]


Data shape for DDIM sampling is (4, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:07<00:00,  6.93it/s]


Data shape for DDIM sampling is (4, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:07<00:00,  6.89it/s]


Data shape for DDIM sampling is (4, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:07<00:00,  7.00it/s]


Your samples are ready and waiting for you here: 
./outputs/animalfaces_yi 
 
Enjoy.


## Mean of the variation

In [27]:
precision_scope = autocast if precision == "autocast" else nullcontext
with torch.no_grad():
    with precision_scope("cuda"):
        with model.ema_scope():
            tic = time.time()
            all_samples = list()
            shape = [C, 64, 64]
            # encode (scaled latent)
            if strength < 1.0:
                z_enc = sampler.stochastic_encode(
                    init_latent, torch.tensor([t_enc]*batch_size).to(device))
            # print(f"z_enc shape: {z_enc.shape}")
            else:
                z_enc = torch.randn([n_samples, 4, 64, 64], device=device)
            # decode it
            for c in rescaled_codes:
                samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                 conditioning=c,
                                                 batch_size=n_samples,
                                                 shape=shape,
                                                 verbose=False,
                                                 unconditional_guidance_scale=scale,
                                                 unconditional_conditioning=c,
                                                 eta=ddim_eta,
                                                 x_T=z_enc)
                '''
                samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=scale,
                                            unconditional_conditioning=uc,)
                                            '''

                x_samples = model.decode_first_stage(samples_ddim)
                x_samples = torch.clamp(
                    (x_samples + 1.0) / 2.0, min=0.0, max=1.0)
                print(x_samples.shape)
                x_samples_mean = x_samples.mean(0).unsqueeze(0)
                if not skip_save:
                    for x_sample in x_samples:
                        x_sample = 255. * \
                            rearrange(x_sample.cpu().numpy(),
                                      'c h w -> h w c')
                        Image.fromarray(x_sample.astype(np.uint8)).save(
                            os.path.join(sample_path, f"{base_count:05}.png"))
                        base_count += 1
                all_samples.append(x_samples_mean)

                if not skip_grid:
                    # additionally, save as grid
                    grid = torch.stack(all_samples, 0)
                    grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                    grid = make_grid(grid, nrow=n_rows)

            # to image
            grid = 255. * \
                rearrange(grid, 'c h w -> h w c').cpu().numpy()
            Image.fromarray(grid.astype(np.uint8)).save(
                os.path.join(outpath, f'grid-{grid_count:04}.png'))
            grid_count += 1
            del grid

        toc = time.time()

print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
      f" \nEnjoy.")

Data shape for DDIM sampling is (20, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:32<00:00,  1.53it/s]


torch.Size([20, 3, 512, 512])
Data shape for DDIM sampling is (20, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:32<00:00,  1.52it/s]


torch.Size([20, 3, 512, 512])
Data shape for DDIM sampling is (20, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:33<00:00,  1.51it/s]


torch.Size([20, 3, 512, 512])
Data shape for DDIM sampling is (20, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:33<00:00,  1.51it/s]


torch.Size([20, 3, 512, 512])
Data shape for DDIM sampling is (20, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:33<00:00,  1.50it/s]


torch.Size([20, 3, 512, 512])
Data shape for DDIM sampling is (20, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:33<00:00,  1.50it/s]


torch.Size([20, 3, 512, 512])
Your samples are ready and waiting for you here: 
./outputs 
 
Enjoy.
