# 利用16张low level图片, 生成16张图

In [1]:
import torch.nn as nn
import torch
import re
import os
from PIL import Image
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from custom_pipeline_low_level import Generator4Embeds
from diffusion_prior import DiffusionPriorUNet, Pipe

In [2]:
input_dir = f"/home/tom/fsas/eeg_data/generated_images/demo/input"
output_dir = f"/home/tom/fsas/eeg_data/generated_images/demo/output"
device = "cuda" if torch.cuda.is_available() else "cpu" 

In [3]:
class CLIPEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # self.clip = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14').to(torch.bfloat16)
        # self.clip_size = (224, 224)

        self.preprocess = CLIPImageProcessor(
            # size={"height": 512, "width": 512},
            size={"shortest_edge": 512}, 
            crop_size={"height": 512, "width": 512},
        )


        # for param in self.clip.parameters():
        #     param.requires_grad = False
        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
        "h94/IP-Adapter", 
        # "laion2b_s32b_b79k",
        subfolder="models/image_encoder",
        torch_dtype=torch.float16,
        ).to("cuda")
    
clip_encoder = CLIPEncoder().to(device)



In [7]:
diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)
pipe = Pipe(diffusion_prior, device=device)
pipe.diffusion_prior.load_state_dict(torch.load(f'/home/tom/fsas/eeg_data/diffusion_prior_old/sub-08/diffusion_prior.pt', map_location=device))
train_eeg_embeddings = torch.load('/home/tom/fsas/eeg_data/ATM_S_eeg_features_sub-08_train_old.pt', map_location=device).unsqueeze(1) # (66160, 1, 1024)

def extract_label(filename):
    match = re.search(r'_(\d+)\.png$', filename)
    if match:
        return int(match.group(1)) 
    return None

seed_value = 42
gen = torch.Generator(device=device)
gen.manual_seed(seed_value)

for i, (file_name) in enumerate(os.listdir(input_dir)):
    label = extract_label(file_name)
    input_path = os.path.join(input_dir, file_name)
    low_level_image = Image.open(input_path) 
    low_level_image = clip_encoder.preprocess(low_level_image, return_tensors="pt").pixel_values  # [1, 3, 224, 224]

    generator = Generator4Embeds(num_inference_steps=5, device=device, img2img_strength=0.8, low_level_image=low_level_image)
    h = pipe.generate(c_embeds=train_eeg_embeddings[label * 10], num_inference_steps=10, guidance_scale=2.0)
    reconstructed_image = generator.generate(h, generator=gen)

    output_path = os.path.join(output_dir, file_name)
    reconstructed_image.save(output_path)

  pipe.diffusion_prior.load_state_dict(torch.load(f'/home/tom/fsas/eeg_data/diffusion_prior_old/sub-08/diffusion_prior.pt', map_location=device))
  train_eeg_embeddings = torch.load('/home/tom/fsas/eeg_data/ATM_S_eeg_features_sub-08_train_old.pt', map_location=device).unsqueeze(1) # (66160, 1, 1024)
Loading pipeline components...: 100%|██████████| 7/7 [00:03<00:00,  2.21it/s]
  state_dict = torch.load(model_file, map_location="cpu")
10it [00:00, 12.60it/s]


latents torch.Size([1, 4, 64, 64])
noise torch.Size([1, 4, 64, 64])


 80%|████████  | 4/5 [00:00<00:00, 11.32it/s]


FileNotFoundError: [Errno 2] No such file or directory: '/home/tom/fsas/eeg_data/generated_images/demo/output/reconstructed_image_114.png'

In [None]:
pipe_2 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float, variant="fp16")

image_processor = VaeImageProcessor()

if hasattr(pipe, 'vae'):
    for param in pipe.vae.parameters():
        param.requires_grad = False

vae = pipe_2.vae.to(device)
vae.requires_grad_(False)
vae.eval()

def evaluate(eeg_encoder, data_loader, train=True):
    eeg_encoder = eeg_encoder.to(device)
    eeg_encoder.eval()
    
    count = 0
    for batch_idx, (eeg_data, labels, _, _, _, _) in enumerate(data_loader):
        eeg_data = eeg_data.to(device)
        eeg_feature = eeg_encoder(eeg_data)
        x_reconstructed = vae.decode(eeg_feature).sample
        prefix = "train" if train else "test" 
        img_reconstructed = image_processor.postprocess(x_reconstructed, output_type="pil")
        for i, (label) in enumerate(labels):
            save_path = f"/home/tom/fsas/eeg_data/generated_images/tmp/{prefix}/{label+1}.png" 
            # 每个label生成一张图即可
            if os.path.exists(save_path):
                pass

            img_reconstructed[i].save(save_path)
            count = count + 1 
            if count == 200:
                return