In [1]:
import torch
from diffusers import StableDiffusionPipeline
from typing import Callable, List, Optional, Union
import inspect
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [3]:
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16).to("cuda")

Fetching 16 files: 100%|█████████████████████| 16/16 [00:00<00:00, 55553.70it/s]


In [11]:
!mkdir imgs
!mkdir imgs/homonym_duplication imgs/meaning_edit imgs/meaning_sum

## Getting Images

Edited version of the ```StableDiffusionPipeline```'s ```__call__()``` function that enables giving the text embedding directly as input.

In [4]:
def get_images(text_embeddings, pipe, img_name,prompt=None, negative_prompt=None,num_images_per_prompt=3):
    height = 512
    width = 512
    num_inference_steps = 50
    guidance_scale = 7.5
    eta = 0.0
    generator = None
    latents = None
    output_type="pil"
    return_dict = True
    callback= None
    callback_steps= 1
    batch_size =1
    with torch.no_grad():

        bs_embed, seq_len, _ = text_embeddings.shape
        text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
        text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)

        do_classifier_free_guidance = guidance_scale > 1.0
        if do_classifier_free_guidance:
            uncond_tokens: List[str]
            if negative_prompt is None:
                uncond_tokens = [""]
            elif type(prompt) is not type(negative_prompt):
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt]
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )
            else:
                uncond_tokens = negative_prompt

            max_length = text_embeddings.shape[1]
            uncond_input = pipe.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )
            uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0]

            seq_len = uncond_embeddings.shape[1]
            uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
            uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)

            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

        latents_shape = (batch_size * num_images_per_prompt, pipe.unet.in_channels, height // 8, width // 8)
        latents_dtype = text_embeddings.dtype
        if latents is None:
            if pipe.device.type == "mps":
                latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
                    pipe.device
                )
            else:
                latents = torch.randn(latents_shape, generator=generator, device=pipe.device, dtype=latents_dtype)
        else:
            if latents.shape != latents_shape:
                raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
            latents = latents.to(pipe.device)

        pipe.scheduler.set_timesteps(num_inference_steps)

        timesteps_tensor = pipe.scheduler.timesteps.to(pipe.device)

        latents = latents * pipe.scheduler.init_noise_sigma

        accepts_eta = "eta" in set(inspect.signature(pipe.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        for i, t in enumerate(pipe.progress_bar(timesteps_tensor)):
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

            noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

            if callback is not None and i % callback_steps == 0:
                callback(i, t, latents)

        latents = 1 / 0.18215 * latents
        image = pipe.vae.decode(latents).sample

        image = (image / 2 + 0.5).clamp(0, 1)

        image = image.cpu().permute(0, 2, 3, 1).float().numpy()

        has_nsfw_concept = None

        if output_type == "pil":
            image = pipe.numpy_to_pil(image)

        if not return_dict:
            print("NSFW")

        out=image

        for i in range(len(image)):
            image[i].save("imgs/"+img_name + "_"+str(i)+".png")

## Editing Embedding

In [5]:
def w_b(w, b):
    v_b = torch.zeros((768)).type(torch.HalfTensor).cuda()
    for j in range(len(b)):
        v_b += torch.dot(w,b[j]) * b[j]
    return v_b

def normal(v):
    return (1/torch.sqrt(torch.dot(v,v))) * v

def norm(v):
    return torch.sqrt(torch.dot(v,v))

def project(a, b):
    bb_dotprod = torch.dot(b,b)
    ab_dotprod = torch.dot(a,b)
    if bb_dotprod != 0:
        coeff = (ab_dotprod/bb_dotprod)
    else:
        coeff = 0
    return coeff * b

def edit_embed(orig_embed, meaning_1, meaning_2):
    u = [normal(meaning_1),normal(meaning_2 - project(meaning_2, normal(meaning_1)))]
    # pushing ambiguous towards meaning_2
    orig_embed = orig_embed  - w_b(orig_embed, u) + norm(meaning_2)*normal(meaning_2 -project(meaning_2, meaning_1)) 
    return orig_embed

## Getting Encodings

In [75]:
def one_prompt_embed(prompt_1, pipe):
    text_inputs = pipe.tokenizer(
        prompt_1,
        padding="max_length",
        max_length=pipe.tokenizer.model_max_length,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids

    text_embeddings_1 = pipe.text_encoder(text_input_ids.to(pipe.device))[0]
    
    return text_embeddings_1

## Find Meaning Directions

In [7]:
def diff_svd(vectors_m, vectors_n, n, model_dim=768):
    mus = [torch.zeros((model_dim)).cuda() for i in range(n)]

    for i in range(n):
        mus[i] = (1/2)*(vectors_m[i]+vectors_n[i])

    subspace = torch.zeros((model_dim,model_dim)).cuda()

    for i in range(n):
        subspace += (1/2)*torch.outer(vectors_m[i] - mus[i],vectors_m[i]- mus[i])
        subspace += (1/2)*torch.outer(vectors_n[i] - mus[i],vectors_n[i]- mus[i])
    u_m, s_m, v = np.linalg.svd(subspace.detach().cpu(), full_matrices=True)
    return torch.tensor(u_m).type(torch.HalfTensor).cuda(), s_m

def find_vectors(w, sentences_1, sentences_2, sentences_amb, pipe, min_dim=10, threshold=0.99, model_dim=768):
    n = len(sentences_1)
    vectors_1 = []
    vectors_2 = []
    vectors_amb = []
    for i in range(n):
        full_vec_1 = one_prompt_embed(sentences_1[i], pipe)
        w_idx = sentences_1[i].split(" ").index(w) + 1
        vec_1 = full_vec_1[:,w_idx,:].squeeze(0)
        vectors_1.append(vec_1)

        full_vec_2 = one_prompt_embed(sentences_2[i], pipe)
        w_idx = sentences_2[i].split(" ").index(w) + 1
        vec_2 = full_vec_2[:,w_idx,:].squeeze(0)
        vectors_2.append(vec_2)

        full_vec_amb = one_prompt_embed(sentences_amb[i], pipe)
        w_idx = sentences_amb[i].split(" ").index(w) + 1
        vec_amb = full_vec_amb[:,w_idx,:].squeeze(0)
        vectors_amb.append(vec_amb)

    u_1, s_1 = diff_svd(vectors_1, vectors_amb, n, model_dim)
    u_2, s_2 = diff_svd(vectors_2, vectors_amb, n, model_dim)

    dim = 0
    while sum(s_1[:dim])/sum(s_1) < threshold or sum(s_2[:dim])/sum(s_2) < threshold or dim < min_dim:
        dim += 1
    u_b_1 = [u_1[:,j] for j in range(dim)]
    u_b_2 = [u_2[:,j] for j in range(dim)]

    diff_1 = [normal(w_b(vectors_1[i], u_b_1)) for i in range(n)]
    diff_2 = [normal(w_b(vectors_2[i], u_b_2)) for i in range(n)]
    diff_amb_1 = [normal(w_b(vectors_amb[i], u_b_1)) for i in range(n)]
    diff_amb_2 = [normal(w_b(vectors_amb[i], u_b_2)) for i in range(n)]

    v_1 = torch.zeros((model_dim)).type(torch.HalfTensor).cuda()
    v_2 = torch.zeros((model_dim)).type(torch.HalfTensor).cuda()
    for i in range(dim):
        v_1 += sum([torch.dot(diff_1[j]  , u_b_1[i]) for j in range(n)])/n * u_b_1[i] 
        v_2 += sum([torch.dot(diff_2[j] , u_b_2[i]) for j in range(n)])/n * u_b_2[i]

    for i in range(n):
        v_1 = v_1 - project(v_1, normal(vectors_2[i]))
        v_2 = v_2 - project(v_2, normal(vectors_1[i]))

    norm_v_1 = norm(v_1)
    v_1 = normal(v_1)

    norm_v_2 = norm(v_2)
    v_2 = normal(v_2)
    for i in range(n):
        proj_1 = vectors_1[i]
        proj_2 = vectors_2[i]
        proj_amb = vectors_amb[i]
    return max([torch.dot(vectors_1[j] , v_1) for j in range(n)]) *v_1, max([torch.dot(vectors_2[j] , v_2) for j in range(n)]) *v_2

## Generate All Images for Sense Editing Experiments

In [8]:
def edit_prompts(word, prompt_dict, sentences_1, sentences_2, sentences_amb, pipe, neg_prompt="", repeat=5):
    v_1, v_2 = find_vectors(word, sentences_1, sentences_2, sentences_amb, pipe, threshold=0.95,min_dim=3)
    for prompt, filename in prompt_dict.items():
        orig_prompt = prompt
        orig_embed = one_prompt_embed(orig_prompt,pipe)
        idx = orig_prompt.split(" ").index(word) + 1

        embed_1 = orig_embed.detach().clone()
        embed_1[:,idx,:] = edit_embed(embed_1[:,idx,:].squeeze(0).clone(), v_2, v_1).clone()

        embed_2 = orig_embed.detach().clone()
        embed_2[:,idx,:] = edit_embed(embed_2[:,idx,:].squeeze(0).clone(), v_1, v_2).clone()
        
        for i in range(repeat):
            get_images(embed_1, pipe, filename + "sense_1_" + str(i))
            get_images(embed_2, pipe, filename + "sense_2_" + str(i))
            get_images(orig_embed, pipe, filename + "amb_" + str(i))
        if neg_prompt != "":
            for i in range(repeat):
                get_images(embed_1, pipe, filename + "sense_1_" + str(i)+"_neg", prompt = orig_prompt, negative_prompt=neg_prompt)
                get_images(embed_2, pipe, filename + "sense_2_" + str(i)+"_neg", prompt = orig_prompt, negative_prompt=neg_prompt)
                get_images(orig_embed, pipe, filename + "amb_" + str(i)+"_neg", prompt = orig_prompt, negative_prompt=neg_prompt)

In [23]:
bass_sentence_music = ["the musician played a double bass"]

bass_sentence_fish = ["the fisherman caught a sea bass"]

bass_sentence_amb = ["a bass"]

In [24]:
edit_prompts("bass", {"a bass":"meaning_edit/bass_"}, bass_sentence_music, bass_sentence_fish, bass_sentence_amb, pipe, neg_prompt="disfigured, deformed, bad anatomy, low quality, jpeg artifacts", repeat=1)

100%|███████████████████████████████████████████| 51/51 [00:10<00:00,  4.85it/s]
100%|███████████████████████████████████████████| 51/51 [00:10<00:00,  4.75it/s]
100%|███████████████████████████████████████████| 51/51 [00:10<00:00,  4.85it/s]
100%|███████████████████████████████████████████| 51/51 [00:10<00:00,  4.84it/s]
100%|███████████████████████████████████████████| 51/51 [00:10<00:00,  4.84it/s]
100%|███████████████████████████████████████████| 51/51 [00:10<00:00,  4.83it/s]


In [77]:
word, prompt_dict, sentences_1, sentences_2, sentences_amb, pipe = "bass", {"a bass":"meaning_edit/bass_"}, bass_sentence_music, bass_sentence_fish, bass_sentence_amb, pipe

In [78]:
v_1, v_2 = find_vectors(word, sentences_1, sentences_2, sentences_amb, pipe, threshold=0.95,min_dim=3)

In [79]:
v_1.shape

torch.Size([768])

In [80]:
v_2.shape

torch.Size([768])

In [81]:
n = len(sentences_1)
vectors_1 = []
vectors_2 = []
vectors_amb = []
for i in range(n):
    full_vec_1 = one_prompt_embed(sentences_1[i], pipe)
    w_idx = sentences_1[i].split(" ").index(word) + 1
    vec_1 = full_vec_1[:,w_idx,:].squeeze(0)
    vectors_1.append(vec_1)

    full_vec_2 = one_prompt_embed(sentences_2[i], pipe)
    w_idx = sentences_2[i].split(" ").index(word) + 1
    vec_2 = full_vec_2[:,w_idx,:].squeeze(0)
    vectors_2.append(vec_2)

    full_vec_amb = one_prompt_embed(sentences_amb[i], pipe)
    w_idx = sentences_amb[i].split(" ").index(word) + 1
    vec_amb = full_vec_amb[:,w_idx,:].squeeze(0)
    vectors_amb.append(vec_amb)

In [82]:
len(vectors_2)

1

In [83]:
full_vec_2 = one_prompt_embed(sentences_2[i], pipe)
w_idx = sentences_2[i].split(" ").index(word) + 1
vec_2 = full_vec_2[:,w_idx,:].squeeze(0)