In [None]:
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator
import inspect

# sdxl unclip requires code from https://github.com/Stability-AI/generative-models/tree/main
sys.path.append('generative_models/')
import sgm
from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder, FrozenOpenCLIPEmbedder2
from generative_models.sgm.models.diffusion import DiffusionEngine
from generative_models.sgm.util import append_dims
from omegaconf import OmegaConf

torch.backends.cuda.matmul.allow_tf32 = False

import utils
from models import *

accelerator = Accelerator(split_batches=True, mixed_precision='fp16')
device = accelerator.device
print("device:", device)
tag = 'last'




device: cuda


In [None]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    model_name = "fmri_model_v1_1ses_50ep"
    print("model_name:", model_name)

    # other variables can be specified in the following string:
    jupyter_args = f"--data_path={os.getcwd()}/data \
                    --cache_dir={os.getcwd()}/cache \
                    --model_name={model_name} --subj=1 \
                    --hidden_dim=1024 --n_blocks=4 --new_test"
    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output
    %load_ext autoreload
    %autoreload 2

model_name: fmri_model_v1_1ses_50ep
--data_path=/workspace/MindEyeV2/MindEyeV2/src/data                     --cache_dir=/workspace/MindEyeV2/MindEyeV2/src/cache                     --model_name=fmri_model_v1_1ses_50ep --subj=1                     --hidden_dim=1024 --n_blocks=4 --new_test


In [None]:
parser = argparse.ArgumentParser(description="model training configuration")
parser.add_argument(
    "--model_name", type=str, default="fmri_model_v1_1ses_50ep",
    help="will load ckpt for model found in ../train_logs/model_name",
)
parser.add_argument(
    "--data_path", type=str, default=os.getcwd(),
    help="path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--cache_dir", type=str, default=os.getcwd(),
    help="path to where misc. files downloaded from huggingface are stored. defaults to current src directory.",
)
parser.add_argument(
    "--subj",type=int, default=1, choices=[1,2,3,4,5,6,7,8],
    help="validate on which subject?",
)
parser.add_argument(
    "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--n_blocks",type=int,default=4,
)
parser.add_argument(
    "--hidden_dim",type=int,default=2048,
)
parser.add_argument(
    "--new_test",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--seed",type=int,default=42,
)
if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)
    
# seed all random functions
utils.seed_everything(seed)

# make output directory
os.makedirs("evals", exist_ok=True)
os.makedirs(f"evals/{model_name}", exist_ok=True)

In [None]:
voxels = {}
f = h5py.File(f'{data_path}/betas_all_subj0{subj}_fp32_renorm.hdf5', 'r')
betas = f['betas'][:]
betas = torch.Tensor(betas).to("cpu")
num_voxels = betas[0].shape[-1]
voxels[f'subj0{subj}'] = betas
print(f"num_voxels for subj0{subj}: {num_voxels}")

if not new_test:
    if subj == 3:
        num_test = 2113
    elif subj == 4:
        num_test = 1985
    elif subj == 6:
        num_test = 2113
    elif subj == 8:
        num_test = 1985
    else:
        num_test = 2770
    test_url = f"{data_path}/wds/subj0{subj}/test/0.tar"
else:
    if subj == 3:
        num_test = 2371
    elif subj == 4:
        num_test = 2188
    elif subj == 6:
        num_test = 2371
    elif subj == 8:
        num_test = 2188
    else:
        num_test = 3000
    test_url = f"{data_path}/wds/subj0{subj}/new_test/0.tar"

print(test_url)
def my_split_by_node(urls): return urls
test_data = wds.WebDataset(test_url, resampled=False, nodesplitter=my_split_by_node) \
    .decode("torch") \
    .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy") \
    .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
test_dl = torch.utils.data.DataLoader(test_data, batch_size=num_test, shuffle=False, drop_last=True, pin_memory=True)
print(f"loaded test dl for subj{subj}!\n")

num_voxels for subj01: 15724
/workspace/MindEyeV2/MindEyeV2/src/data/wds/subj01/new_test/0.tar
Loaded test dl for subj1!



In [None]:
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = f['images']

test_images_idx = []
test_voxels_idx = []
for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
    test_voxels = voxels[f'subj0{subj}'][behav[:, 0, 5].cpu().long()]
    test_voxels_idx = np.append(test_images_idx, behav[:, 0, 5].cpu().numpy())
    test_images_idx = np.append(test_images_idx, behav[:, 0, 0].cpu().numpy())
test_images_idx = test_images_idx.astype(int)
test_voxels_idx = test_voxels_idx.astype(int)

assert (test_i + 1) * num_test == len(test_voxels) == len(test_images_idx)
print(test_i, len(test_voxels), len(test_images_idx), len(np.unique(test_images_idx)))

0 3000 3000 1000


In [None]:
clip_img_embedder = FrozenOpenCLIPImageEmbedder(
    arch="ViT-B-32",
    version='openai',
    output_tokens=True,
    only_tokens=True,
)
clip_img_embedder.to("cpu")
clip_seq_dim = 49
clip_emb_dim = 768

if blurry_recon:
    from diffusers import AutoencoderKL
    autoenc = AutoencoderKL(
        down_block_types=['DownEncoderBlock2D'] * 4,
        up_block_types=['UpDecoderBlock2D'] * 4,
        block_out_channels=[128, 256, 512, 512],
        layers_per_block=2,
        sample_size=224,
    )
    ckpt = torch.load(f'{cache_dir}/sd_image_var_autoenc.pth')
    autoenc.load_state_dict(ckpt)
    autoenc.eval()
    autoenc.requires_grad_(False)
    autoenc.to("cpu")
    utils.count_params(autoenc)
    print("loading blurry recon model")

class MindEyeModule(nn.Module):
    def __init__(self):
        super(MindEyeModule, self).__init__()
    def forward(self, x):
        return x

model = MindEyeModule()

class RidgeRegression(torch.nn.Module):
    def __init__(self, input_sizes, out_features): 
        super(RidgeRegression, self).__init__()
        self.out_features = out_features
        self.linears = torch.nn.ModuleList([
            torch.nn.Linear(input_size, out_features) for input_size in input_sizes
        ])
    def forward(self, x, subj_idx):
        out = self.linears[subj_idx](x[:,0]).unsqueeze(1)
        return out

model.ridge = RidgeRegression([num_voxels], out_features=hidden_dim)

from diffusers.models.autoencoders.vae import Decoder
from models import BrainNetwork
model.backbone = BrainNetwork(h=hidden_dim, in_dim=hidden_dim, seq_len=1, 
                              clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim)
utils.count_params(model.ridge)
utils.count_params(model.backbone)
utils.count_params(model)

out_dim = clip_emb_dim
depth = 6
dim_head = 52
heads = clip_emb_dim // 52
timesteps = 100

prior_network = PriorNetwork(
    dim=out_dim,
    depth=depth,
    dim_head=dim_head,
    heads=heads,
    causal=False,
    num_tokens=clip_seq_dim,
    learned_query_mode="pos_emb"
)

model.diffusion_prior = BrainDiffusionPrior(
    net=prior_network,
    image_embed_dim=out_dim,
    condition_on_text_encodings=False,
    timesteps=timesteps,
    cond_drop_prob=0.2,
    image_embed_scale=None,
)
model.to("cpu")

utils.count_params(model.diffusion_prior)
utils.count_params(model)

outdir = os.path.abspath(f'../train_logs/{model_name}')
print(f"\n---loading {outdir}/{tag}.pth ckpt---\n")
try:
    checkpoint = torch.load(outdir+f'/{tag}.pth', map_location='cpu')
    state_dict = checkpoint['model_state_dict']
    model.load_state_dict(state_dict, strict=False)
    del checkpoint
except:
    import deepspeed
    state_dict = deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir=outdir, tag=tag)
    model.load_state_dict(state_dict, strict=False)
    del state_dict
print("ckpt loaded!")

param counts:
83,653,863 total
0 trainable
Loading blurry recon model
param counts:
16,102,400 total
16,102,400 trainable
param counts:
50,076,528 total
50,076,528 trainable
param counts:
66,178,928 total
66,178,928 trainable
param counts:
55,096,640 total
55,096,624 trainable
param counts:
121,275,568 total
121,275,552 trainable

---loading /workspace/MindEyeV2/MindEyeV2/train_logs/fmri_model_v1_1ses_50ep/last.pth ckpt---

ckpt loaded!


In [None]:
# %%
# setup text caption networks
from transformers import AutoProcessor, AutoModelForCausalLM
from modeling_git import GitForCausalLMClipEmb
processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
clip_text_model = GitForCausalLMClipEmb.from_pretrained("microsoft/git-base-coco")
clip_text_model.to("cpu") # if you get OOM running this script, you can switch this to cpu and lower minibatch_size to 4
clip_text_model.eval().requires_grad_(False)
clip_text_seq_dim = 49
clip_text_emb_dim = 768

class CLIPConverter(torch.nn.Module):
    def __init__(self):
        super(CLIPConverter, self).__init__()
        self.linear1 = nn.Linear(clip_seq_dim, clip_text_seq_dim)
        self.linear2 = nn.Linear(clip_emb_dim, clip_text_emb_dim)
    def forward(self, x):
        x = x.permute(0,2,1)
        x = self.linear1(x)
        x = self.linear2(x.permute(0,2,1))
        return x

clip_convert = CLIPConverter()

expected_ckpt_path = os.path.join(cache_dir, "epoch10.pth")
print(f"\n---attempting to load {outdir}/{tag}.pth checkpoint---\n")
checkpoint_path = os.path.join(outdir, f"{tag}.pth")

if not os.path.exists(outdir):
    print(f"WARNING: Directory {outdir} doesn't exist! Creating it...")
    os.makedirs(outdir, exist_ok=True)

if not os.path.exists(checkpoint_path):
    print(f"WARNING: Checkpoint file {checkpoint_path} not found!")
    print("Continuing without loading a model checkpoint. Results may not be meaningful.")
else:
    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        state_dict = checkpoint['model_state_dict']
        model.load_state_dict(state_dict, strict=False)
        del checkpoint
        print("Checkpoint loaded successfully!")
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        try:
            import deepspeed
            if os.path.exists(os.path.join(outdir, tag)):
                state_dict = deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(
                    checkpoint_dir=outdir, tag=tag)
                model.load_state_dict(state_dict, strict=False)
                del state_dict
                print("DeepSpeed checkpoint loaded successfully!")
            else:
                print(f"DeepSpeed checkpoint directory {os.path.join(outdir, tag)} not found!")
        except Exception as deep_e:
            print(f"DeepSpeed loading also failed: {deep_e}")

  return self.fget.__get__(instance, owner)()



---attempting to load /workspace/MindEyeV2/MindEyeV2/train_logs/fmri_model_v1_1ses_50ep/last.pth checkpoint---

Checkpoint loaded successfully!


In [None]:
# prep unCLIP

config = OmegaConf.load("generative_models/configs/unclip6.yaml")
config = OmegaConf.to_container(config, resolve=True)
unclip_params = config["model"]["params"]
network_config = unclip_params["network_config"]
denoiser_config = unclip_params["denoiser_config"]
first_stage_config = unclip_params["first_stage_config"]
conditioner_config = unclip_params["conditioner_config"]
sampler_config = unclip_params["sampler_config"]
scale_factor = unclip_params["scale_factor"]
disable_first_stage_autocast = unclip_params["disable_first_stage_autocast"]
offset_noise_level = unclip_params["loss_fn_config"]["params"]["offset_noise_level"]

first_stage_config['target'] = 'sgm.models.autoencoder.AutoencoderKL'
sampler_config['params']['num_steps'] = 38

# ───── create the engine exactly as before ─────────────────────────
diffusion_engine = DiffusionEngine(
    network_config=network_config,
    denoiser_config=denoiser_config,
    first_stage_config=first_stage_config,
    conditioner_config=conditioner_config,
    sampler_config=sampler_config,
    scale_factor=scale_factor,
    disable_first_stage_autocast=disable_first_stage_autocast,
)

# NEW ↓ — cast weights to fp16 to cut memory in half
# diffusion_engine.half()

# set to inference and put on CPU
diffusion_engine.eval().requires_grad_(False)
diffusion_engine.to("cpu")

# With these lines
ckpt_path = os.path.join(cache_dir, "unclip6_epoch0_step110000.ckpt")
print(f"Looking for unCLIP checkpoint at: {ckpt_path}")

if not os.path.exists(ckpt_path):
    print(f"ERROR: unCLIP checkpoint not found at {ckpt_path}")
    print("You need to download the unCLIP model checkpoint first.")
    print("This is typically available from Hugging Face or the model creator's repository.")
    print("After downloading, place it in your cache directory, which is set to:")
    print(f"  {cache_dir}")
    os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
    raise FileNotFoundError(f"Missing required checkpoint: {ckpt_path}")
else:
    print(f"Found unCLIP checkpoint at: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location='cpu')
    diffusion_engine.load_state_dict(ckpt['state_dict'])



Initialized embedder #0: FrozenOpenCLIPImageEmbedder with 1909889025 params. Trainable: False
Initialized embedder #1: ConcatTimestepEmbedderND with 0 params. Trainable: False
Initialized embedder #2: ConcatTimestepEmbedderND with 0 params. Trainable: False
Looking for unCLIP checkpoint at: /workspace/MindEyeV2/MindEyeV2/src/cache/unclip6_epoch0_step110000.ckpt
Found unCLIP checkpoint at: /workspace/MindEyeV2/MindEyeV2/src/cache/unclip6_epoch0_step110000.ckpt


In [None]:
# --- Ensure blurry_recon argument is used ---
print(f"Blurry Reconstruction Flag (from args): {blurry_recon}")

# (vector_suffix inference logic - keep the corrected version)
try:
    embed_dim = None; embedder_info = []
    if hasattr(diffusion_engine, 'conditioner') and hasattr(diffusion_engine.conditioner, 'embedders'):
        print("Available conditioner embedders:")
        for i, emb in enumerate(diffusion_engine.conditioner.embedders):
            emb_type = type(emb); embedder_info.append(f"  Index: {i}, Type: {emb_type}")
            if isinstance(emb, FrozenOpenCLIPEmbedder2):
                proj_layer = getattr(getattr(emb, 'model', None), 'text_projection', None)
                if proj_layer is not None and hasattr(proj_layer, 'shape'):
                    embed_dim = proj_layer.shape[-1]
                    embedder_info[-1] += f" -> Found embed_dim: {embed_dim} (Preferred)"
                    break
                else:
                    embedder_info[-1] += " -> Warning: Could not get text_projection shape"
            elif isinstance(emb, FrozenOpenCLIPImageEmbedder):
                found_dim = getattr(emb, 'output_dim', getattr(getattr(emb, 'model', None), 'output_dim', None))
                if found_dim and embed_dim is None:
                    embed_dim = found_dim
                    embedder_info[-1] += f" -> Found embed_dim: {found_dim}"
    for info in embedder_info: print(info)
    vector_suffix_shape = (1, embed_dim) if isinstance(embed_dim, int) else (1, 1024)
    print(f"Using vector_suffix_shape: {vector_suffix_shape}")
    vector_suffix = torch.zeros(vector_suffix_shape, device='cpu', dtype=torch.float32)
    print(f"Created placeholder vector_suffix with shape: {vector_suffix.shape}. Needs verification!")
except Exception as e:
    print(f"Error inferring vector_suffix shape: {e}. Using default (1, 1024).")
    vector_suffix = torch.zeros(1, 1024, device='cpu', dtype=torch.float32)

# get all reconstructions
# get all reconstructions
model.to("cpu")
model.eval().requires_grad_(False)

# --- Load Autoencoder Conditionally ---
autoenc = None # Ensure initialized to None
if blurry_recon:
    print("Attempting to load Autoencoder for Blurry Recon...")
    try:
        from diffusers import AutoencoderKL
        autoenc = AutoencoderKL(
             down_block_types=['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],
             up_block_types=['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],
             block_out_channels=[128, 256, 512, 512], layers_per_block=2, sample_size=256,
         )
        ckpt_path = f'{cache_dir}/sd_image_var_autoenc.pth'
        if not os.path.exists(ckpt_path):
             print(f"ERROR: Autoencoder checkpoint file not found at {ckpt_path}. Disabling blurry recon.")
             blurry_recon = False # Turn off flag if file missing
             autoenc = None
        else:
             ckpt = torch.load(ckpt_path, map_location='cpu')
             autoenc.load_state_dict(ckpt)
             print(f"Autoencoder loaded successfully from {ckpt_path}")
             autoenc.eval().requires_grad_(False).to("cpu") # Keep on CPU initially
             utils.count_params(autoenc)
    except Exception as e:
        print(f"ERROR loading autoencoder: {e}. Disabling blurry recon.")
        blurry_recon = False # Turn off flag on error
        autoenc = None
else:
    print("Blurry reconstruction disabled by args, skipping Autoencoder load.")

# Keep other models on CPU initially
clip_text_model.to("cpu")
diffusion_engine.to("cpu")

# Initialize result accumulators
all_blurryrecons = None
all_recons = None
all_predcaptions = [] # Use a standard list
all_clipvoxels = None

# Training Loop Settings
minibatch_size = 1
num_samples_per_image = 1
assert num_samples_per_image == 1
if utils.is_interactive(): plotting=False
dtype = torch.float16

# Check BrainNetwork internal flag
if hasattr(model, 'backbone') and hasattr(model.backbone, 'blurry_recon'):
     print(f"Loaded model.backbone internal blurry_recon flag: {model.backbone.blurry_recon}")
else: print("Warning: Cannot check internal blurry_recon flag.")

with torch.no_grad():
    test_images_idx_unique = np.unique(test_images_idx)
    test_images_idx_unique = test_images_idx_unique[:10] # Limit for debugging

    print(f"Processing {len(test_images_idx_unique)} unique images...")
    for batch_start in tqdm(range(0, len(test_images_idx_unique), minibatch_size)):
        batch_end = batch_start + minibatch_size
        uniq_imgs = test_images_idx_unique[batch_start:batch_end]

        # --- Voxel Preparation (CPU) ---
        voxel_list = []
        for uniq_img in uniq_imgs:
            locs = np.where(test_images_idx == uniq_img)[0]
            if len(locs) == 1: locs = np.repeat(locs, 3)
            elif len(locs) == 2: locs = np.concatenate([locs[[0]], locs])[:3]
            elif len(locs) > 3: locs = locs[:3]
            elif len(locs) < 1: continue
            assert len(locs) == 3
            max_voxel_idx = len(test_voxels) - 1
            valid_locs = [l for l in locs if l <= max_voxel_idx]
            if len(valid_locs) != 3: continue
            voxel_list.append(test_voxels[None, valid_locs])
        if not voxel_list: continue
        voxel = torch.cat(voxel_list, dim=0).to("cpu")

        # --- MindEye Processing (CPU) ---
        all_backbone_reps = []
        all_clip_voxels_reps = []
        all_blurry_enc_reps = []
        for rep in range(3):
            voxel_ridge = model.ridge(voxel[:, [rep]], 0)
            backbone_out = model.backbone(voxel_ridge)
            # Robust Output Handling
            backbone_rep, clip_voxels_rep, blurry_latent_rep = None, None, None
            if isinstance(backbone_out, (list, tuple)) and len(backbone_out) >= 2:
                 backbone_rep = backbone_out[0]
                 clip_voxels_rep = backbone_out[1]
                 # --- Get Blurry Component Correctly ---
                 if len(backbone_out) >= 3:
                     blurry_component = backbone_out[2]
                     # Check if the component itself is the tensor or a tuple containing it
                     if isinstance(blurry_component, torch.Tensor):
                          blurry_latent_rep = blurry_component
                     elif isinstance(blurry_component, tuple) and len(blurry_component) > 0 and isinstance(blurry_component[0], torch.Tensor):
                          # Assuming the first element of the tuple is the desired latent
                          blurry_latent_rep = blurry_component[0]
                     else:
                          blurry_latent_rep = None
                 else:
                      blurry_latent_rep = None # No third element returned
            else:
                 print("Warning: Unexpected output structure from model.backbone.")
                 backbone_rep = backbone_out
                 clip_voxels_rep = torch.zeros_like(backbone_rep) if isinstance(backbone_rep, torch.Tensor) else None
                 blurry_latent_rep = None
            all_backbone_reps.append(backbone_rep)
            all_clip_voxels_reps.append(clip_voxels_rep)
            all_blurry_enc_reps.append(blurry_latent_rep)

        # Average - Check for None
        backbone = torch.mean(torch.stack([t.cpu() for t in all_backbone_reps if t is not None]), dim=0) if all(t is not None for t in all_backbone_reps) else None
        clip_voxels = torch.mean(torch.stack([t.cpu() for t in all_clip_voxels_reps if t is not None]), dim=0) if all(t is not None for t in all_clip_voxels_reps) else None
        blurry_image_enc = None
        if blurry_recon and all(isinstance(t, torch.Tensor) for t in all_blurry_enc_reps):
            try:
                 blurry_tensors_for_stacking = [t.cpu() for t in all_blurry_enc_reps]
                 stacked_blurry_encs = torch.stack(blurry_tensors_for_stacking, dim=1) # Stack along dim 1 (reps)
                 blurry_image_enc = torch.mean(stacked_blurry_encs, dim=1) # Average over dim 1
                 print(f"Averaged blurry latent shape: {blurry_image_enc.shape}")
            except Exception as e: print(f"Error averaging blurry latents: {e}. Setting to None.")
        elif blurry_recon: print("Warning: Blurry latents not valid tensors.")

        if backbone is None or clip_voxels is None: print("Skipping batch due to None backbone/clip_voxels."); continue

        # Store clipvoxels
        if all_clipvoxels is None: all_clipvoxels = clip_voxels.cpu()
        else: all_clipvoxels = torch.vstack((all_clipvoxels, clip_voxels.cpu()))

        # Diffusion Prior
        prior_out = model.diffusion_prior.p_sample_loop(backbone.shape, text_cond=dict(text_embed=backbone), cond_scale=1., timesteps=20)

        # Caption Generation
        try:
            pred_caption_emb = clip_convert(prior_out)
            generated_ids = clip_text_model.generate(pixel_values=pred_caption_emb, max_length=20)
            generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)
            all_predcaptions.extend(generated_caption)
            print(f"Batch {batch_start//minibatch_size}: {generated_caption}")
        except Exception as caption_e:
            print(f"Error during caption generation: {caption_e}")
            all_predcaptions.extend(["<caption_error>"] * len(voxel))

        # --- unCLIP Reconstruction (Attempt on GPU) ---
        vector_suffix_gpu = None; batch_recons_tensor = None; ctx = None; samples = None
        reconstruction_succeeded = False
        try:
            print("Attempting reconstruction: Moving diffusion_engine to GPU...")
            diffusion_engine.to(device)
            if hasattr(diffusion_engine.denoiser, 'sigmas') and diffusion_engine.denoiser.sigmas is not None:
                 diffusion_engine.denoiser.sigmas = diffusion_engine.denoiser.sigmas.to(device)
            current_batch_size = len(voxel)
            vs_repeated = vector_suffix.repeat(current_batch_size, 1) if vector_suffix.shape[0] != current_batch_size else vector_suffix
            vector_suffix_gpu = vs_repeated.to(device)
            print("Generating reconstructions...")
            batch_recons = []
            with torch.cuda.amp.autocast(dtype=dtype, enabled=(dtype != torch.float32)):
                 for i in range(len(voxel)):
                     ctx = F.pad(prior_out[[i]], (0, 1664 - clip_emb_dim)).to(device)
                     samples = utils.unclip_recon(ctx, diffusion_engine, vector_suffix_gpu[[i]], num_samples=num_samples_per_image)
                     batch_recons.append(samples.cpu())
                     del ctx; del samples
            if batch_recons:
                 batch_recons_tensor = torch.cat(batch_recons, dim=0)
                 if all_recons is None: all_recons = batch_recons_tensor
                 else: all_recons = torch.vstack((all_recons, batch_recons_tensor))
                 reconstruction_succeeded = True
            del batch_recons;
            if batch_recons_tensor is not None: del batch_recons_tensor
        except Exception as e:
            print(f"Error during reconstruction: {e}")
            if 'out of memory' in str(e).lower(): print("\n--- CUDA Out of Memory ---\n")
            reconstruction_succeeded = False
        finally:
            print("Moving diffusion_engine back to CPU...")
            diffusion_engine.to("cpu")
            if vector_suffix_gpu is not None: del vector_suffix_gpu
            torch.cuda.empty_cache()
        if reconstruction_succeeded: print("Reconstruction successful.")

        # --- Blurry Reconstruction (Attempt on GPU IF ENABLED and Latent Available) ---
        if blurry_recon and autoenc is not None and blurry_image_enc is not None:
             blurry_image_enc_gpu = None; blurred_image_gpu = None; blurred_image = None
             try:
                print("Attempting blurry reconstruction: Moving autoenc to GPU...")
                autoenc.to(device)
                blurry_image_enc_gpu = blurry_image_enc.to(device)
                print(f"Shape feeding into blurry decode: {blurry_image_enc_gpu.shape}")

                print("Decoding blurry images...")
                with torch.cuda.amp.autocast(dtype=dtype, enabled=(dtype != torch.float32)):
                     scale_factor_vae = getattr(getattr(autoenc, 'config', None), 'scaling_factor', 0.18215)
                     blurred_image_gpu = (autoenc.decode(blurry_image_enc_gpu / scale_factor_vae).sample / 2 + 0.5).clamp(0, 1)

                blurred_image = blurred_image_gpu.cpu()
                print("Blurry decoding done, images moved to CPU.")

                if all_blurryrecons is None: all_blurryrecons = blurred_image
                else: all_blurryrecons = torch.vstack((all_blurryrecons, blurred_image))

             except Exception as e:
                 print(f"Error during blurry reconstruction: {e}")
                 if 'Expected' in str(e) and 'input' in str(e): print(">>> Shape mismatch feeding into autoenc.decode.")
                 if 'out of memory' in str(e).lower(): print("\n--- CUDA Out of Memory during Blurry Recon ---\n")
             finally:
                print("Moving autoenc back to CPU...")
                autoenc.to("cpu")
                if blurry_image_enc_gpu is not None: del blurry_image_enc_gpu
                if blurred_image_gpu is not None: del blurred_image_gpu
                torch.cuda.empty_cache()
        elif blurry_recon:
             print("Skipping blurry reconstruction for this batch (autoenc or latent missing/invalid).")

        # End of loop cleanup
        del voxel, voxel_list, voxel_ridge, backbone, clip_voxels, blurry_image_enc
        del all_backbone_reps, all_clip_voxels_reps, all_blurry_enc_reps
        del prior_out

with torch.no_grad():
    test_images_idx_unique = np.unique(test_images_idx)
    test_images_idx_unique = test_images_idx_unique[:10] # Limit for debugging

    print(f"Processing {len(test_images_idx_unique)} unique images...")
    for batch_start in tqdm(range(0, len(test_images_idx_unique), minibatch_size)):
        batch_end = batch_start + minibatch_size
        uniq_imgs = test_images_idx_unique[batch_start:batch_end]

        # --- Voxel Preparation (CPU) ---
        voxel_list = []
        for uniq_img in uniq_imgs:
            locs = np.where(test_images_idx == uniq_img)[0]
            if len(locs) == 1: locs = np.repeat(locs, 3)
            elif len(locs) == 2: locs = np.concatenate([locs[[0]], locs])[:3]
            elif len(locs) > 3: locs = locs[:3]
            elif len(locs) < 1: continue
            assert len(locs) == 3
            max_voxel_idx = len(test_voxels) - 1
            valid_locs = [l for l in locs if l <= max_voxel_idx]
            if len(valid_locs) != 3: continue
            voxel_list.append(test_voxels[None, valid_locs])
        if not voxel_list: continue
        voxel = torch.cat(voxel_list, dim=0).to("cpu")

        # --- MindEye Processing (CPU) ---
        all_backbone_reps = []
        all_clip_voxels_reps = []
        all_blurry_enc_reps = []
        for rep in range(3):
            voxel_ridge = model.ridge(voxel[:, [rep]], 0)
            backbone_out = model.backbone(voxel_ridge)
            # Robust Output Handling
            backbone_rep, clip_voxels_rep, blurry_latent_rep = None, None, None
            if isinstance(backbone_out, (list, tuple)) and len(backbone_out) >= 2:
                 backbone_rep = backbone_out[0]
                 clip_voxels_rep = backbone_out[1]
                 # --- Get Blurry Component Correctly ---
                 if len(backbone_out) >= 3:
                     blurry_component = backbone_out[2]
                     # Check if the component itself is the tensor or a tuple containing it
                     if isinstance(blurry_component, torch.Tensor):
                          blurry_latent_rep = blurry_component
                     elif isinstance(blurry_component, tuple) and len(blurry_component) > 0 and isinstance(blurry_component[0], torch.Tensor):
                          # Assuming the first element of the tuple is the desired latent
                          blurry_latent_rep = blurry_component[0]
                     else:
                          blurry_latent_rep = None
                 else:
                      blurry_latent_rep = None # No third element returned
            else:
                 print("Warning: Unexpected output structure from model.backbone.")
                 backbone_rep = backbone_out
                 clip_voxels_rep = torch.zeros_like(backbone_rep) if isinstance(backbone_rep, torch.Tensor) else None
                 blurry_latent_rep = None
            all_backbone_reps.append(backbone_rep)
            all_clip_voxels_reps.append(clip_voxels_rep)
            all_blurry_enc_reps.append(blurry_latent_rep)

        # Average - Check for None
        backbone = torch.mean(torch.stack([t.cpu() for t in all_backbone_reps if t is not None]), dim=0) if all(t is not None for t in all_backbone_reps) else None
        clip_voxels = torch.mean(torch.stack([t.cpu() for t in all_clip_voxels_reps if t is not None]), dim=0) if all(t is not None for t in all_clip_voxels_reps) else None
        blurry_image_enc = None
        if blurry_recon and all(isinstance(t, torch.Tensor) for t in all_blurry_enc_reps):
            try:
                 blurry_tensors_for_stacking = [t.cpu() for t in all_blurry_enc_reps]
                 stacked_blurry_encs = torch.stack(blurry_tensors_for_stacking, dim=1) # Stack along dim 1 (reps)
                 blurry_image_enc = torch.mean(stacked_blurry_encs, dim=1) # Average over dim 1
                 print(f"Averaged blurry latent shape: {blurry_image_enc.shape}")
            except Exception as e: print(f"Error averaging blurry latents: {e}. Setting to None.")
        elif blurry_recon: print("Warning: Blurry latents not valid tensors.")

        if backbone is None or clip_voxels is None: print("Skipping batch due to None backbone/clip_voxels."); continue

        # Store clipvoxels
        if all_clipvoxels is None: all_clipvoxels = clip_voxels.cpu()
        else: all_clipvoxels = torch.vstack((all_clipvoxels, clip_voxels.cpu()))

        # Diffusion Prior
        prior_out = model.diffusion_prior.p_sample_loop(backbone.shape, text_cond=dict(text_embed=backbone), cond_scale=1., timesteps=20)

        # Caption Generation
        try:
            pred_caption_emb = clip_convert(prior_out)
            generated_ids = clip_text_model.generate(pixel_values=pred_caption_emb, max_length=20)
            generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)
            all_predcaptions.extend(generated_caption)
            print(f"Batch {batch_start//minibatch_size}: {generated_caption}")
        except Exception as caption_e:
            print(f"Error during caption generation: {caption_e}")
            all_predcaptions.extend(["<caption_error>"] * len(voxel))

        # --- unCLIP Reconstruction (Attempt on GPU) ---
        vector_suffix_gpu = None; batch_recons_tensor = None; ctx = None; samples = None
        reconstruction_succeeded = False
        try:
            print("Attempting reconstruction: Moving diffusion_engine to GPU...")
            diffusion_engine.to(device)
            if hasattr(diffusion_engine.denoiser, 'sigmas') and diffusion_engine.denoiser.sigmas is not None:
                 diffusion_engine.denoiser.sigmas = diffusion_engine.denoiser.sigmas.to(device)
            current_batch_size = len(voxel)
            vs_repeated = vector_suffix.repeat(current_batch_size, 1) if vector_suffix.shape[0] != current_batch_size else vector_suffix
            vector_suffix_gpu = vs_repeated.to(device)
            print("Generating reconstructions...")
            batch_recons = []
            with torch.cuda.amp.autocast(dtype=dtype, enabled=(dtype != torch.float32)):
                 for i in range(len(voxel)):
                     ctx = F.pad(prior_out[[i]], (0, 1664 - clip_emb_dim)).to(device)
                     samples = utils.unclip_recon(ctx, diffusion_engine, vector_suffix_gpu[[i]], num_samples=num_samples_per_image)
                     batch_recons.append(samples.cpu())
                     del ctx; del samples
            if batch_recons:
                 batch_recons_tensor = torch.cat(batch_recons, dim=0)
                 if all_recons is None: all_recons = batch_recons_tensor
                 else: all_recons = torch.vstack((all_recons, batch_recons_tensor))
                 reconstruction_succeeded = True
            del batch_recons;
            if batch_recons_tensor is not None: del batch_recons_tensor
        except Exception as e:
            print(f"Error during reconstruction: {e}")
            if 'out of memory' in str(e).lower(): print("\n--- CUDA Out of Memory ---\n")
            reconstruction_succeeded = False
        finally:
            print("Moving diffusion_engine back to CPU...")
            diffusion_engine.to("cpu")
            if vector_suffix_gpu is not None: del vector_suffix_gpu
            torch.cuda.empty_cache()
        if reconstruction_succeeded: print("Reconstruction successful.")

        # --- Blurry Reconstruction (Attempt on GPU IF ENABLED and Latent Available) ---
        if blurry_recon and autoenc is not None and blurry_image_enc is not None:
             blurry_image_enc_gpu = None; blurred_image_gpu = None; blurred_image = None
             try:
                print("Attempting blurry reconstruction: Moving autoenc to GPU...")
                autoenc.to(device)
                blurry_image_enc_gpu = blurry_image_enc.to(device)
                print(f"Shape feeding into blurry decode: {blurry_image_enc_gpu.shape}")

                print("Decoding blurry images...")
                with torch.cuda.amp.autocast(dtype=dtype, enabled=(dtype != torch.float32)):
                     scale_factor_vae = getattr(getattr(autoenc, 'config', None), 'scaling_factor', 0.18215)
                     blurred_image_gpu = (autoenc.decode(blurry_image_enc_gpu / scale_factor_vae).sample / 2 + 0.5).clamp(0, 1)

                blurred_image = blurred_image_gpu.cpu()
                print("Blurry decoding done, images moved to CPU.")

                if all_blurryrecons is None: all_blurryrecons = blurred_image
                else: all_blurryrecons = torch.vstack((all_blurryrecons, blurred_image))

             except Exception as e:
                 print(f"Error during blurry reconstruction: {e}")
                 if 'Expected' in str(e) and 'input' in str(e): print(">>> Shape mismatch feeding into autoenc.decode.")
                 if 'out of memory' in str(e).lower(): print("\n--- CUDA Out of Memory during Blurry Recon ---\n")
             finally:
                print("Moving autoenc back to CPU...")
                autoenc.to("cpu")
                if blurry_image_enc_gpu is not None: del blurry_image_enc_gpu
                if blurred_image_gpu is not None: del blurred_image_gpu
                torch.cuda.empty_cache()
        elif blurry_recon:
             print("Skipping blurry reconstruction for this batch (autoenc or latent missing/invalid).")

        # End of loop cleanup
        del voxel, voxel_list, voxel_ridge, backbone, clip_voxels, blurry_image_enc
        del all_backbone_reps, all_clip_voxels_reps, all_blurry_enc_reps
        del prior_out