## Image generation pipeline

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image, make_image_grid
import torch
import sys
sys.path.append('src')
from brain_encoder import BrainEncoder
from omegaconf import OmegaConf
from safetensors.torch import load_model
import os
from dataset import build_dataloaders, select_random_dimension
from PIL import Image
import numpy as np

In [None]:
def get_concat_v_cut_center(images):
    heights = np.array(list(map(lambda x: x.height, images)))
    dst = Image.new('RGB', (min(list(map(lambda x: x.width, images))), 
                            heights.sum()))
    for im, h in zip(images, [0, *heights.cumsum()[:-1]]):
        dst.paste(im, (0, h))
    return dst

In [None]:
pipe = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", 
    torch_dtype=torch.float16
).to("cuda")

pipe.load_ip_adapter(
    "h94/IP-Adapter", 
    subfolder="sdxl_models", 
    weight_name="ip-adapter_sdxl_vit-h.safetensors", 
    torch_dtype=torch.float16
)

pipe.set_ip_adapter_scale(1)

In [None]:
config_path = '/home/jovyan/shares/SR004.nfs2/nkiselev/visual_stimuli_reconstruction/CreationOfIntelligentSystems_Simultaneous_fMRI-EEG/train/configs/pioneer.yaml'
config = OmegaConf.load(config_path)

In [None]:
model = BrainEncoder(**config.model_kwargs)

In [None]:
steps = 14000
filename = os.path.join(config.output_dir, f'checkpoint-{steps}', 'model.safetensors')
load_model(model, filename)

In [None]:
train_dataloader, val_dataloader = build_dataloaders(**config.dataloaders_kwargs)

In [None]:
idx = 21
image_idx = 0

x = val_dataloader.dataset[idx]
sub_id = x['id']
fmri_embeds = x['fmri'].unsqueeze(0)
eeg_embeds = x['eeg'].unsqueeze(0)
image_embeds = x['frames'].unsqueeze(0)
frame_paths = x['frame_paths']

image_path = frame_paths[0][image_idx].replace('.pt', '.jpg')
image = Image.open(image_path)
image_embeds = image_embeds[:, image_idx, :]

image.show()

In [None]:
model.eval()
with torch.no_grad():
    combined_embeds = model(sub_id, eeg_embeds, fmri_embeds).to(pipe.dtype)

In [None]:
ip_adapter_image_embeds = torch.cat([
    torch.zeros_like(image_embeds),
    image_embeds
]).unsqueeze(1)

ip_adapter_combined_embeds = torch.cat([
    torch.zeros_like(combined_embeds),
    combined_embeds
]).unsqueeze(1)

num_inference_steps = 30
guidance_scale = 5.0
num_images_per_prompt = 3

original_images = pipe(
    prompt='', 
    ip_adapter_image_embeds=[ip_adapter_image_embeds], 
    num_inference_steps=num_inference_steps,
    guidance_scale=guidance_scale,
    num_images_per_prompt=num_images_per_prompt
).images

brain_images = pipe(
    prompt='', 
    ip_adapter_image_embeds=[ip_adapter_combined_embeds], 
    num_inference_steps=num_inference_steps,
    guidance_scale=guidance_scale,
    num_images_per_prompt=num_images_per_prompt
).images

In [None]:
original_grid = make_image_grid(original_images, 1, len(original_images))
brain_grid = make_image_grid(brain_images, 1, len(brain_images))

get_concat_v_cut_center([
    image.resize((
        int(num_images_per_prompt * 1024),
        int(num_images_per_prompt * 1024 * image.size[1] / image.size[0])
    )), 
    original_grid, 
    brain_grid
]).show()

---

## Example of image reconstruction with CLIP-ViT-H-14 and SDXL + IP-Adapter

In [None]:
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torch

In [None]:
pipe = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", 
    torch_dtype=torch.float16
).to("cuda")

pipe.load_ip_adapter(
    "h94/IP-Adapter", 
    subfolder="sdxl_models", 
    weight_name="ip-adapter_sdxl_vit-h.safetensors", 
    torch_dtype=torch.float16
)

pipe.set_ip_adapter_scale(1)

In [None]:
image = load_image('https://preview.redd.it/sdxl-is-really-good-with-cats-v0-n7izni8y1f9c1.png?auto=webp&s=d939d6d86f4402e0ea588faf3ca86d67903a826c')
image = image.resize((1024, 1024))
image

In [None]:
import open_clip

image_encoder, _, feature_extractor = open_clip.create_model_and_transforms(
    'ViT-H-14', pretrained='laion2b_s32b_b79k', precision='fp16', device='cuda')

In [None]:
image_processed = feature_extractor(image)[None, ...].to("cuda", dtype=torch.float16)
image_embeds = image_encoder.encode_image(image_processed)

In [None]:
ip_adapter_image_embeds = torch.cat([
    torch.zeros_like(image_embeds),
    image_embeds
]).unsqueeze(1)

image = pipe(
    prompt='', 
    ip_adapter_image_embeds=[ip_adapter_image_embeds], 
    num_inference_steps=30,
    guidance_scale=5.0,
).images[0]

image