In [10]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
import zarr
import os
import einops
from einops import rearrange
%load_ext autoreload

from transformers import ViTForImageClassification, ViTImageProcessor, ViTFeatureExtractor, ViTConfig, ViTModel, pipeline
from diffusers import AutoencoderKL
from torch.utils.data import DataLoader

from utils import get_directories, get_imgs, get_device
from data_preprocessing import preprocess_images,XrdDataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
# Load data
path = 'data'
feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
dataset = XrdDataset(data_dir=path, feature_extractor=feature_extractor)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
sample = next(iter(dataloader))
# Reshape, here sample shape is (40, 3, 224, 224)
sample = einops.rearrange(sample, 'f b c w h -> (f b) c w h')



Loading data/mfxl1025422_r0313_peaknet.0035.zarr/images


In [15]:
# Load the autoencoder model.
device = get_device()
path = "vae_model"
vae = AutoencoderKL.from_pretrained("vae_model")
vae.eval()
vae.to(device)
# Convert into float32 to fit pytorch model
sample = sample.to(device).float()

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.


In [16]:
# Reconstruct the image
with torch.no_grad():
    # Encode the images
    latent_dist = vae.encode(sample).latent_dist
    z = latent_dist.sample()
    # Decode the latent representation
    recon = vae.decode(z).sample
    

RuntimeError: MPS backend out of memory (MPS allocated: 8.41 GB, other allocations: 3.09 MB, max allowed: 9.07 GB). Tried to allocate 980.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
# Visualization
def show_image(tensor, title="Image"):
    # Take out the first image in the batch
    tensor = tensor[0]
    # Take out the first channel
    tensor = tensor[0]
    image = tensor.cpu().numpy()
    plt.imshow(image, cmap="gray")
    plt.title(title)
    plt.axis("off")

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
show_image(sample, "Original")
plt.subplot(1, 2, 2)
show_image(recon, "Reconstructed")
plt.show()
