In [1]:
import os
import time

from tqdm import tqdm
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn 
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import random

In [2]:
import torch
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
device = torch.device("cuda:7")
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!

# Load model.
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to(device)

# Ensure sampler uses "trailing" timesteps.
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")


  deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
num_samples = 20
num_prompts = 2
prompt = ["draw an orange cat in vangogh style" for i in range(num_prompts)]
for i in range(num_samples//num_prompts):
    images = pipe(prompt, num_inference_steps=2, guidance_scale=0).images
    for idx,image in enumerate(images):
        image.save(f"./datasets/generated/cats/{i*num_prompts+idx}.jpg")

In [3]:
rgb_mean = np.array([0.485, 0.456, 0.406])
rgb_std = np.array([0.229, 0.224, 0.225])
image_shape = 1024
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((image_shape,image_shape)),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 常用标准化

])


In [4]:
vgg_net = torchvision.models.vgg16()
vgg_encoder = nn.Sequential(*list(vgg_net.features)).to(device)

In [5]:
style_layers, content_layers = [0, 5, 10, 19, 28], [25]

def build_style_hooker():
    style_features=[]
    def style_hooker(module, fea_in, fea_out):
        style_features.append(fea_out)
    return style_hooker,style_features
def build_content_hooker():
    content_features=[]
    def content_hooker(module,fea_in,fea_out):
        content_features.append(fea_out)
    return content_hooker,content_features

def hook_model(layers,hook,model):
    for i,layer in model.named_children():
        if int(i) in layers:
            layer.register_forward_hook(hook)
# style_hooker,style_features = build_style_hooker()
# hook_model(style_layers,style_hooker,vgg_encoder)
# img = Image.open("./bedroom.jpg")
# img_tensor = preprocess(img).to(device)
# print(img_tensor.shape)
# a=vgg_encoder(img_tensor)

In [13]:
def compute_gram(x):
    #x [batch_size,channels,h,w]
    x = x.reshape(x.shape[0],x.shape[1],-1) #[batch_size,channels,h*w]
    gram = x@x.transpose(1,2) #[batch_size,channels,channels]
    return gram/(x.shape[2])
def style_distance(x1,x2):
    #return [batch_size]
    gram_diff = (compute_gram(x1)-compute_gram(x2))**2
    return torch.mean(gram_diff,dim=(-1,-2))
# print(style_distance(style_features[0],style_features[0]))

In [6]:
def frozen_all(model):
    for parameters in model.parameters():
        parameters.requires_grad = False
def activate_all(model):
    for parameters in model.parameters():
        parameters.requires_grad = True

In [7]:
vae = pipe.vae
unet = pipe.unet
clip1 = pipe.text_encoder
clip2 = pipe.text_encoder_2
model_components = {"vae":vae,"unet":unet,"clip1":clip1,"clip2":clip2,"vgg_encoder":vgg_encoder}

In [8]:
forget_prompts_num=10
forget_prompts = []
neutral_prompts_num=10
neutral_prompts = [
    "An ancient library hidden in a lush, mystical forest, bathed in golden sunlight filtering through the trees",
    "A futuristic cityscape at night, illuminated by neon lights with flying cars zooming between skyscrapers",
    "A serene mountain landscape in the early morning, with low-hanging clouds and the first light of dawn painting the peaks in warm colors",
    "A bustling medieval market square during a festival, with colorful stalls, flags waving, and lively crowds",
    "An underwater city with bioluminescent plants and creatures, ancient ruins, and bubble-like homes",
    "A steampunk laboratory filled with intricate machinery, glowing tubes, gears, and an inventor working on a new creation",
    "A vast desert with a nomadic caravan at sunset, silhouetted against the vibrant colors of the sky, with distant mountains",
    "An enchanted forest in autumn, with a path winding through trees with multicolored leaves, and a hint of magic in the air",
    "A space station orbiting a distant planet, with astronauts looking out at the stars from a large viewing dome",
    "A post-apocalyptic city being reclaimed by nature, with skyscrapers covered in vines and wildlife wandering the streets"
]
for i in range(forget_prompts_num):
    color = random.choice(['red','black','blue','white'])
    prompt = f"draw a {color} cat with vangogh style" #remember to change for robust test
    forget_prompts.append(prompt)

In [9]:
forget_tensors = []
forget_dataset_dir = 'datasets/generated/cats'
for file in os.listdir(forget_dataset_dir):
    image = Image.open(os.path.join(forget_dataset_dir,file))
    forget_tensors.append(preprocess(image))
forget_tensors=torch.stack(forget_tensors,dim=0).to(device)
print(forget_tensors.shape)

torch.Size([20, 3, 1024, 1024])


In [33]:
forget_tensors[1,:].repeat(2,1,1,1).shape

torch.Size([2, 3, 1024, 1024])

In [21]:
epochs = 100
batch_size = 2
vgg_encoder.eval()
style_hooker,style_features = build_style_hooker()
hook_model(style_layers,style_hooker,vgg_encoder)
task_name ="naive_forget"
save_dir = f'model_weights/{task_name}/{time.strftime("%Y%m%d")}'

# unet.load_state_dict(torch.load("model_weights/naive_forget/20240402/naive_forget_156959506.5.pt"))

for k,v in model_components.items():
    frozen_all(v)
activate_all(unet)

In [24]:
optim = torch.optim.Adam(unet.parameters())
unet.train()
for epoch in tqdm(range(epochs)):
    loss_epoch = 0

    for batch in range(forget_tensors.shape[0]):
        optim.zero_grad()
        style_features.clear()
        prompt = random.choices(forget_prompts,k=batch_size) #remember to alter
        sample_images =  pipe(prompt, num_inference_steps=4, guidance_scale=0).images
        sample_tensors = []
        for image in sample_images:
            sample_tensors.append(preprocess(image))
        sample_tensors = torch.stack(sample_tensors,dim=0).to(device)
        selected_forget_tensors = forget_tensors[batch,:].repeat(batch_size,1,1,1)
        surrogate_tensor = torch.randn_like(sample_tensors)
        in_tensors = torch.cat([sample_tensors,selected_forget_tensors],dim=0)
        vgg_encoder(in_tensors)
        # gram distance for every sample between the forget pics and sampled pics
        # print(style_features)
        gram_distance = style_distance(style_features[0][:batch_size,:],style_features[0][batch_size:,:])
        # print(gram_distance.shape)
        #mse for every sample [batch_size]
        torch.autograd.set_detect_anomaly(True)
        loss_forget = torch.sum((sample_tensors-surrogate_tensor)**2,dim=(-1,-2,-3))
        loss_final = torch.sum(torch.exp(-gram_distance)*loss_forget)
        # loss_forget.requires_grad=True
        # loss_final.requires_grad = True
        # loss_forget.requires_grad = True
        with torch.no_grad():
            loss_epoch += loss_final.item()
        loss_forget.backward()
        
        # loss_final.backward()#remmeber to alter
        optim.step()
    
    if epoch%5==0:
        image = pipe(random.choice(forget_prompts), num_inference_steps=4, guidance_scale=0).images[0]
        image.save(f"outputs/{epoch}.jpg")
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        torch.save(unet.state_dict(),os.path.join(save_dir,f"{task_name}_{round(loss_epoch,3)}.pt"))
    print(f"epoch:{epoch},loss:{loss_epoch}")

  0%|                                                                                                                                        | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

False


  0%|                                                                                                                                        | 0/100 [00:02<?, ?it/s]


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn