## Image generation pipeline

### Preliminaries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image, make_image_grid
import torch
import sys
sys.path.append('src')
from brain_encoder import BrainEncoder
from diffusion_prior.model import DiffusionPriorUNet
from diffusion_prior.pipeline import DiffusionPrior
from omegaconf import OmegaConf
from safetensors.torch import load_model
import os
from dataset import BrainStimuliDataset, select_random_dimension
from PIL import Image
import numpy as np

In [None]:
def get_concat_v_cut_center(images):
    heights = np.array(list(map(lambda x: x.height, images)))
    dst = Image.new('RGB', (min(list(map(lambda x: x.width, images))), 
                            heights.sum()))
    for im, h in zip(images, [0, *heights.cumsum()[:-1]]):
        dst.paste(im, (0, h))
    return dst

### SDXL + IP-Adapter pipeline

In [None]:
pipe = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", 
    torch_dtype=torch.float16
).to("cuda")

pipe.load_ip_adapter(
    "h94/IP-Adapter", 
    subfolder="sdxl_models", 
    weight_name="ip-adapter_sdxl_vit-h.safetensors", 
    torch_dtype=torch.float16
)

pipe.set_ip_adapter_scale(1)

### BrainEncoder

In [None]:
config_dir = '/home/jovyan/shares/SR004.nfs2/nkiselev/visual_stimuli_reconstruction/CreationOfIntelligentSystems_Simultaneous_fMRI-EEG/train/configs'
# config_name = 'pioneer.yaml'
config_name = 'improved-dataloader.yaml'
config_path = os.path.join(config_dir, config_name)
config = OmegaConf.load(config_path)

In [None]:
model = BrainEncoder(**config.model_kwargs).to(torch.float16)

In [None]:
steps = 31000
filename = os.path.join(config.output_dir, f'checkpoint-{steps}', 'model.safetensors')
load_model(model, filename)

### DiffusionPrior

In [None]:
prior_config_name = 'diffusion-prior.yaml'
prior_config_path = os.path.join(config_dir, prior_config_name)
prior_config = OmegaConf.load(prior_config_path)

In [None]:
diffusion_prior = DiffusionPriorUNet(**prior_config.model_kwargs)

In [None]:
steps = 14000
filename = os.path.join(prior_config.output_dir, f'checkpoint-{steps}', 'model.safetensors')
load_model(diffusion_prior, filename)

In [None]:
prior_pipe = DiffusionPrior(diffusion_prior=diffusion_prior, device="cuda")

### Dataset

In [None]:
dataset = BrainStimuliDataset(**config.dataloader_kwargs.dataset)

### Inference

In [None]:
idx = 40010
image_idx = 0

x = dataset[idx]
sub_id = torch.tensor([x['id']])
fmri_embeds = x['fmri'].unsqueeze(0).to(torch.float16)
eeg_embeds = x['eeg'].unsqueeze(0).to(torch.float16)
image_embeds = x['frames'].unsqueeze(0).to(torch.float16)
frame_paths = x['frame_paths']

image_path = frame_paths[0].replace('.pt', '.jpg')
image = Image.open(image_path)
image_embeds = image_embeds[:, image_idx, :]

image.show()

In [None]:
model.eval()
prior_pipe.diffusion_prior.eval()
with torch.no_grad():
    combined_embeds = model(sub_id, eeg_embeds, fmri_embeds).to(pipe.dtype)
    combined_embeds_enhanced = prior_pipe.generate(
        combined_embeds=combined_embeds, 
        num_inference_steps=1000, 
        guidance_scale=5.0,
    ).to(pipe.dtype)

### Generation

In [None]:
ip_adapter_image_embeds = torch.cat([
    torch.zeros_like(image_embeds),
    image_embeds
]).unsqueeze(1)

ip_adapter_combined_embeds = torch.cat([
    torch.zeros_like(combined_embeds),
    combined_embeds
]).unsqueeze(1)

num_inference_steps = 30
guidance_scale = 5.0
num_images_per_prompt = 3

original_images = pipe(
    prompt='', 
    ip_adapter_image_embeds=[ip_adapter_image_embeds], 
    num_inference_steps=num_inference_steps,
    guidance_scale=guidance_scale,
    num_images_per_prompt=num_images_per_prompt
).images

brain_images = pipe(
    prompt='', 
    ip_adapter_image_embeds=[ip_adapter_combined_embeds], 
    num_inference_steps=num_inference_steps,
    guidance_scale=guidance_scale,
    num_images_per_prompt=num_images_per_prompt
).images

brain_enhanced_images = pipe(
    prompt='', 
    ip_adapter_image_embeds=[ip_adapter_combined_embeds], 
    num_inference_steps=num_inference_steps,
    guidance_scale=guidance_scale,
    num_images_per_prompt=num_images_per_prompt
).images

In [None]:
original_grid = make_image_grid(original_images, 1, len(original_images))
brain_grid = make_image_grid(brain_images, 1, len(brain_images))
brain_enhanced_grid = make_image_grid(brain_enhanced_images, 1, len(brain_enhanced_images))

get_concat_v_cut_center([
    # image.resize((
    #     int(num_images_per_prompt * 1024),
    #     int(num_images_per_prompt * 1024 * image.size[1] / image.size[0])
    # )), 
    original_grid, 
    brain_grid,
    brain_enhanced_grid
]).show()