In [1]:
import os
import time

from tqdm import tqdm
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn 
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import random
from torch.cuda.amp import autocast


In [2]:
from diffusers import AutoPipelineForText2Image
import torch
device = torch.device("cuda:6")
model_name = "stabilityai/sd-turbo"
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae", use_safetensors=True).to(device)
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
    model_name, subfolder="text_encoder", use_safetensors=True).to(device)
unet = UNet2DConditionModel.from_pretrained(
    model_name, subfolder="unet", use_safetensors=True
).to(device)
scheduler =EulerDiscreteScheduler.from_pretrained(model_name,subfolder="scheduler")

In [16]:
scheduler

EulerDiscreteScheduler {
  "_class_name": "EulerDiscreteScheduler",
  "_diffusers_version": "0.27.2",
  "beta_end": 0.012,
  "beta_schedule": "scaled_linear",
  "beta_start": 0.00085,
  "clip_sample": false,
  "interpolation_type": "linear",
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "rescale_betas_zero_snr": false,
  "sample_max_value": 1.0,
  "set_alpha_to_one": false,
  "sigma_max": null,
  "sigma_min": null,
  "skip_prk_steps": true,
  "steps_offset": 1,
  "timestep_spacing": "trailing",
  "timestep_type": "discrete",
  "trained_betas": null,
  "use_karras_sigmas": false
}

In [3]:
fim_prompts = [
    "Serene alpine meadow at dawn",
    "Cyberpunk cityscape at twilight",
    "Medieval castle overlooking a village",
    "Mystical forest with glowing mushrooms",
    "Futuristic space station in orbit",
    "Underwater coral city teeming with marine life",
    "Steampunk airship above a Victorian skyline",
    "Desert oasis with palm trees and clear skies",
    "Ancient pyramids beneath a starry night",
    "Lunar base on the moon's surface",
    "Bustling market in a fantasy realm",
    "Gothic cathedral at dusk with a storm approaching",
    "Post-apocalyptic wasteland with ruins",
    "High-speed train in a mountain tunnel",
    "Giant robot overlooking a cityscape",
    "Enchanted garden with floating lanterns",
    "Abandoned factory overtaken by nature",
    "Deep sea exploration with bioluminescent creatures",
    "Island paradise with a hidden waterfall",
    "Subterranean cave network with vast chambers",
    "Flying cars over a neon metropolis",
    "Snowy village with a cozy inn",
    "Haunted mansion in a dark forest",
    "Solar powered farm in a desert",
    "Interdimensional portal opening in a library",
    "Viking longship sailing through a stormy sea",
    "Alien planet with unique flora and fauna",
    "Colony on Mars with domed habitats",
    "Art deco skyscrapers with flying buttresses",
    "Sunken pirate ship in crystal clear water",
    "Harbor at sunset with seagulls and sailboats",
    "Zen rock garden in a tranquil temple",
    "Giant's causeway on an alien world",
    "Festival of lights in a bustling city",
    "Dinosaurs roaming a prehistoric landscape",
    "Flying buttress bridge over a chasm",
    "Orbiting satellite with Earth in the background",
    "Cybernetic enhancements in a futuristic surgery",
    "Crystal clear lake in a mountain valley",
    "Glacial ice cave with refracted light",
    "Aerial view of a sprawling desert city",
    "Underwater research lab with marine visitors",
    "Mammoth herd on the tundra under a aurora",
    "Rustic windmill farm against a sunset",
    "Space elevator with cargo ascending",
    "Volcanic eruption with lava flowing down the slopes",
    "Time-lapse of stars moving in the night sky",
    "Ornate gothic library filled with ancient tomes",
    "Sailing through a bioluminescent bay",
    "Dune buggy race across a sandy plain"
]

In [4]:
def generate_img(prompts,gen_shape=(512,512),num_inference_steps=4,guidance_scale=0.75,requires_grad=True):
    text_input = tokenizer(
        prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
    )
    batch_size = len(prompts)
    height,width=gen_shape 
    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
    latents = torch.randn(
        (batch_size, unet.config.in_channels, height // 8, width // 8),
        device=device
    )
    latents = latents * scheduler.init_noise_sigma
    from tqdm.auto import tqdm

    scheduler.set_timesteps(num_inference_steps)
    
    for t in scheduler.timesteps:
        # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
        latent_model_input = torch.cat([latents] * 2)
    
        latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
    
        # predict the noise residual
        if requires_grad:
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
        else:
            with torch.no_grad():
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
    
        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
    
        # compute the previous noisy sample x_t -> x_t-1
        latents = scheduler.step(noise_pred, t, latents).prev_sample
    latents = 1 / 0.18215 * latents
    if requires_grad:
        image = vae.decode(latents).sample
        return image

    else:
        with torch.no_grad():
            image = vae.decode(latents).sample
            return image


In [5]:
num_prompts=5
fim_samples_num = 100
image_shape = 512
prompts_all = []
# for i in tqdm(range(fim_samples_num//num_prompts)):
#     prompt_selected = random.choices(fim_prompts,k=num_prompts)
#     images = generate_img(prompt_selected,gen_shape=(image_shape,image_shape),num_inference_steps=2,requires_grad=False)
#     prompts_all.extend(prompt_selected)
#     for idx,image in enumerate(images):
#         image = (image / 2 + 0.5).clamp(0, 1)
#         image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
#         Image.fromarray(image).save(f"datasets/generated/sd_fim/{i * num_prompts + idx}.jpg")
# prompts_all_text = "\n".join(prompts_all)
# with open("datasets/generated/sd_fim/prompts.txt","w") as f:
#     f.write(prompts_all_text)        

In [6]:
rgb_mean = np.array([0.485, 0.456, 0.406])
rgb_std = np.array([0.229, 0.224, 0.225])
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((image_shape,image_shape)),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 常用标准化
])
postprocess = transforms.Compose([
    lambda x:(x/2+0.5).clamp(0,1),
    transforms.ToPILImage()
])

In [7]:
fim_dataset =[]
fim_dataset_dir = "datasets/generated/sd_fim"
for file in os.listdir(fim_dataset_dir):
    if file.endswith(".jpg"):
        image = Image.open(os.path.join(fim_dataset_dir,file))
        fim_dataset.append(preprocess(image))
fim_dataset = torch.stack(fim_dataset)
with open(os.path.join(fim_dataset_dir,"prompts.txt"),"r") as f:
    fim_prompts_datasets = f.readlines()


In [6]:
fim_dataset.shape

torch.Size([100, 3, 512, 512])

In [8]:
import pickle

batch_size = 2
save_path = os.path.join("model_weights","fim_sd.pt")
def save_fim():
    torch.cuda.empty_cache()
    fisher_dict = []
    params_mle_dict = []    
    mle_params = {name:param.clone() for name,param in unet.named_parameters()}
    fisher_dict = {name:torch.zeros_like(param.clone()) for name,param in unet.named_parameters()}
    with torch.no_grad():
        text_input = tokenizer(
        fim_prompts_datasets, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
    )
    for _ in tqdm(range(fim_samples_num)):
        torch.cuda.empty_cache()
        with torch.no_grad():
            sample_idx = random.choices([i for i in range(batch_size)],k=batch_size)
            selected_input_ids = text_input.input_ids[sample_idx,:]
            text_embeddings = text_encoder(selected_input_ids.to(device))[0]
            max_length = text_input.input_ids.shape[-1]
            uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
            uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
            text_embeddings = torch.cat([text_embeddings])
            selected_samples = fim_dataset[sample_idx,:].to(device)
            selected_x = vae.encode(selected_samples).latent_dist.sample()
            noise = torch.randn_like(selected_x).to(device)
            time_steps = torch.randint(0,scheduler.config.num_train_timesteps,(1,)).to(device)
        x_t = scheduler.add_noise(selected_x,noise,time_steps)
        pred = unet(x_t,time_steps,text_embeddings).sample
        loss = F.mse_loss(pred,noise)
        
        loss.backward()
        for name, param in unet.named_parameters():
            if torch.isnan(param.grad.data).any():
                print("NAN detected")
            fisher_dict[name] += (param.grad.data ** 2) / batch_size
    with open(save_path, 'wb') as f:
        pickle.dump(fisher_dict, f)
save_fim()         

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:11<00:00,  1.40it/s]
