In [None]:
!pip install transformers accelerate datasets diffusers Pillow==9.4.0

# Load MNIST

In [None]:
from datasets import load_dataset, Dataset, DatasetDict

ds = load_dataset("ylecun/mnist")

In [None]:
ds["train"]

In [None]:
d=ds["train"][0]
d

# Load DCAE

In [None]:
import torch
from diffusers import AutoencoderDC
from transformers import Gemma2Model, GemmaTokenizerFast

model = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

dcae = AutoencoderDC.from_pretrained(model, subfolder="vae", torch_dtype=dtype).to(device)

# PIL to latent

In [None]:
import torchvision.transforms as T

def encode_pil(image, ae):
    # MNIST inputs are grayscale/BW
    image = image.convert('RGB')
    transform = T.Compose([
        T.Resize(256, antialias=True),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        lambda x: x.to(dtype=torch.bfloat16)
    ])

    image_tensor = transform(image)[None].to(device)

    with torch.no_grad():
        latent = ae.encode(image_tensor)
    return latent.latent

latent = encode_pil(d["image"], dcae)
latent.shape

# Process MNIST and upload

In [None]:
from tqdm import tqdm

dataset_latents = {}
splits = ["train", "test"]

for split in splits:
    print(split)
    dataset_latents[split]=[]
    
    for d in tqdm(ds[split]):
        pil, label = d["image"], d["label"]
        latent = encode_pil(pil, dcae).float().cpu()
        dataset_latents[split].append({
            "label": label,
            "latent": latent.numpy()
        })

In [None]:
dataset = DatasetDict({split: Dataset.from_list(dataset_latents[split]) for split in splits})
dataset

In [None]:
dataset.push_to_hub("g-ronimo/MNIST-latents_dc-ae-f32c32-sana-1.0", private=True, commit_message=model)