In [1]:
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
import gc
from tqdm import tqdm
import webdataset as wds

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator

# SDXL unCLIP requires code from https://github.com/Stability-AI/generative-models/tree/main
sys.path.append('generative_models/')
import sgm
from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder, FrozenOpenCLIPEmbedder2
from generative_models.sgm.models.diffusion import DiffusionEngine
from generative_models.sgm.util import append_dims
from omegaconf import OmegaConf

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# custom functions #
import utils
from models import *

### Multi-GPU config ###
local_rank = os.getenv('RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print("LOCAL RANK ", local_rank)  

accelerator = Accelerator(split_batches=False, mixed_precision="fp16")
device = accelerator.device
print("device:",device)

LOCAL RANK  0
device: cuda


In [2]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    model_name = "testing_session_1_to_39"
    print("model_name:", model_name)

    # other variables can be specified in the following string:
    jupyter_args = f"--data_path=/home/ubuntu/MindEyeV2/mindeyev2_dataset \
                    --cache_dir=/home/ubuntu/MindEyeV2/cache \
                    --model_name={model_name} --subj=1 --num_sessions=2\
                    --hidden_dim=512 --n_blocks=4 --new_test"
    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

model_name: testing_session_1_to_39
--data_path=/home/ubuntu/MindEyeV2/mindeyev2_dataset                     --cache_dir=/home/ubuntu/MindEyeV2/cache                     --model_name=testing_session_1_to_39 --subj=1 --num_sessions=2                    --hidden_dim=512 --n_blocks=4 --new_test


In [3]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="testing",
    help="will load ckpt for model found in ../train_logs/model_name",
)
parser.add_argument(
    "--data_path", type=str, default=os.getcwd(),
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--cache_dir", type=str, default=os.getcwd(),
    help="Path to where misc. files downloaded from huggingface are stored. Defaults to current src directory.",
)
parser.add_argument(
    "--subj",type=int, default=1, choices=[1,2,3,4,5,6,7,8],
    help="Validate on which subject?",
)
parser.add_argument(
    "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--n_blocks",type=int,default=4,
)
parser.add_argument(
    "--hidden_dim",type=int,default=2048,
)
parser.add_argument(
    "--new_test",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--seq_len",type=int,default=1,
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--num_sessions", type=int, default=1,
    help="Number of training sessions to include",
)
if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)
    
# seed all random functions
utils.seed_everything(seed)

# make output directory
os.makedirs("evals",exist_ok=True)
os.makedirs(f"evals/{model_name}",exist_ok=True)

print(num_sessions)

2


In [4]:
voxels = {}
num_voxels = {}
num_voxels_list = []
# Load hdf5 data for betas
f = h5py.File(f'{data_path}/betas_all_subj01_fp32_renorm.hdf5', 'r')
betas = f['betas'][:]
betas = torch.Tensor(betas).to("cpu")
num_voxels_list.append(betas[0].shape[-1])
num_voxels[f'subj01'] = betas[0].shape[-1]
voxels[f'subj01'] = betas
print(f"num_voxels for subj01: {num_voxels[f'subj01']}")

if not new_test: # using old test set from before full dataset released (used in original MindEye paper)
    if subj==3:
        num_test=2113
    elif subj==4:
        num_test=1985
    elif subj==6:
        num_test=2113
    elif subj==8:
        num_test=1985
    else:
        num_test=2770
    test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
else: # using larger test set from after full dataset released
    if subj==3:
        num_test=2371
    elif subj==4:
        num_test=2188
    elif subj==6:
        num_test=2371
    elif subj==8:
        num_test=2188
    else:
        num_test=3000
    test_url = f"{data_path}/wds/subj0{subj}/new_test/" + "0.tar"

train_url = f"{data_path}/wds/subj01/train/" + "{1.." + f"{num_sessions}" + "}.tar"
    
print(train_url)
def my_split_by_node(urls): return urls
train_data = {}
train_dl = {}
s = 1
train_data[f'subj0{s}'] = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
                        .decode("torch")\
                        .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
                        .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
train_dl[f'subj0{s}'] = torch.utils.data.DataLoader(train_data[f'subj0{s}'], batch_size=512, shuffle=False, drop_last=True, pin_memory=True)
print(f"Loaded train dl for subj{s}\n")

num_voxels for subj01: 15724
/home/ubuntu/MindEyeV2/mindeyev2_dataset/wds/subj01/train/{1..2}.tar
Loaded train dl for subj1



In [5]:
clip_img_embedder = FrozenOpenCLIPImageEmbedder(
    arch="ViT-bigG-14",
    version="laion2b_s39b_b160k",
    output_tokens=True,
    only_tokens=True,
)
clip_img_embedder.to(device)
clip_seq_dim = 256
clip_emb_dim = 1664

In [6]:
blurry_recon = False

In [7]:
if blurry_recon:
    from diffusers import AutoencoderKL
    autoenc = AutoencoderKL(
        down_block_types=['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],
        up_block_types=['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],
        block_out_channels=[128, 256, 512, 512],
        layers_per_block=2,
        sample_size=256,
    )
    ckpt = torch.load(f'{cache_dir}/sd_image_var_autoenc.pth')
    autoenc.load_state_dict(ckpt)
    autoenc.eval()
    autoenc.requires_grad_(False)
    autoenc.to(device)
    utils.count_params(autoenc)

In [8]:
class MindEyeModule(nn.Module):
    def __init__(self):
        super(MindEyeModule, self).__init__()
    def forward(self, x):
        return x
        
model = MindEyeModule()

In [9]:
class RidgeRegression(torch.nn.Module):
    # make sure to add weight_decay when initializing optimizer
    def __init__(self, input_sizes, out_features, seq_len): 
        super(RidgeRegression, self).__init__()
        self.out_features = out_features
        self.linears = torch.nn.ModuleList([
                torch.nn.Linear(input_size, out_features) for input_size in input_sizes
            ])
    def forward(self, x, subj_idx):
        out = torch.cat([self.linears[subj_idx](x[:,seq]).unsqueeze(1) for seq in range(seq_len)], dim=1)
        return out


print(len(num_voxels_list))
model.ridge = RidgeRegression(num_voxels_list, out_features=hidden_dim, seq_len=seq_len)
utils.count_params(model.ridge)
utils.count_params(model)

# test on subject 1 with fake data
b = torch.randn((2,seq_len,num_voxels_list[0]))
print(b.shape, model.ridge(b,0).shape)

1
param counts:
8,051,200 total
8,051,200 trainable
param counts:
8,051,200 total
8,051,200 trainable
torch.Size([2, 1, 15724]) torch.Size([2, 1, 512])


In [10]:
from diffusers.models.vae import Decoder
class BrainNetwork(nn.Module):
    def __init__(self, h=hidden_dim, in_dim=hidden_dim, out_dim=clip_emb_dim*clip_seq_dim, seq_len=1, n_blocks=n_blocks, drop=.15, 
                 clip_size=768):
        super().__init__()
        self.seq_len = seq_len
        self.h = h
        self.clip_size = clip_size
        
        self.mixer_blocks1 = nn.ModuleList([
            self.mixer_block1(h, drop) for _ in range(n_blocks)
        ])
        self.mixer_blocks2 = nn.ModuleList([
            self.mixer_block2(seq_len, drop) for _ in range(n_blocks)
        ])
        
        # Output linear layer
        self.backbone_linear = nn.Linear(h * seq_len, out_dim, bias=True) 
        self.clip_proj = self.projector(clip_size, clip_size, h=clip_size)
        
        if blurry_recon:
            self.blin1 = nn.Linear(h*seq_len,4*28*28,bias=True)
            self.bdropout = nn.Dropout(.3)
            self.bnorm = nn.GroupNorm(1, 64)
            self.bupsampler = Decoder(
                in_channels=64,
                out_channels=4,
                up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
                block_out_channels=[32, 64, 128],
                layers_per_block=1,
            )
            self.b_maps_projector = nn.Sequential(
                nn.Conv2d(64, 512, 1, bias=False),
                nn.GroupNorm(1,512),
                nn.ReLU(True),
                nn.Conv2d(512, 512, 1, bias=False),
                nn.GroupNorm(1,512),
                nn.ReLU(True),
                nn.Conv2d(512, 512, 1, bias=True),
            )
            
    def projector(self, in_dim, out_dim, h=2048):
        return nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.GELU(),
            nn.Linear(in_dim, h),
            nn.LayerNorm(h),
            nn.GELU(),
            nn.Linear(h, h),
            nn.LayerNorm(h),
            nn.GELU(),
            nn.Linear(h, out_dim)
        )
    
    def mlp(self, in_dim, out_dim, drop):
        return nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(out_dim, out_dim),
        )
    
    def mixer_block1(self, h, drop):
        return nn.Sequential(
            nn.LayerNorm(h),
            self.mlp(h, h, drop),  # Token mixing
        )

    def mixer_block2(self, seq_len, drop):
        return nn.Sequential(
            nn.LayerNorm(seq_len),
            self.mlp(seq_len, seq_len, drop)  # Channel mixing
        )
        
    def forward(self, x):
        # make empty tensors
        c,b,t = torch.Tensor([0.]), torch.Tensor([[0.],[0.]]), torch.Tensor([0.])
        
        # Mixer blocks
        residual1 = x
        residual2 = x.permute(0,2,1)
        for block1, block2 in zip(self.mixer_blocks1,self.mixer_blocks2):
            x = block1(x) + residual1
            residual1 = x
            x = x.permute(0,2,1)
            
            x = block2(x) + residual2
            residual2 = x
            x = x.permute(0,2,1)
            
        x = x.reshape(x.size(0), -1)
        backbone = self.backbone_linear(x).reshape(len(x), -1, self.clip_size)
        c = self.clip_proj(backbone)

        if blurry_recon:
            b = self.blin1(x)
            b = self.bdropout(b)
            b = b.reshape(b.shape[0], -1, 7, 7).contiguous()
            b = self.bnorm(b)
            b_aux = self.b_maps_projector(b).flatten(2).permute(0,2,1)
            b_aux = b_aux.view(len(b_aux), 49, 512)
            b = (self.bupsampler(b), b_aux)
        
        return backbone, c, b

print (f"clip size {clip_emb_dim}")
model.backbone = BrainNetwork(h=hidden_dim, in_dim=hidden_dim, seq_len=seq_len, 
                          clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim) 
utils.count_params(model.ridge)
utils.count_params(model.backbone)
utils.count_params(model)

clip size 1664
param counts:
8,051,200 total
8,051,200 trainable
param counts:
228,956,824 total
228,956,824 trainable
param counts:
237,008,024 total
237,008,024 trainable


237008024

In [11]:
# setup diffusion prior network
out_dim = clip_emb_dim
depth = 6
dim_head = 52
heads = clip_emb_dim//52 # heads * dim_head = clip_emb_dim
timesteps = 100

prior_network = PriorNetwork(
        dim=out_dim,
        depth=depth,
        dim_head=dim_head,
        heads=heads,
        causal=False,
        num_tokens = clip_seq_dim,
        learned_query_mode="pos_emb"
    )

model.diffusion_prior = BrainDiffusionPrior(
    net=prior_network,
    image_embed_dim=out_dim,
    condition_on_text_encodings=False,
    timesteps=timesteps,
    cond_drop_prob=0.2,
    image_embed_scale=None,
)
model.to(device)

utils.count_params(model.diffusion_prior)
utils.count_params(model)

param counts:
259,865,216 total
259,865,200 trainable
param counts:
496,873,240 total
496,873,224 trainable


496873224

In [12]:
# Load pretrained model ckpt
tag='last.pth'
outdir = os.path.abspath(f'/home/ubuntu/MindEyeV2/train_logs/testing_session_1_to_39/')
print(f"\n---loading {outdir}/{tag} ckpt---\n")

named_layers = dict(model.ridge.named_modules())

print(named_layers)
try:
    checkpoint = torch.load(outdir+f'/{tag}', map_location='cpu')
    state_dict = checkpoint['model_state_dict']
    state_dict.pop('ridge.linears.1.weight',None)
    state_dict.pop('ridge.linears.1.bias',None)
    model.load_state_dict(state_dict, strict=True)
    del checkpoint
except: # probably ckpt is saved using deepspeed format
    import deepspeed
    state_dict = deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir=outdir, tag = tag)
    model.load_state_dict(state_dict, strict=False)
    del state_dict
print("ckpt loaded!")


---loading /home/ubuntu/MindEyeV2/train_logs/testing_session_1_to_39/last.pth ckpt---

{'': RidgeRegression(
  (linears): ModuleList(
    (0): Linear(in_features=15724, out_features=512, bias=True)
  )
), 'linears': ModuleList(
  (0): Linear(in_features=15724, out_features=512, bias=True)
), 'linears.0': Linear(in_features=15724, out_features=512, bias=True)}
ckpt loaded!


In [13]:
# setup text caption networks
from transformers import AutoProcessor, AutoModelForCausalLM
from modeling_git import GitForCausalLMClipEmb
processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
clip_text_model = GitForCausalLMClipEmb.from_pretrained("microsoft/git-large-coco")
clip_text_model.to(device) # if you get OOM running this script, you can switch this to cpu and lower minibatch_size to 4
clip_text_model.eval().requires_grad_(False)
clip_text_seq_dim = 257
clip_text_emb_dim = 1024

class CLIPConverter(torch.nn.Module):
    def __init__(self):
        super(CLIPConverter, self).__init__()
        self.linear1 = nn.Linear(clip_seq_dim, clip_text_seq_dim)
        self.linear2 = nn.Linear(clip_emb_dim, clip_text_emb_dim)
    def forward(self, x):
        x = x.permute(0,2,1)
        x = self.linear1(x)
        x = self.linear2(x.permute(0,2,1))
        return x
        
clip_convert = CLIPConverter()
state_dict = torch.load(f"{cache_dir}/bigG_to_L_epoch8.pth", map_location='cpu')['model_state_dict']
clip_convert.load_state_dict(state_dict, strict=True)
clip_convert.to(device) # if you get OOM running this script, you can switch this to cpu and lower minibatch_size to 4
del state_dict

In [14]:
# test_url = f"{data_path}/wds/subj0{subj}/new_test/" + "0.tar"
test_url = f"{data_path}/wds/subj0{subj}/train/" + "{1..2}.tar"

print(test_url)
batch_size = 1360
def my_split_by_node(urls): return urls
test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
                    .decode("torch")\
                    .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
                    .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
test_dl = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)
print(f"Loaded test dl for subj{subj}!\n")

# Prep images but don't load them all to memory
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = f['images']

# Prep test voxels and indices of test images
test_images_idx = []
test_voxels_idx = []
shape = (batch_size, num_voxels[f'subj01'])
test_voxels = torch.empty(shape)

for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
    # test_voxels = torch.vstack((test_voxels, voxels[f'subj01'][behav[:,0,5].cpu().long()]))
    test_voxels_idx = np.append(test_images_idx, behav[:,0,5].cpu().numpy())
    test_images_idx = np.append(test_images_idx, behav[:,0,0].cpu().numpy())

print(voxels[f'subj01'][behav[:,0,5].cpu().long()][0])
print(type(test_voxels), test_voxels.shape)
test_images_idx = test_images_idx.astype(int)
test_voxels_idx = test_voxels_idx.astype(int)

# assert (test_i+1) * num_test == len(test_voxels) == len(test_images_idx)
print(test_i, len(test_voxels), len(test_images_idx), len(np.unique(test_images_idx)))

/home/ubuntu/MindEyeV2/mindeyev2_dataset/wds/subj01/train/{1..2}.tar
Loaded test dl for subj1!

tensor([-0.0663, -1.2783,  1.1900,  ..., -0.5472, -0.7142, -1.0364])
<class 'torch.Tensor'> torch.Size([1360, 15724])
0 1360 1360 978


In [16]:
# Prep images but don't load them all to memory
import gc
torch.cuda.empty_cache()
gc.collect()
# get all reconstructions
model.to(device)
model.eval().requires_grad_(False)

f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = f['images']

# Prep test voxels and indices of test images
num_voxels_list.append(betas[0].shape[-1])
num_voxels[f'subj01'] = betas[0].shape[-1]
voxels[f'subj01'] = betas
test_voxel = None

all_blurryrecons = None
all_recons = None
all_predcaptions = []
all_clipvoxels = None
all_prior_out = None

minibatch_size = 8
num_samples_per_image = 1
assert num_samples_per_image == 1
plotting = False
print (len(np.unique(test_images_idx)))

with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
    for batch in tqdm(range(0,len(np.unique(test_images_idx)))):
        voxel = voxels[f'subj01'][behav[:,0,5].cpu().long()]
        if seq_len==1:
            voxel = voxel.unsqueeze(1)
        print (voxel.shape)
        uniq_imgs = np.unique(test_images_idx)
        
        for uniq_img in uniq_imgs:
            locs = np.where(test_images_idx==uniq_img)[0]
            if len(locs)==1:
                locs = locs.repeat(3)
            elif len(locs)==2:
                locs = locs.repeat(2)[:3]
            assert len(locs)==3

            if test_voxel is None:
                test_voxel = voxel[locs][None] # 1, num_image_repetitions, num_voxels
            else:
                test_voxel = torch.vstack((test_voxel, voxel[locs][None]))
        test_voxel = test_voxel.to(device)


        test_indices = torch.arange(len(test_voxel))[:300]
        voxel = test_voxel[test_indices].to(device)

        for rep in range(3):
            voxel_ridge = model.ridge(voxel[:,rep],0) # 0th index of subj_list
            print (voxel_ridge.shape)
            backbone0, clip_voxels0, blurry_image_enc0 = model.backbone(voxel_ridge)
            if rep==0:
                clip_voxels = clip_voxels0
                backbone = backbone0
                blurry_image_enc = blurry_image_enc0[0]
            else:
                clip_voxels += clip_voxels0
                backbone += backbone0
                blurry_image_enc += blurry_image_enc0[0]
        clip_voxels /= 3
        backbone /= 3
        blurry_image_enc /= 3
        
        # Feed voxels through OpenCLIP-bigG diffusion prior
        prior_out = model.diffusion_prior.p_sample_loop(backbone.shape, 
                        text_cond = dict(text_embed = backbone), 
                        cond_scale = 1., timesteps = 20).cpu()
        
        if all_prior_out is None:
            all_prior_out = prior_out
        else:
            all_prior_out = torch.vstack((all_prior_out, prior_out))
        

torch.save(all_prior_out,f"evals/{model_name}/{model_name}_all_prior_out.pt")

978


  0%|          | 0/978 [00:00<?, ?it/s]

torch.Size([1360, 1, 15724])
torch.Size([300, 1, 512])
torch.Size([300, 1, 512])
torch.Size([300, 1, 512])


sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/978 [00:15<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 978.00 MiB. GPU 0 has a total capacty of 39.39 GiB of which 917.94 MiB is free. Including non-PyTorch memory, this process has 38.48 GiB memory in use. Of the allocated memory 37.07 GiB is allocated by PyTorch, and 940.02 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()
# get all reconstructions
model.to(device)
model.eval().requires_grad_(False)

# all_images = None
all_blurryrecons = None
all_recons = None
all_predcaptions = []
all_clipvoxels = None
all_prior_out = None

minibatch_size = 8
num_samples_per_image = 1
assert num_samples_per_image == 1
plotting = False
print (len(np.unique(test_images_idx)))

with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
    # for batch in tqdm(range(0,len(np.unique(test_images_idx)),minibatch_size)):
    #     uniq_imgs = np.unique(test_images_idx)[batch:batch+minibatch_size]
    for batch in tqdm(range(0,len(np.unique(test_images_idx)))):
        uniq_imgs = np.unique(test_images_idx)
        print(voxels[f'subj01'].shape)
        if seq_len==1:
            test_voxels = test_voxels.unsqueeze(1)

        for uniq_img in uniq_imgs:
            locs = np.where(test_images_idx==uniq_img)[0]
            # print(f'locs before {locs}')
            if len(locs)==1:
                locs = locs.repeat(3)
            elif len(locs)==2:
                locs = locs.repeat(2)[:3]
            # print(test_voxels.shape, locs)
            assert len(locs)==3
            if voxel is None:
                test_voxel = test_voxels[locs][None] # 1, num_image_repetitions, num_voxels
            else:
                test_voxel = torch.vstack((test_voxel, test_voxels[locs][None]))
        voxel = voxel.to(device)
        
        for rep in range(3):
            voxel_ridge = model.ridge(voxel[:,[rep]],0) # 0th index of subj_list
            backbone0, clip_voxels0, blurry_image_enc0 = model.backbone(voxel_ridge)
            if rep==0:
                clip_voxels = clip_voxels0
                backbone = backbone0
                blurry_image_enc = blurry_image_enc0[0]
            else:
                clip_voxels += clip_voxels0
                backbone += backbone0
                blurry_image_enc += blurry_image_enc0[0]
        clip_voxels /= 3
        backbone /= 3
        blurry_image_enc /= 3
        
        # Feed voxels through OpenCLIP-bigG diffusion prior
        prior_out = model.diffusion_prior.p_sample_loop(backbone.shape, 
                        text_cond = dict(text_embed = backbone), 
                        cond_scale = 1., timesteps = 20).cpu()
        
        if all_prior_out is None:
            all_prior_out = prior_out
        else:
            all_prior_out = torch.vstack((all_prior_out, prior_out))
        

torch.save(all_prior_out,f"evals/{model_name}/{model_name}_all_prior_out.pt")

if not utils.is_interactive():
    sys.exit(0)