In [12]:
# Reproducibility
import torch, random, numpy as np
from huggingface_hub import hf_hub_download
from diffusers import AutoencoderKL, DiffusionPipeline
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from pathlib import Path
from PIL import Image
import einops
import matplotlib.pyplot as plt
from datetime import datetime

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

device = torch.device('mps') if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() else torch.device('cpu')
print('Device:', device)

OUT_DIR = Path('../data/outputs/inference')
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Load SD3 VAE
sd3_vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3.5-large", subfolder="vae")

# Load PixCell-1024 pipeline
pipeline = DiffusionPipeline.from_pretrained(
    "StonyBrook-CVLab/PixCell-1024",
    vae=sd3_vae,
    custom_pipeline="StonyBrook-CVLab/PixCell-pipeline",
    trust_remote_code=True,
    torch_dtype=torch.float16,
)
pipeline.to(device);

# Load UNI-2h for conditioning
timm_kwargs = {
            'img_size': 224,
            'patch_size': 14,
            'depth': 24,
            'num_heads': 24,
            'init_values': 1e-5,
            'embed_dim': 1536,
            'mlp_ratio': 2.66667*2,
            'num_classes': 0,
            'no_embed_class': True,
            'mlp_layer': timm.layers.SwiGLUPacked,
            'act_layer': torch.nn.SiLU,
            'reg_tokens': 8,
            'dynamic_img_size': True
        }
uni_model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
transform = create_transform(**resolve_data_config(uni_model.pretrained_cfg, model=uni_model))
uni_model.eval()
uni_model.to(device);

# Unconditional generation
N = 4
LATENT_DIM = 256
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
print('Intended outputs ->', OUT_DIR)

uncond = pipeline.get_unconditional_embedding(N)
with torch.amp.autocast(device.type):
    samples = pipeline(uni_embeds=uncond, negative_uni_embeds=None, guidance_scale=1.0).images

# Save and display images
fig, axes = plt.subplots(1, N, figsize=(N * 4, 4))
for i, img in enumerate(samples):
    img_path = OUT_DIR / f'unconditional_sample_{timestamp}_{i:02d}.png'
    img.save(img_path)
    axes[i].imshow(img)
    axes[i].axis('off')
    axes[i].set_title(f'Sample {i+1}')
plt.tight_layout()
plt.show()

# Conditional generation (example from user's input, adapted for this notebook)
# This part is for demonstration and might be more relevant for 02_dataset_prep.ipynb
# path = hf_hub_download(repo_id="StonyBrook-CVLab/PixCell-1024", filename="test_image.png")
# image = Image.open(path).convert("RGB")

# uni_patches = np.array(image)
# uni_patches = einops.rearrange(uni_patches, '(d1 h) (d2 w) c -> (d1 d2) h w c', d1=4, d2=4)
# uni_input = torch.stack([transform(Image.fromarray(item)) for item in uni_patches])

# with torch.inference_mode():
#     uni_emb = uni_model(uni_input.to(device))

# uni_emb = uni_emb.unsqueeze(0)
# print("Extracted UNI:", uni_emb.shape)

# uncond = pipeline.get_unconditional_embedding(uni_emb.shape[0])
# with torch.amp.autocast(device.type):
#     samples = pipeline(uni_embeds=uni_emb, negative_uni_embeds=uncond, guidance_scale=1.5).images

Device: mps


Keyword arguments {'trust_remote_code': True} are not expected by PixCellPipeline and will be ignored.


Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]

The config attributes {'double_self_attention': False, 'num_vector_embeds': None, 'only_cross_attention': False, 'use_linear_projection': False} were passed to PixCellTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Some weights of PixCellTransformer2DModel were not initialized from the model checkpoint at /Users/nicorosen/.cache/huggingface/hub/models--StonyBrook-CVLab--PixCell-1024/snapshots/5bb47f246858bac6ba8cc1aadad595e08143bf03/transformer and are newly initialized: ['y_pos_embed.y_pos_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Expected types for transformer: (<class 'diffusers_modules.local.StonyBrook-CVLab--PixCell-pipeline.2730c93b7b1c71aa99f09c6dc4623d80caaa9bec.pixcell_transformer_2d.PixCellTransformer2DModel'>,), got <class 'diffusers_modules.local.pixcell_transformer_2d.PixCellTransformer2DModel'>.


config.json:   0%|          | 0.00/587 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.73G [00:00<?, ?B/s]

Intended outputs -> ../data/outputs/inference


  0%|          | 0/20 [00:00<?, ?it/s]

RuntimeError: MPS backend out of memory (MPS allocated: 18.06 GB, other allocations: 2.83 MB, max allowed: 20.40 GB). Tried to allocate 4.00 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

Next: Proceed to dataset prep and reconstructions: [02_dataset_prep.ipynb](02_dataset_prep.ipynb:1).