# 02 — Dataset Prep and Reconstructions
Download a few public, non-sensitive histology tiles, preprocess to model size (assume 1024; verify), and run encode→decode reconstructions. See [README.md](../README.md:1).

## Download sample tiles
- Replace URLs with your preferred public sample tiles.
- Tiles will be stored in [data/inputs/](../data/inputs:1).

In [None]:
from pathlib import Path
import requests

INPUT_DIR = Path('../data/inputs')
INPUT_DIR.mkdir(parents=True, exist_ok=True)

SAMPLE_URLS = [
    # Placeholder non-sensitive images (replace with histology tiles you can use)
    'https://picsum.photos/seed/histo1/1024/1024',
    'https://picsum.photos/seed/histo2/1024/1024',
    'https://picsum.photos/seed/histo3/1024/1024'
]

for i, url in enumerate(SAMPLE_URLS, 1):
    dest = INPUT_DIR / f'sample_{i}.png'
    try:
        r = requests.get(url, timeout=30)
        r.raise_for_status()
        dest.write_bytes(r.content)
        print('Downloaded:', dest)
    except Exception as e:
        print('Failed:', url, e)

## Preprocess (resize/crop to 1024)
Verify model-required size on the model card. This stub uses 1024. Outputs can overwrite or be stored separately.

In [None]:
from PIL import Image
import numpy as np

TARGET = 1024
for p in INPUT_DIR.glob('*.png'):
    try:
        img = Image.open(p).convert('RGB')
        img = img.resize((TARGET, TARGET), Image.BICUBIC)
        img.save(p)
        print('Preprocessed:', p)
    except Exception as e:
        print('Failed preprocess:', p, e)

## TODO: Encode → Decode with PixCell-1024
# Replace stubs with the official encode/decode API per the model card: https://huggingface.co/StonyBrook-CVLab/PixCell-1024

In [None]:
import torch
from huggingface_hub import hf_hub_download
from diffusers import AutoencoderKL, DiffusionPipeline
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from pathlib import Path
from PIL import Image
import einops
import matplotlib.pyplot as plt

OUT_RECONS = Path('../data/outputs/reconstructions')
OUT_RECONS.mkdir(parents=True, exist_ok=True)

device = 'mps' if hasattr(__import__('torch').backends, 'mps') and __import__('torch').backends.mps.is_available() else 'cpu'

# Load SD3 VAE
sd3_vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3.5-large", subfolder="vae")

# Load PixCell-1024 pipeline
pipeline = DiffusionPipeline.from_pretrained(
    "StonyBrook-CVLab/PixCell-1024",
    vae=sd3_vae,
    custom_pipeline="StonyBrook-CVLab/PixCell-pipeline",
    trust_remote_code=True,
    torch_dtype=torch.float16,
)
pipeline.to(device);

# Load UNI-2h for conditioning
timm_kwargs = {
            'img_size': 224,
            'patch_size': 14,
            'depth': 24,
            'num_heads': 24,
            'init_values': 1e-5,
            'embed_dim': 1536,
            'mlp_ratio': 2.66667*2,
            'num_classes': 0,
            'no_embed_class': True,
            'mlp_layer': timm.layers.SwiGLUPacked,
            'act_layer': torch.nn.SiLU,
            'reg_tokens': 8,
            'dynamic_img_size': True
        }
uni_model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
transform = create_transform(**resolve_data_config(uni_model.pretrained_cfg, model=uni_model))
uni_model.eval()
uni_model.to(device);

def encode(image):
    # Rearrange 1024x1024 image into 16 256x256 patches
    uni_patches = np.array(image)
    uni_patches = einops.rearrange(uni_patches, '(d1 h) (d2 w) c -> (d1 d2) h w c', d1=4, d2=4)
    uni_input = torch.stack([transform(Image.fromarray(item)) for item in uni_patches])

    # Extract UNI embeddings
    with torch.inference_mode():
        uni_emb = uni_model(uni_input.to(device))

    # reshape UNI to (bs, 16, D)
    uni_emb = uni_emb.unsqueeze(0)
    return uni_emb

def decode(latent, guidance_scale=1.5):
    # Get unconditional embedding for classifier-free guidance
    uncond = pipeline.get_unconditional_embedding(latent.shape[0])
    # Generate new samples
    with torch.amp.autocast(device.type):
        samples = pipeline(uni_embeds=latent, negative_uni_embeds=uncond, guidance_scale=guidance_scale).images
    return samples[0] # return single image

# Example loop
for p in INPUT_DIR.glob('*.png'):
    try:
        img = Image.open(p).convert('RGB')
        uni_emb = encode(img)
        rec = decode(uni_emb)
        rec_path = OUT_RECONS / f'{p.stem}_recon.png'
        rec.save(rec_path)
        print('Saved recon:', rec_path)
    except Exception as e:
        print(f"Error processing {p}: {e}")