In [1]:
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds
import torch.nn.functional as F   

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator

# SDXL unCLIP requires code from https://github.com/Stability-AI/generative-models/tree/main
sys.path.append('generative_models/')
import sgm
from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder, FrozenOpenCLIPEmbedder2
from generative_models.sgm.models.diffusion import DiffusionEngine
from generative_models.sgm.util import append_dims
from omegaconf import OmegaConf

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = False

# custom functions #
import utils
from models import *

accelerator = Accelerator(split_batches=True, mixed_precision='fp16')
device = accelerator.device
print("device:",device)
tag='last'



device: cuda


In [2]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    model_name = "fmri_model_v1_1ses_50ep"
    print("model_name:", model_name)

    # other variables can be specified in the following string:
    jupyter_args = f"--data_path={os.getcwd()}/data \
                    --cache_dir={os.getcwd()}/cache \
                    --model_name={model_name} --subj=1 \
                    --hidden_dim=1024 --n_blocks=4 --new_test"
    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

model_name: fmri_model_v1_1ses_50ep
--data_path=/workspace/MindEyeV2/MindEyeV2/src/data                     --cache_dir=/workspace/MindEyeV2/MindEyeV2/src/cache                     --model_name=fmri_model_v1_1ses_50ep --subj=1                     --hidden_dim=1024 --n_blocks=4 --new_test


In [3]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="fmri_model_v1_1ses_50ep",
    help="will load ckpt for model found in ../train_logs/model_name",
)
parser.add_argument(
    "--data_path", type=str, default=os.getcwd(),
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--cache_dir", type=str, default=os.getcwd(),
    help="Path to where misc. files downloaded from huggingface are stored. Defaults to current src directory.",
)
parser.add_argument(
    "--subj",type=int, default=1, choices=[1,2,3,4,5,6,7,8],
    help="Validate on which subject?",
)
parser.add_argument(
    "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--n_blocks",type=int,default=4,
)
parser.add_argument(
    "--hidden_dim",type=int,default=2048,
)
parser.add_argument(
    "--new_test",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--seed",type=int,default=42,
)
if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)
    
# seed all random functions
utils.seed_everything(seed)

# make output directory
os.makedirs("evals",exist_ok=True)
os.makedirs(f"evals/{model_name}",exist_ok=True)

In [4]:
voxels = {}
# Load hdf5 data for betas
f = h5py.File(f'{data_path}/betas_all_subj0{subj}_fp32_renorm.hdf5', 'r')
betas = f['betas'][:]
betas = torch.Tensor(betas).to("cpu")
num_voxels = betas[0].shape[-1]
voxels[f'subj0{subj}'] = betas
print(f"num_voxels for subj0{subj}: {num_voxels}")

if not new_test: # using old test set from before full dataset released (used in original MindEye paper)
    if subj==3:
        num_test=2113
    elif subj==4:
        num_test=1985
    elif subj==6:
        num_test=2113
    elif subj==8:
        num_test=1985
    else:
        num_test=2770
    test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
else: # using larger test set from after full dataset released
    if subj==3:
        num_test=2371
    elif subj==4:
        num_test=2188
    elif subj==6:
        num_test=2371
    elif subj==8:
        num_test=2188
    else:
        num_test=3000
    test_url = f"{data_path}/wds/subj0{subj}/new_test/" + "0.tar"
    
print(test_url)
def my_split_by_node(urls): return urls
test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
                    .decode("torch")\
                    .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
                    .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
test_dl = torch.utils.data.DataLoader(test_data, batch_size=num_test, shuffle=False, drop_last=True, pin_memory=True)
print(f"Loaded test dl for subj{subj}!\n")

num_voxels for subj01: 15724
/workspace/MindEyeV2/MindEyeV2/src/data/wds/subj01/new_test/0.tar
Loaded test dl for subj1!



In [5]:
# Prep images but don't load them all to memory
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = f['images']

# Prep test voxels and indices of test images
test_images_idx = []
test_voxels_idx = []
for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
    test_voxels = voxels[f'subj0{subj}'][behav[:,0,5].cpu().long()]
    test_voxels_idx = np.append(test_images_idx, behav[:,0,5].cpu().numpy())
    test_images_idx = np.append(test_images_idx, behav[:,0,0].cpu().numpy())
test_images_idx = test_images_idx.astype(int)
test_voxels_idx = test_voxels_idx.astype(int)

assert (test_i+1) * num_test == len(test_voxels) == len(test_images_idx)
print(test_i, len(test_voxels), len(test_images_idx), len(np.unique(test_images_idx)))

0 3000 3000 1000


In [6]:
clip_img_embedder = FrozenOpenCLIPImageEmbedder(
    # arch="ViT-L-14",
    arch="ViT-B-32",
    version='openai',
    output_tokens=True,
    only_tokens=True,
)
# clip_img_embedder.to("cpu")
clip_img_embedder.to("cpu")
clip_seq_dim = 49
clip_emb_dim = 768

if blurry_recon:
    
    from diffusers import AutoencoderKL
    autoenc = AutoencoderKL(
        down_block_types=['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],
        up_block_types=['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],
        block_out_channels=[128, 256, 512, 512],
        layers_per_block=2,
        sample_size=224,
    )
    ckpt = torch.load(f'{cache_dir}/sd_image_var_autoenc.pth')
    autoenc.load_state_dict(ckpt)
    autoenc.eval()
    autoenc.requires_grad_(False)
    autoenc.to("cpu")
    utils.count_params(autoenc)
    print("Loading blurry recon model")
    
class MindEyeModule(nn.Module):
    def __init__(self):
        super(MindEyeModule, self).__init__()
    def forward(self, x):
        return x
        
model = MindEyeModule()

class RidgeRegression(torch.nn.Module):
    # make sure to add weight_decay when initializing optimizer to enable regularization
    def __init__(self, input_sizes, out_features): 
        super(RidgeRegression, self).__init__()
        self.out_features = out_features
        self.linears = torch.nn.ModuleList([
                torch.nn.Linear(input_size, out_features) for input_size in input_sizes
            ])
    def forward(self, x, subj_idx):
        out = self.linears[subj_idx](x[:,0]).unsqueeze(1)
        return out
        
model.ridge = RidgeRegression([num_voxels], out_features=hidden_dim)

from diffusers.models.autoencoders.vae import Decoder
from models import BrainNetwork
model.backbone = BrainNetwork(h=hidden_dim, in_dim=hidden_dim, seq_len=1, 
                          clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim) 
utils.count_params(model.ridge)
utils.count_params(model.backbone)
utils.count_params(model)

# setup diffusion prior network
out_dim = clip_emb_dim
depth = 6
dim_head = 52
heads = clip_emb_dim//52 # heads * dim_head = clip_emb_dim
timesteps = 100

prior_network = PriorNetwork(
        dim=out_dim,
        depth=depth,
        dim_head=dim_head,
        heads=heads,
        causal=False,
        num_tokens = clip_seq_dim,
        learned_query_mode="pos_emb"
    )

model.diffusion_prior = BrainDiffusionPrior(
    net=prior_network,
    image_embed_dim=out_dim,
    condition_on_text_encodings=False,
    timesteps=timesteps,
    cond_drop_prob=0.2,
    image_embed_scale=None,
)
model.to("cpu")

utils.count_params(model.diffusion_prior)
utils.count_params(model)
# Load pretrained model ckpt
outdir = f"/teamspace/studios/this_studio/MindEyeV2/train_logs/{model_name}"

# With this
outdir = os.path.abspath(f'../train_logs/{model_name}')  # Go up one directory to find train_logs
print(f"\n---loading {outdir}/{tag}.pth ckpt---\n")
try:
    checkpoint = torch.load(outdir+f'/{tag}.pth', map_location='cpu')
    state_dict = checkpoint['model_state_dict']
    model.load_state_dict(state_dict, strict=False)
    del checkpoint
except: # probably ckpt is saved using deepspeed format
    import deepspeed
    state_dict = deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir=outdir, tag=tag)
    model.load_state_dict(state_dict, strict=False)
    del state_dict
print("ckpt loaded!")

param counts:
83,653,863 total
0 trainable
Loading blurry recon model
param counts:
16,102,400 total
16,102,400 trainable
param counts:
50,076,528 total
50,076,528 trainable
param counts:
66,178,928 total
66,178,928 trainable
param counts:
55,096,640 total
55,096,624 trainable
param counts:
121,275,568 total
121,275,552 trainable

---loading /workspace/MindEyeV2/MindEyeV2/train_logs/fmri_model_v1_1ses_50ep/last.pth ckpt---

ckpt loaded!


In [7]:
# %%
# setup text caption networks
from transformers import AutoProcessor, AutoModelForCausalLM
from modeling_git import GitForCausalLMClipEmb
processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
clip_text_model = GitForCausalLMClipEmb.from_pretrained("microsoft/git-base-coco")
clip_text_model.to("cpu") # if you get OOM running this script, you can switch this to cpu and lower minibatch_size to 4
clip_text_model.eval().requires_grad_(False)
clip_text_seq_dim = 49
clip_text_emb_dim = 768

class CLIPConverter(torch.nn.Module):
    def __init__(self):
        super(CLIPConverter, self).__init__()
        self.linear1 = nn.Linear(clip_seq_dim, clip_text_seq_dim)
        self.linear2 = nn.Linear(clip_emb_dim, clip_text_emb_dim)
    def forward(self, x):
        x = x.permute(0,2,1)
        x = self.linear1(x)
        x = self.linear2(x.permute(0,2,1))
        return x

clip_convert = CLIPConverter()

# --- FIX: Change the filename here ---
# Make sure the file 'epoch8.pth' actually exists in your cache_dir
expected_ckpt_path = os.path.join(cache_dir, "epoch10.pth")
print(f"\n---attempting to load {outdir}/{tag}.pth checkpoint---\n")
checkpoint_path = os.path.join(outdir, f"{tag}.pth")

if not os.path.exists(outdir):
    print(f"WARNING: Directory {outdir} doesn't exist! Creating it...")
    os.makedirs(outdir, exist_ok=True)

if not os.path.exists(checkpoint_path):
    print(f"WARNING: Checkpoint file {checkpoint_path} not found!")
    print("Continuing without loading a model checkpoint. Results may not be meaningful.")
else:
    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        state_dict = checkpoint['model_state_dict']
        model.load_state_dict(state_dict, strict=False)
        del checkpoint
        print("Checkpoint loaded successfully!")
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        try:
            import deepspeed
            if os.path.exists(os.path.join(outdir, tag)):
                state_dict = deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(
                    checkpoint_dir=outdir, tag=tag)
                model.load_state_dict(state_dict, strict=False)
                del state_dict
                print("DeepSpeed checkpoint loaded successfully!")
            else:
                print(f"DeepSpeed checkpoint directory {os.path.join(outdir, tag)} not found!")
        except Exception as deep_e:
            print(f"DeepSpeed loading also failed: {deep_e}")

  return self.fget.__get__(instance, owner)()



---attempting to load /workspace/MindEyeV2/MindEyeV2/train_logs/fmri_model_v1_1ses_50ep/last.pth checkpoint---

Checkpoint loaded successfully!


In [8]:
# prep unCLIP

config = OmegaConf.load("generative_models/configs/unclip6.yaml")
config = OmegaConf.to_container(config, resolve=True)
unclip_params = config["model"]["params"]
network_config = unclip_params["network_config"]
denoiser_config = unclip_params["denoiser_config"]
first_stage_config = unclip_params["first_stage_config"]
conditioner_config = unclip_params["conditioner_config"]
sampler_config = unclip_params["sampler_config"]
scale_factor = unclip_params["scale_factor"]
disable_first_stage_autocast = unclip_params["disable_first_stage_autocast"]
offset_noise_level = unclip_params["loss_fn_config"]["params"]["offset_noise_level"]

first_stage_config['target'] = 'sgm.models.autoencoder.AutoencoderKL'
sampler_config['params']['num_steps'] = 38

# ───── create the engine exactly as before ─────────────────────────
diffusion_engine = DiffusionEngine(
    network_config=network_config,
    denoiser_config=denoiser_config,
    first_stage_config=first_stage_config,
    conditioner_config=conditioner_config,
    sampler_config=sampler_config,
    scale_factor=scale_factor,
    disable_first_stage_autocast=disable_first_stage_autocast,
)

# NEW ↓ — cast weights to fp16 to cut memory in half
# diffusion_engine.half()

# set to inference and put on GPU
diffusion_engine.eval().requires_grad_(False)
diffusion_engine.to("cpu")


# With these lines
ckpt_path = os.path.join(cache_dir, "unclip6_epoch0_step110000.ckpt")
print(f"Looking for unCLIP checkpoint at: {ckpt_path}")

if not os.path.exists(ckpt_path):
    print(f"ERROR: unCLIP checkpoint not found at {ckpt_path}")
    print("You need to download the unCLIP model checkpoint first.")
    print("This is typically available from Hugging Face or the model creator's repository.")
    print("After downloading, place it in your cache directory, which is set to:")
    print(f"  {cache_dir}")
    
    # Create required directories
    os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
    
    # Optionally raise an exception to stop execution or provide continuation options
    raise FileNotFoundError(f"Missing required checkpoint: {ckpt_path}")
else:
    print(f"Found unCLIP checkpoint at: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location='cpu')
    diffusion_engine.load_state_dict(ckpt['state_dict'])



Initialized embedder #0: FrozenOpenCLIPImageEmbedder with 1909889025 params. Trainable: False
Initialized embedder #1: ConcatTimestepEmbedderND with 0 params. Trainable: False
Initialized embedder #2: ConcatTimestepEmbedderND with 0 params. Trainable: False
Looking for unCLIP checkpoint at: /workspace/MindEyeV2/MindEyeV2/src/cache/unclip6_epoch0_step110000.ckpt
Found unCLIP checkpoint at: /workspace/MindEyeV2/MindEyeV2/src/cache/unclip6_epoch0_step110000.ckpt


In [9]:
# # %%
# # recon_inference.ipynb - Main processing cell (REVISED AGAIN)

# import inspect
# import os # Make sure os is imported if not already
# import torch # Make sure torch is imported
# import numpy as np # Make sure numpy is imported if not already
# import torch.nn.functional as F # Import functional if needed by utils.unclip_recon
# from tqdm import tqdm # Ensure tqdm is imported
# from torchvision import transforms # Ensure transforms are imported

# # --- Ensure blurry_recon argument is used ---
# print(f"Blurry Reconstruction Flag (from args): {blurry_recon}") # <<< THIS LINE

# # (vector_suffix inference logic - keep the corrected version)
# try:
#     embed_dim = None; embedder_info = []
#     if hasattr(diffusion_engine, 'conditioner') and hasattr(diffusion_engine.conditioner, 'embedders'):
#         print("Available conditioner embedders:")
#         # Correctly iterate over ModuleList
#         for i, emb in enumerate(diffusion_engine.conditioner.embedders): # Use enumerate for ModuleList
#             emb_type = type(emb); embedder_info.append(f"  Index: {i}, Type: {emb_type}")
#             # Simplified checks for common types
#             if isinstance(emb, FrozenOpenCLIPEmbedder2):
#                  # Safely access attributes
#                  proj_layer = getattr(getattr(emb, 'model', None), 'text_projection', None)
#                  if proj_layer is not None and hasattr(proj_layer, 'shape'):
#                       embed_dim = proj_layer.shape[-1]
#                       embedder_info[-1] += f" -> Found embed_dim: {embed_dim} (Preferred)"
#                       break # Found preferred, stop checking
#                  else:
#                       embedder_info[-1] += " -> Warning: Could not get text_projection shape"
#             elif isinstance(emb, FrozenOpenCLIPImageEmbedder):
#                  found_dim = getattr(emb, 'output_dim', getattr(getattr(emb, 'model', None), 'output_dim', None))
#                  if found_dim and embed_dim is None: # Use if no preferred found yet
#                      embed_dim = found_dim
#                      embedder_info[-1] += f" -> Found embed_dim: {found_dim}"
#     for info in embedder_info: print(info)
#     # Ensure embed_dim is an integer before using it for shape
#     vector_suffix_shape = (1, embed_dim) if isinstance(embed_dim, int) else (1, 1024)
#     print(f"Using vector_suffix_shape: {vector_suffix_shape}")
#     vector_suffix = torch.zeros(vector_suffix_shape, device='cpu', dtype=torch.float32)
#     print(f"Created placeholder vector_suffix with shape: {vector_suffix.shape}. Needs verification!")
# except Exception as e:
#     print(f"Error inferring vector_suffix shape: {e}. Using default (1, 1024).")
#     vector_suffix = torch.zeros(1, 1024, device='cpu', dtype=torch.float32)

# # %%
# # get all reconstructions
# model.to("cpu")
# model.eval().requires_grad_(False)

# # --- Load Autoencoder Conditionally ---
# autoenc = None # Ensure initialized to None
# if blurry_recon:
#     print("Attempting to load Autoencoder for Blurry Recon...")
#     try:
#         from diffusers import AutoencoderKL
#         autoenc = AutoencoderKL(
#              down_block_types=['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],
#              up_block_types=['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],
#              block_out_channels=[128, 256, 512, 512], layers_per_block=2, sample_size=256,
#          )
#         ckpt_path = f'{cache_dir}/sd_image_var_autoenc.pth'
#         if not os.path.exists(ckpt_path):
#              print(f"ERROR: Autoencoder checkpoint file not found at {ckpt_path}. Disabling blurry recon.")
#              blurry_recon = False # Turn off flag if file missing
#              autoenc = None
#         else:
#              ckpt = torch.load(ckpt_path, map_location='cpu')
#              autoenc.load_state_dict(ckpt)
#              print(f"Autoencoder loaded successfully from {ckpt_path}")
#              autoenc.eval().requires_grad_(False).to("cpu") # Keep on CPU initially
#              utils.count_params(autoenc)
#     except Exception as e:
#         print(f"ERROR loading autoencoder: {e}. Disabling blurry recon.")
#         blurry_recon = False # Turn off flag on error
#         autoenc = None
# else:
#     print("Blurry reconstruction disabled by args, skipping Autoencoder load.")

# # Keep other models on CPU initially
# clip_text_model.to("cpu")
# diffusion_engine.to("cpu")

# # Initialize result accumulators
# all_blurryrecons = None
# all_recons = None
# all_predcaptions = [] # Use a standard list
# all_clipvoxels = None

# # Training Loop Settings
# minibatch_size = 1
# num_samples_per_image = 1
# assert num_samples_per_image == 1
# if utils.is_interactive(): plotting=False
# dtype = torch.float16

# # Check BrainNetwork internal flag
# if hasattr(model, 'backbone') and hasattr(model.backbone, 'blurry_recon'):
#      print(f"Loaded model.backbone internal blurry_recon flag: {model.backbone.blurry_recon}")
# else: print("Warning: Cannot check internal blurry_recon flag.")

# with torch.no_grad():
#     test_images_idx_unique = np.unique(test_images_idx)
#     test_images_idx_unique = test_images_idx_unique[:10] # Limit for debugging

#     print(f"Processing {len(test_images_idx_unique)} unique images...")
#     for batch_start in tqdm(range(0, len(test_images_idx_unique), minibatch_size)):
#         batch_end = batch_start + minibatch_size
#         uniq_imgs = test_images_idx_unique[batch_start:batch_end]

#         # --- Voxel Preparation (CPU) ---
#         voxel_list = []
#         for uniq_img in uniq_imgs:
#             locs = np.where(test_images_idx == uniq_img)[0]
#             if len(locs) == 1: locs = np.repeat(locs, 3)
#             elif len(locs) == 2: locs = np.concatenate([locs[[0]], locs])[:3]
#             elif len(locs) > 3: locs = locs[:3]
#             elif len(locs) < 1: continue
#             assert len(locs) == 3
#             max_voxel_idx = len(test_voxels) - 1
#             valid_locs = [l for l in locs if l <= max_voxel_idx]
#             if len(valid_locs) != 3: continue
#             voxel_list.append(test_voxels[None, valid_locs])
#         if not voxel_list: continue
#         voxel = torch.cat(voxel_list, dim=0).to("cpu")

#         # --- MindEye Processing (CPU) ---
#         all_backbone_reps = []
#         all_clip_voxels_reps = []
#         all_blurry_enc_reps = []
#         for rep in range(3):
#             voxel_ridge = model.ridge(voxel[:, [rep]], 0)
#             backbone_out = model.backbone(voxel_ridge)
#             # Robust Output Handling
#             backbone_rep, clip_voxels_rep, blurry_latent_rep = None, None, None
#             if isinstance(backbone_out, (list, tuple)) and len(backbone_out) >= 2:
#                  backbone_rep = backbone_out[0]
#                  clip_voxels_rep = backbone_out[1]
#                  # --- Get Blurry Component Correctly ---
#                  if len(backbone_out) >= 3:
#                      blurry_component = backbone_out[2]
#                      # Check if the component itself is the tensor or a tuple containing it
#                      if isinstance(blurry_component, torch.Tensor):
#                           blurry_latent_rep = blurry_component
#                      elif isinstance(blurry_component, tuple) and len(blurry_component) > 0 and isinstance(blurry_component[0], torch.Tensor):
#                           # Assuming the first element of the tuple is the desired latent
#                           blurry_latent_rep = blurry_component[0]
#                           # print(f"[Debug] Extracted blurry latent from tuple (rep {rep})") # Optional debug
#                      else:
#                           # Handle unexpected format or None case
#                           blurry_latent_rep = None
#                           # print(f"[Debug] Blurry component was not tensor or valid tuple (rep {rep})") # Optional debug
#                  else:
#                       blurry_latent_rep = None # No third element returned
#             else:
#                  print("Warning: Unexpected output structure from model.backbone.")
#                  backbone_rep = backbone_out
#                  clip_voxels_rep = torch.zeros_like(backbone_rep) if isinstance(backbone_rep, torch.Tensor) else None
#                  blurry_latent_rep = None
#             # --- End Blurry Component Handling ---
#             all_backbone_reps.append(backbone_rep)
#             all_clip_voxels_reps.append(clip_voxels_rep)
#             all_blurry_enc_reps.append(blurry_latent_rep)

#         # Average - Check for None
#         backbone = torch.mean(torch.stack([t.cpu() for t in all_backbone_reps if t is not None]), dim=0) if all(t is not None for t in all_backbone_reps) else None
#         clip_voxels = torch.mean(torch.stack([t.cpu() for t in all_clip_voxels_reps if t is not None]), dim=0) if all(t is not None for t in all_clip_voxels_reps) else None
#         blurry_image_enc = None
#         if blurry_recon and all(isinstance(t, torch.Tensor) for t in all_blurry_enc_reps): # Check they are Tensors
#             try:
#                  blurry_tensors_for_stacking = [t.cpu() for t in all_blurry_enc_reps]
#                  stacked_blurry_encs = torch.stack(blurry_tensors_for_stacking, dim=1) # Stack along dim 1 (reps)
#                  blurry_image_enc = torch.mean(stacked_blurry_encs, dim=1) # Average over dim 1
#                  print(f"Averaged blurry latent shape: {blurry_image_enc.shape}")
#             except Exception as e: print(f"Error averaging blurry latents: {e}. Setting to None.")
#         elif blurry_recon: print("Warning: Blurry latents not valid tensors.")

#         if backbone is None or clip_voxels is None: print("Skipping batch due to None backbone/clip_voxels."); continue

#         # Store clipvoxels
#         if all_clipvoxels is None: all_clipvoxels = clip_voxels.cpu()
#         else: all_clipvoxels = torch.vstack((all_clipvoxels, clip_voxels.cpu()))

#         # Diffusion Prior
#         prior_out = model.diffusion_prior.p_sample_loop(backbone.shape, text_cond=dict(text_embed=backbone), cond_scale=1., timesteps=20)

#         # Caption Generation
#         try:
#             pred_caption_emb = clip_convert(prior_out)
#             generated_ids = clip_text_model.generate(pixel_values=pred_caption_emb, max_length=20)
#             generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)
#             all_predcaptions.extend(generated_caption) # Add to standard list
#             print(f"Batch {batch_start//minibatch_size}: {generated_caption}")
#         except Exception as caption_e:
#             print(f"Error during caption generation: {caption_e}")
#             all_predcaptions.extend(["<caption_error>"] * len(voxel))

#         # --- unCLIP Reconstruction (Attempt on GPU) ---
#         vector_suffix_gpu = None; batch_recons_tensor = None; ctx = None; samples = None
#         reconstruction_succeeded = False
#         try:
#             print("Attempting reconstruction: Moving diffusion_engine to GPU...")
#             diffusion_engine.to(device)
#             if hasattr(diffusion_engine.denoiser, 'sigmas') and diffusion_engine.denoiser.sigmas is not None:
#                  diffusion_engine.denoiser.sigmas = diffusion_engine.denoiser.sigmas.to(device)
#             current_batch_size = len(voxel)
#             vs_repeated = vector_suffix.repeat(current_batch_size, 1) if vector_suffix.shape[0] != current_batch_size else vector_suffix
#             vector_suffix_gpu = vs_repeated.to(device)
#             print("Generating reconstructions...")
#             batch_recons = []
#             with torch.cuda.amp.autocast(dtype=dtype, enabled=(dtype != torch.float32)):
#                  for i in range(len(voxel)):
#                      ctx = F.pad(prior_out[[i]], (0, 1664 - clip_emb_dim)).to(device)
#                      # Make sure utils.unclip_recon exists and is imported
#                      samples = utils.unclip_recon(ctx, diffusion_engine, vector_suffix_gpu[[i]], num_samples=num_samples_per_image)
#                      batch_recons.append(samples.cpu())
#                      del ctx; del samples
#             if batch_recons:
#                  batch_recons_tensor = torch.cat(batch_recons, dim=0)
#                  if all_recons is None: all_recons = batch_recons_tensor
#                  else: all_recons = torch.vstack((all_recons, batch_recons_tensor))
#                  reconstruction_succeeded = True
#             del batch_recons;
#             if batch_recons_tensor is not None: del batch_recons_tensor
#         except Exception as e:
#             print(f"Error during reconstruction: {e}")
#             if 'out of memory' in str(e).lower(): print("\n--- CUDA Out of Memory ---\n")
#             reconstruction_succeeded = False
#         finally:
#             print("Moving diffusion_engine back to CPU...")
#             diffusion_engine.to("cpu")
#             if vector_suffix_gpu is not None: del vector_suffix_gpu
#             torch.cuda.empty_cache()
#         if reconstruction_succeeded: print("Reconstruction successful.")

#         # --- Blurry Reconstruction (Attempt on GPU IF ENABLED and Latent Available) ---
#         if blurry_recon and autoenc is not None and blurry_image_enc is not None:
#              blurry_image_enc_gpu = None; blurred_image_gpu = None; blurred_image = None
#              try:
#                 print("Attempting blurry reconstruction: Moving autoenc to GPU...")
#                 autoenc.to(device)
#                 blurry_image_enc_gpu = blurry_image_enc.to(device)
#                 print(f"Shape feeding into blurry decode: {blurry_image_enc_gpu.shape}")

#                 print("Decoding blurry images...")
#                 with torch.cuda.amp.autocast(dtype=dtype, enabled=(dtype != torch.float32)):
#                      # VAE Scaling: Divide by scale factor before decode, then un-normalize
#                      # Ensure autoenc.config.scaling_factor is accessible and correct
#                      scale_factor_vae = getattr(getattr(autoenc, 'config', None), 'scaling_factor', 0.18215) # Use default if not found
#                      blurred_image_gpu = (autoenc.decode(blurry_image_enc_gpu / scale_factor_vae).sample / 2 + 0.5).clamp(0, 1)

#                 blurred_image = blurred_image_gpu.cpu()
#                 print("Blurry decoding done, images moved to CPU.")

#                 # Aggregate blurry recons
#                 if all_blurryrecons is None: all_blurryrecons = blurred_image
#                 else: all_blurryrecons = torch.vstack((all_blurryrecons, blurred_image))

#              except Exception as e:
#                  print(f"Error during blurry reconstruction: {e}")
#                  if 'Expected' in str(e) and 'input' in str(e): print(">>> Shape mismatch feeding into autoenc.decode.")
#                  if 'out of memory' in str(e).lower(): print("\n--- CUDA Out of Memory during Blurry Recon ---\n")
#              finally:
#                 print("Moving autoenc back to CPU...")
#                 autoenc.to("cpu")
#                 if blurry_image_enc_gpu is not None: del blurry_image_enc_gpu
#                 if blurred_image_gpu is not None: del blurred_image_gpu
#                 torch.cuda.empty_cache()
#         elif blurry_recon:
#              print("Skipping blurry reconstruction for this batch (autoenc or latent missing/invalid).")

#         # End of loop cleanup
#         del voxel, voxel_list, voxel_ridge, backbone, clip_voxels, blurry_image_enc
#         del all_backbone_reps, all_clip_voxels_reps, all_blurry_enc_reps
#         del prior_out

# # --- Post-processing and Saving ---
# imsize = 256
# if all_recons is not None and len(all_recons) > 0:
#     print(f"Resizing {len(all_recons)} reconstructions to {imsize}x{imsize}...")
#     all_recons = transforms.Resize((imsize, imsize), antialias=True)(all_recons.float())
# else: all_recons = None

# if blurry_recon and all_blurryrecons is not None and len(all_blurryrecons) > 0:
#     print(f"Resizing {len(all_blurryrecons)} blurry reconstructions to {imsize}x{imsize}...")
#     all_blurryrecons = transforms.Resize((imsize, imsize), antialias=True)(all_blurryrecons.float())
# else: all_blurryrecons = None

# # Saving
# print("Saving outputs...")
# output_dir = f"evals/{model_name}"
# # --- Add Debugging Prints for Saving ---
# print(f"[Debug] Output directory variable: {output_dir}")
# print(f"[Debug] Absolute output directory: {os.path.abspath(output_dir)}")
# print(f"[Debug] Attempting to create directory: {os.path.abspath(output_dir)}")
# try:
#     os.makedirs(output_dir, exist_ok=True)
#     print(f"[Debug] Directory exists after makedirs: {os.path.isdir(output_dir)}")
#     if hasattr(os, 'access'): print(f"[Debug] Write permission for {output_dir}: {os.access(output_dir, os.W_OK)}")
# except Exception as e: print(f"[Debug] ERROR creating directory {output_dir}: {e}")
# # --- End Debugging ---

# if all_recons is not None:
#     recon_path = os.path.join(output_dir, f"{model_name}_all_recons.pt")
#     print(f"Shape of saved recons: {all_recons.shape}. Saving to {recon_path}...")
#     print(f"[Debug] Absolute recon path: {os.path.abspath(recon_path)}")
#     print(f"[Debug] Directory for recon path exists: {os.path.isdir(os.path.dirname(recon_path))}")
#     try:
#         torch.save(all_recons, recon_path)
#         print("Recons saved successfully.")
#     except Exception as e: print(f"!!! ERROR saving recons to {recon_path}: {e}")
# else: print("Skipping saving of image reconstructions (all_recons is None).")

# if all_blurryrecons is not None:
#      blurry_path = os.path.join(output_dir, f"{model_name}_all_blurryrecons.pt")
#      print(f"Shape of saved blurry recons: {all_blurryrecons.shape}. Saving to {blurry_path}...")
#      try:
#          torch.save(all_blurryrecons, blurry_path)
#          print("Blurry recons saved successfully.")
#      except Exception as e: print(f"!!! ERROR saving blurry recons to {blurry_path}: {e}")
# else: print("Skipping saving of blurry reconstructions.")

# # --- Corrected Caption Saving ---
# if all_predcaptions:
#      if not isinstance(all_predcaptions, list) or not all(isinstance(item, str) for item in all_predcaptions):
#           print("Warning: all_predcaptions not list of strings. Converting...")
#           try: all_predcaptions = [str(item) for item in all_predcaptions]
#           except Exception as conv_e: print(f"Error converting captions: {conv_e}. Skipping save."); all_predcaptions = None
#      if all_predcaptions:
#           caption_path = os.path.join(output_dir, f"{model_name}_all_predcaptions.pt")
#           print(f"Saving {len(all_predcaptions)} predicted captions to {caption_path}...")
#           try:
#                torch.save(all_predcaptions, caption_path)
#                print("Captions saved successfully.")
#           except Exception as save_e:
#                print(f"!!! Error saving captions with torch.save: {save_e}")
#                print("Attempting fallback to .npy...")
#                try:
#                     caption_path_npy = os.path.join(output_dir, f"{model_name}_all_predcaptions.npy")
#                     np.save(caption_path_npy, np.array(all_predcaptions, dtype=object))
#                     print(f"Captions saved as fallback to {caption_path_npy}")
#                except Exception as npy_save_e: print(f"!!! Fallback saving as .npy also failed: {npy_save_e}")
# else: print("Skipping saving of captions.")
# # -------------------------------

# if all_clipvoxels is not None:
#     clipvoxels_path = os.path.join(output_dir, f"{model_name}_all_clipvoxels.pt")
#     print(f"Shape of saved clipvoxels: {all_clipvoxels.shape}. Saving to {clipvoxels_path}...")
#     try:
#         torch.save(all_clipvoxels, clipvoxels_path)
#         print("Clipvoxels saved successfully.")
#     except Exception as e: print(f"!!! ERROR saving clipvoxels to {clipvoxels_path}: {e}")
# else: print("Skipping saving of clipvoxels (all_clipvoxels is None).")


# print(f"Saved outputs for {model_name} in {output_dir}")

# if not utils.is_interactive():
#     sys.exit(0)

# # %%

In [10]:
# %%
# %%
# recon_inference.ipynb - Main processing cell (REVISED AGAIN)

import inspect
import os # Make sure os is imported if not already
import torch # Make sure torch is imported
import numpy as np # Make sure numpy is imported if not already
import torch.nn.functional as F # Import functional if needed by utils.unclip_recon
from tqdm import tqdm # Ensure tqdm is imported
from torchvision import transforms # Ensure transforms are imported

# --- Ensure blurry_recon argument is used ---

# (vector_suffix inference logic - keep the corrected version)
try:
    embed_dim = None; embedder_info = []
    if hasattr(diffusion_engine, 'conditioner') and hasattr(diffusion_engine.conditioner, 'embedders'):
        print("Available conditioner embedders:")
        # Correctly iterate over ModuleList
        for i, emb in enumerate(diffusion_engine.conditioner.embedders): # Use enumerate for ModuleList
            emb_type = type(emb); embedder_info.append(f"  Index: {i}, Type: {emb_type}")
            # Simplified checks for common types
            if isinstance(emb, FrozenOpenCLIPEmbedder2):
                 # Safely access attributes
                 proj_layer = getattr(getattr(emb, 'model', None), 'text_projection', None)
                 if proj_layer is not None and hasattr(proj_layer, 'shape'):
                      embed_dim = proj_layer.shape[-1]
                      embedder_info[-1] += f" -> Found embed_dim: {embed_dim} (Preferred)"
                      break # Found preferred, stop checking
                 else:
                      embedder_info[-1] += " -> Warning: Could not get text_projection shape"
            elif isinstance(emb, FrozenOpenCLIPImageEmbedder):
                 found_dim = getattr(emb, 'output_dim', getattr(getattr(emb, 'model', None), 'output_dim', None))
                 if found_dim and embed_dim is None: # Use if no preferred found yet
                     embed_dim = found_dim
                     embedder_info[-1] += f" -> Found embed_dim: {found_dim}"
    for info in embedder_info: print(info)
    # Ensure embed_dim is an integer before using it for shape
    vector_suffix_shape = (1, embed_dim) if isinstance(embed_dim, int) else (1, 1024)
    print(f"Using vector_suffix_shape: {vector_suffix_shape}")
    vector_suffix = torch.zeros(vector_suffix_shape, device='cpu', dtype=torch.float32)
    print(f"Created placeholder vector_suffix with shape: {vector_suffix.shape}. Needs verification!")
except Exception as e:
    print(f"Error inferring vector_suffix shape: {e}. Using default (1, 1024).")
    vector_suffix = torch.zeros(1, 1024, device='cpu', dtype=torch.float32)

# %%
# get all reconstructions
model.to("cpu")
model.eval().requires_grad_(False)

# --- Load Autoencoder Conditionally ---
autoenc = None # Ensure initialized to None
# --- Keep Autoencoder loading if blurry_recon is ON, needed for training loss, but not for inference decoding ---
if blurry_recon:
    print("Attempting to load Autoencoder for Blurry Recon (Training Loss Component)...")
    try:
        from diffusers import AutoencoderKL
        autoenc = AutoencoderKL(
             down_block_types=['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],
             up_block_types=['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],
             block_out_channels=[128, 256, 512, 512], layers_per_block=2,
             # --- CORRECTED: Use sample_size 256 as per typical SD VAE checkpoints ---
             sample_size=256,
         )
        ckpt_path = f'{cache_dir}/sd_image_var_autoenc.pth'
        if not os.path.exists(ckpt_path):
             print(f"ERROR: Autoencoder checkpoint file not found at {ckpt_path}. Blurry recon components might be affected.")
             # We don't disable blurry_recon here as the backbone might still output the latent for other uses
             autoenc = None # Set to None if file missing
        else:
             ckpt = torch.load(ckpt_path, map_location='cpu')
             autoenc.load_state_dict(ckpt)
             print(f"Autoencoder loaded successfully from {ckpt_path}")
             autoenc.eval().requires_grad_(False).to("cpu") # Keep on CPU initially
             utils.count_params(autoenc)
    except Exception as e:
        print(f"ERROR loading autoencoder: {e}. Blurry recon components might be affected.")
        autoenc = None # Set to None on error
else:
    print("Blurry reconstruction disabled by args, skipping Autoencoder load.")


# Keep other models on CPU initially
clip_text_model.to("cpu")
diffusion_engine.to("cpu")

# Initialize result accumulators
# all_blurryrecons = None # --- Comment out or remove if not decoding blurry latent ---
all_recons = None
all_predcaptions = [] # Use a standard list
all_clipvoxels = None
all_backbones = None # Add accumulator for backbone output if needed for UMAP
all_prior_out = None # Add accumulator for prior output if needed for UMAP


# Training Loop Settings
minibatch_size = 1
num_samples_per_image = 1
assert num_samples_per_image == 1
if utils.is_interactive(): plotting=False
# --- CORRECTED: Use FP16 for GPU operations ---
dtype = torch.float16

# Check BrainNetwork internal flag
if hasattr(model, 'backbone') and hasattr(model.backbone, 'blurry_recon'):
     print(f"Loaded model.backbone internal blurry_recon flag: {model.backbone.blurry_recon}")
else: print("Warning: Cannot check internal blurry_recon flag.")

with torch.no_grad():
    test_images_idx_unique = np.unique(test_images_idx)
    test_images_idx_unique = test_images_idx_unique[:10] # Limit for debugging

    print(f"Processing {len(test_images_idx_unique)} unique images...")
    for batch_start in tqdm(range(0, len(test_images_idx_unique), minibatch_size)):
        batch_end = batch_start + minibatch_size
        uniq_imgs = test_images_idx_unique[batch_start:batch_end]

        # --- Voxel Preparation (CPU) ---
        voxel_list = []
        for uniq_img in uniq_imgs:
            locs = np.where(test_images_idx == uniq_img)[0]
            # --- Ensure you handle cases with fewer than 3 repetitions if needed ---
            if len(locs) < 1: continue # Skip images with no fMRI data
            # Pad/repeat locs to always have 3 if the model expects it
            if len(locs) == 1: locs = np.repeat(locs, 3)
            elif len(locs) == 2: locs = np.concatenate([locs, locs[:1]]) # Repeat first element
            elif len(locs) > 3: locs = locs[:3]
            # Ensure valid indices before accessing test_voxels
            max_voxel_sample_idx = len(test_voxels) - 1
            valid_locs = [l for l in locs if l >= 0 and l <= max_voxel_sample_idx]
            if len(valid_locs) != 3: # After ensuring 3 locs, re-check validity
                 print(f"Warning: Skipping unique image {uniq_img} due to invalid or insufficient voxel sample indices ({len(valid_locs)} found).")
                 continue # Skip batch if valid locs are not 3

            voxel_list.append(test_voxels[None, valid_locs])

        if not voxel_list: continue # Skip batch if no valid images were found after checks
        voxel = torch.cat(voxel_list, dim=0).to("cpu")

        # --- MindEye Processing (CPU) ---
        all_backbone_reps = []
        all_clip_voxels_reps = []
        # --- FIX: Uncomment this line ---
        all_blurry_enc_reps = [] # Accumulator for blurry latent if needed
        for rep in range(3):
            voxel_ridge = model.ridge(voxel[:, [rep]], 0)
            backbone_out = model.backbone(voxel_ridge)

            # Robust Output Handling based on BrainNetwork's return (backbone, c, b)
            backbone_rep, clip_voxels_rep, blurry_latent_rep = None, None, None # Initialize blurry_latent_rep
            if isinstance(backbone_out, (list, tuple)) and len(backbone_out) >= 2:
                 backbone_rep = backbone_out[0]
                 clip_voxels_rep = backbone_out[1]
                 if len(backbone_out) >= 3:
                     blurry_component = backbone_out[2]
                     # --- FIX: Uncomment and adjust blurry latent extraction if needed ---
                     # Check if the component itself is the tensor or a tuple containing it
                     if isinstance(blurry_component, torch.Tensor):
                          blurry_latent_rep = blurry_component
                     elif isinstance(blurry_component, tuple) and len(blurry_component) > 0 and isinstance(blurry_component[0], torch.Tensor):
                          # Assuming the first element of the tuple is the desired latent
                          blurry_latent_rep = blurry_component[0] 
                     else:
                          blurry_latent_rep = None # Handle unexpected format or None case
                 else:
                      blurry_latent_rep = None # No third element returned
            else:
                 print("Warning: Unexpected output structure from model.backbone.")
                 backbone_rep = backbone_out
                 clip_voxels_rep = torch.zeros_like(backbone_rep) if isinstance(backbone_rep, torch.Tensor) else None
                 # blurry_latent_rep = None # Cannot get blurry latent if output structure is wrong

            all_backbone_reps.append(backbone_rep)
            all_clip_voxels_reps.append(clip_voxels_rep)
            # if blurry_latent_rep is not None: all_blurry_enc_reps.append(blurry_latent_rep)


        # Average - Check for None before stacking/averaging
        # Ensure all reps produced valid tensors before averaging
        if not all(t is not None for t in all_backbone_reps) or not all(t is not None for t in all_clip_voxels_reps):
             print(f"Skipping batch {batch_start//minibatch_size} due to None backbone/clip_voxels outputs from reps.")
             continue # Skip this batch

        backbone = torch.mean(torch.stack([t.cpu() for t in all_backbone_reps]), dim=0)
        clip_voxels = torch.mean(torch.stack([t.cpu() for t in all_clip_voxels_reps]), dim=0)

        # Average blurry latent if needed (and if the code path to collect it above is uncommented)
        blurry_image_enc = None
        if blurry_recon and all_blurry_enc_reps and all(isinstance(t, torch.Tensor) for t in all_blurry_enc_reps):
            try:
                 # Stack along dim 1 (reps), then average
                 stacked_blurry_encs = torch.stack([t.cpu() for t in all_blurry_enc_reps], dim=1)
                 blurry_image_enc = torch.mean(stacked_blurry_encs, dim=1)
                 print(f"Averaged blurry latent shape: {blurry_image_enc.shape}")
            except Exception as e: print(f"Error averaging blurry latents: {e}. Setting to None.")
        elif blurry_recon: print("Warning: Blurry latents not valid tensors from all reps.")


        # Store backbone, clipvoxels, prior_out if needed for UMAP/evals
        if all_backbones is None: all_backbones = backbone.cpu()
        else: all_backbones = torch.vstack((all_backbones, backbone.cpu()))
        if all_clipvoxels is None: all_clipvoxels = clip_voxels.cpu()
        else: all_clipvoxels = torch.vstack((all_clipvoxels, clip_voxels.cpu()))

        # Diffusion Prior (on CPU initially, result is prior_out on CPU)
        prior_out = model.diffusion_prior.p_sample_loop(backbone.shape, text_cond=dict(text_embed=backbone), cond_scale=1., timesteps=20)

        # Store prior_out if needed for UMAP/evals
        if all_prior_out is None: all_prior_out = prior_out.cpu()
        else: all_prior_out = torch.vstack((all_prior_out, prior_out.cpu()))

        # --- Cleanup before Captioning ---
        del backbone, clip_voxels # Delete averaged tensors now they are stored/used
        del all_backbone_reps, all_clip_voxels_reps # Delete lists
        torch.cuda.empty_cache()

        # Caption Generation
        generated_caption = ["<caption_error>"] * len(voxel) # Default value
        try:
            # Ensure clip_convert is on CPU
            clip_convert.to("cpu")
            # Convert on CPU
            pred_caption_emb_cpu = clip_convert(prior_out.cpu())

            # Move model to GPU, try FP16 this time as FP32 didn't help shape mismatch
            print("Moving clip_text_model to GPU (FP16) for caption generation...")
            clip_text_model.half().to(device) # Use FP16

            # Ensure input tensor is contiguous AND FP16 on the correct device
            pixel_values_input = pred_caption_emb_cpu.half().to(device).contiguous()

            # Generate captions - Use simple greedy search (num_beams=1) to minimize complexity
            # Use autocast context
            with torch.cuda.amp.autocast(dtype=torch.float16, enabled=True):
                 generated_ids = clip_text_model.generate(
                     pixel_values=pixel_values_input,
                     max_length=20,
                     num_beams=1, # Use greedy search
                     do_sample=False # Disable sampling
                 )

            # Move model back to CPU immediately
            clip_text_model.to("cpu").float() # Move to CPU and restore FP32

            # Process results on CPU
            generated_caption = processor.batch_decode(generated_ids.cpu(), skip_special_tokens=True)
            print(f"Batch {batch_start//minibatch_size}: {generated_caption}")

            # --- Cleanup ---
            del pixel_values_input, generated_ids, pred_caption_emb_cpu

        except Exception as caption_e:
            print(f"Error during caption generation for batch {batch_start//minibatch_size}: {caption_e}")
            # generated_caption remains ["<caption_error>"]
        finally:
            # Ensure model is back on CPU and FP32 even if there's an error
            clip_text_model.to("cpu").float()
            torch.cuda.empty_cache() # Clear cache after captioning

        all_predcaptions.extend(generated_caption) # Extend with result or error placeholder

        # --- Explicit Cleanup Before Reconstruction ---
        print("Clearing cache before reconstruction step...")
        torch.cuda.empty_cache()

        # --- unCLIP Reconstruction (Attempt on GPU) ---
        vector_suffix_gpu = None; batch_recons_tensor = None; ctx = None; samples = None
        reconstruction_succeeded = False
        try:
            # --- FIX: Explicitly use FP16 for diffusion engine BEFORE moving ---
            print("Converting diffusion_engine to FP16 (half precision)...")
            diffusion_engine.half() # Use .half() to convert model parameters to FP16

            print("Attempting reconstruction: Moving diffusion_engine (FP16) to GPU...")
            diffusion_engine.to(device) # Move the FP16 model to GPU

            # Make sure diffusion engine components are on device if they have state (like sigmas)
            if hasattr(diffusion_engine.denoiser, 'sigmas') and diffusion_engine.denoiser.sigmas is not None:
                 diffusion_engine.denoiser.sigmas = diffusion_engine.denoiser.sigmas.to(device=device, dtype=torch.float16)

            current_batch_size = len(voxel)
            vs_repeated = vector_suffix.float().repeat(current_batch_size, 1) if vector_suffix.shape[0] != current_batch_size else vector_suffix.float()
            vector_suffix_gpu = vs_repeated.to(device).half() # Move to GPU and convert to FP16

            print("Generating reconstructions...")
            batch_recons = []
            # Use autocast with dtype=torch.float16
            with torch.cuda.amp.autocast(dtype=dtype, enabled=True): # dtype is torch.float16
                 for i in range(len(voxel)):
                     ctx = F.pad(prior_out[[i]].cpu(), (0, 1664 - prior_out.shape[-1])).to(device).half()
                     current_vector_suffix = vector_suffix_gpu[[i]]

                     samples = utils.unclip_recon(ctx, diffusion_engine, current_vector_suffix, num_samples=num_samples_per_image)
                     batch_recons.append(samples.float().cpu()) # Convert back to FP32 on CPU
                     del ctx; del samples; del current_vector_suffix
                     torch.cuda.empty_cache()

            if batch_recons:
                 batch_recons_tensor = torch.cat(batch_recons, dim=0)
                 if all_recons is None: all_recons = batch_recons_tensor
                 else: all_recons = torch.vstack((all_recons, batch_recons_tensor))
                 reconstruction_succeeded = True

            if batch_recons_tensor is not None: del batch_recons_tensor

        except Exception as e:
            print(f"Error during reconstruction: {e}")
            if 'out of memory' in str(e).lower():
                print("\\n--- CUDA Out of Memory ---")
                print(f"Current Memory Allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
                print(f"Max Memory Allocated: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
            reconstruction_succeeded = False
        finally:
            print("Moving diffusion_engine back to CPU...")
            diffusion_engine.to("cpu").float() # Move model back to CPU and restore FP32

            if hasattr(diffusion_engine.denoiser, 'sigmas') and diffusion_engine.denoiser.sigmas is not None:
                 diffusion_engine.denoiser.sigmas = diffusion_engine.denoiser.sigmas.cpu()
            if vector_suffix_gpu is not None: del vector_suffix_gpu
            torch.cuda.empty_cache()

        if reconstruction_succeeded: print("Reconstruction successful.")
        else: print("Reconstruction failed or was skipped.")

        # --- Cleanup prior_out after it's used for reconstruction padding ---
        del prior_out
        torch.cuda.empty_cache()

# --- Post-processing and Saving ---
imsize = 256
if all_recons is not None and len(all_recons) > 0:
    print(f"Resizing {len(all_recons)} reconstructions to {imsize}x{imsize}...")
    # Use antialias=True for better downsampling quality
    all_recons = transforms.Resize((imsize, imsize), antialias=True)(all_recons.float())
else: all_recons = None

# --- Skip resizing/saving blurry recons if decoding was commented out ---
if blurry_recon and all_blurryrecons is not None and len(all_blurryrecons) > 0:
    print(f"Resizing {len(all_blurryrecons)} blurry reconstructions to {imsize}x{imsize}...")
    all_blurryrecons = transforms.Resize((imsize, imsize), antialias=True)(all_blurryrecons.float())
else: all_blurryrecons = None

# Saving
print("Saving outputs...")
output_dir = f"evals/{model_name}"
print(f"Saving outputs to directory: {output_dir}")
try:
    os.makedirs(output_dir, exist_ok=True)
except Exception as e:
    print(f"ERROR creating output directory {output_dir}: {e}")
    # Decide how to handle this - maybe exit or fallback? For now, just print error.


if all_recons is not None:
    recon_path = os.path.join(output_dir, f"{model_name}_all_recons.pt")
    print(f"Shape of saved recons: {all_recons.shape}. Saving to {recon_path}...")
    try:
        torch.save(all_recons, recon_path)
        print("Recons saved successfully.")
    except Exception as e:
         print(f"!!! ERROR saving recons to {recon_path}: {e}")
else: print("Skipping saving of image reconstructions (all_recons is None).")

# --- Skip saving blurry recons if decoding was commented out ---
if all_blurryrecons is not None:
     blurry_path = os.path.join(output_dir, f"{model_name}_all_blurryrecons.pt")
     print(f"Shape of saved blurry recons: {all_blurryrecons.shape}. Saving to {blurry_path}...")
     try:
         torch.save(all_blurryrecons, blurry_path)
         print("Blurry recons saved successfully.")
     except Exception as e: print(f"!!! ERROR saving blurry recons to {blurry_path}: {e}")
else: print("Skipping saving of blurry reconstructions.")


# --- Corrected Caption Saving ---
# Ensure all_predcaptions is a list of strings before saving
if all_predcaptions:
     if not isinstance(all_predcaptions, list) or not all(isinstance(item, str) for item in all_predcaptions):
          print("Warning: all_predcaptions not list of strings. Attempting conversion...")
          try: all_predcaptions = [str(item) for item in all_predcaptions]
          except Exception as conv_e: print(f"Error converting captions: {conv_e}. Skipping save."); all_predcaptions = None
     if all_predcaptions:
          caption_path = os.path.join(output_dir, f"{model_name}_all_predcaptions.pt")
          print(f"Saving {len(all_predcaptions)} predicted captions to {caption_path}...")
          try:
               torch.save(all_predcaptions, caption_path)
               print("Captions saved successfully.")
          except Exception as save_e:
               print(f"!!! Error saving captions with torch.save: {save_e}")
               print("Attempting fallback to .npy...")
               try:
                    caption_path_npy = os.path.join(output_dir, f"{model_name}_all_predcaptions.npy")
                    np.save(caption_path_npy, np.array(all_predcaptions, dtype=object))
                    print(f"Captions saved as fallback to {caption_path_npy}")
               except Exception as npy_save_e: print(f"!!! Fallback saving as .npy also failed: {npy_save_e}")
else: print("Skipping saving of captions (list is empty).")


# --- Save Backbone and Prior_Out if accumulated ---
if all_backbones is not None:
    backbones_path = os.path.join(output_dir, f"{model_name}_all_backbones.pt")
    print(f"Shape of saved backbones: {all_backbones.shape}. Saving to {backbones_path}...")
    try:
        torch.save(all_backbones, backbones_path)
        print("Backbones saved successfully.")
    except Exception as e: print(f"!!! ERROR saving backbones to {backbones_path}: {e}")
else: print("Skipping saving of backbones (all_backbones is None).")

if all_prior_out is not None:
    prior_out_path = os.path.join(output_dir, f"{model_name}_all_prior_out.pt")
    print(f"Shape of saved prior_out: {all_prior_out.shape}. Saving to {prior_out_path}...")
    try:
        torch.save(all_prior_out, prior_out_path)
        print("Prior_out saved successfully.")
    except Exception as e: print(f"!!! ERROR saving prior_out to {prior_out_path}: {e}")
else: print("Skipping saving of prior_out (all_prior_out is None).")


if all_clipvoxels is not None:
    clipvoxels_path = os.path.join(output_dir, f"{model_name}_all_clipvoxels.pt")
    print(f"Shape of saved clipvoxels: {all_clipvoxels.shape}. Saving to {clipvoxels_path}...")
    try:
        torch.save(all_clipvoxels, clipvoxels_path)
        print("Clipvoxels saved successfully.")
    except Exception as e: print(f"!!! ERROR saving clipvoxels to {clipvoxels_path}: {e}")
else: print("Skipping saving of clipvoxels (all_clipvoxels is None).")


print(f"Saved outputs for {model_name} in {output_dir}")

if not utils.is_interactive():
    sys.exit(0)

# %%

Available conditioner embedders:
  Index: 0, Type: <class 'sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder'>
  Index: 1, Type: <class 'sgm.modules.encoders.modules.ConcatTimestepEmbedderND'>
  Index: 2, Type: <class 'sgm.modules.encoders.modules.ConcatTimestepEmbedderND'>
Using vector_suffix_shape: (1, 1024)
Created placeholder vector_suffix with shape: torch.Size([1, 1024]). Needs verification!
Attempting to load Autoencoder for Blurry Recon (Training Loss Component)...
Autoencoder loaded successfully from /workspace/MindEyeV2/MindEyeV2/src/cache/sd_image_var_autoenc.pth
param counts:
83,653,863 total
0 trainable
Loaded model.backbone internal blurry_recon flag: True
Processing 10 unique images...


  0%|          | 0/10 [00:00<?, ?it/s]



sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

Moving clip_text_model to GPU (FP16) for caption generation...
Error during caption generation for batch 0: output with shape [1, 1, 1, 1] doesn't match the broadcast shape [1, 1, 1, 2]
Clearing cache before reconstruction step...
Converting diffusion_engine to FP16 (half precision)...
Attempting reconstruction: Moving diffusion_engine (FP16) to GPU...
Generating reconstructions...




Error during reconstruction: Input type (float) and bias type (c10::Half) should be the same
Moving diffusion_engine back to CPU...


 10%|█         | 1/10 [00:17<02:40, 17.81s/it]

Reconstruction failed or was skipped.


sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

Moving clip_text_model to GPU (FP16) for caption generation...
Error during caption generation for batch 1: output with shape [1, 1, 1, 1] doesn't match the broadcast shape [1, 1, 1, 2]
Clearing cache before reconstruction step...
Converting diffusion_engine to FP16 (half precision)...
Attempting reconstruction: Moving diffusion_engine (FP16) to GPU...
Generating reconstructions...
Error during reconstruction: Input type (float) and bias type (c10::Half) should be the same
Moving diffusion_engine back to CPU...


 20%|██        | 2/10 [00:33<02:11, 16.48s/it]

Reconstruction failed or was skipped.


sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

Moving clip_text_model to GPU (FP16) for caption generation...
Error during caption generation for batch 2: output with shape [1, 1, 1, 1] doesn't match the broadcast shape [1, 1, 1, 2]
Clearing cache before reconstruction step...
Converting diffusion_engine to FP16 (half precision)...
Attempting reconstruction: Moving diffusion_engine (FP16) to GPU...
Generating reconstructions...
Error during reconstruction: Input type (float) and bias type (c10::Half) should be the same
Moving diffusion_engine back to CPU...


 30%|███       | 3/10 [00:48<01:52, 16.08s/it]

Reconstruction failed or was skipped.


sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

Moving clip_text_model to GPU (FP16) for caption generation...
Error during caption generation for batch 3: output with shape [1, 1, 1, 1] doesn't match the broadcast shape [1, 1, 1, 2]
Clearing cache before reconstruction step...
Converting diffusion_engine to FP16 (half precision)...
Attempting reconstruction: Moving diffusion_engine (FP16) to GPU...
Generating reconstructions...
Error during reconstruction: Input type (float) and bias type (c10::Half) should be the same
Moving diffusion_engine back to CPU...


 40%|████      | 4/10 [01:04<01:34, 15.68s/it]

Reconstruction failed or was skipped.


sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

Moving clip_text_model to GPU (FP16) for caption generation...
Error during caption generation for batch 4: output with shape [1, 1, 1, 1] doesn't match the broadcast shape [1, 1, 1, 2]
Clearing cache before reconstruction step...
Converting diffusion_engine to FP16 (half precision)...
Attempting reconstruction: Moving diffusion_engine (FP16) to GPU...
Generating reconstructions...
Moving diffusion_engine back to CPU...


 40%|████      | 4/10 [01:10<01:46, 17.72s/it]


KeyboardInterrupt: 