In [7]:

import os
import numpy as np
import torch
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, Lambda, ToPILImage
from diffusers import AutoencoderKL

device = "cuda" if torch.cuda.is_available() else "cpu"
# Load pre-trained VAE model
vae_path = os.path.join("MAT", "vae")
vae = AutoencoderKL.from_pretrained(vae_path)
vae = vae.to(device)

# Define the preprocessing transformations for images
image_transform = Compose([
    Resize((vae.config.sample_size, vae.config.sample_size)),
    ToTensor(),
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Define the preprocessing transformations for masks
mask_transform = Compose([
    Resize((vae.config.sample_size, vae.config.sample_size)),
    ToTensor(),
    Lambda(lambda x: torch.cat([x, x, x], 0)),  # Repeat the mask channel to create a 3-channel image
])

def apply_mask_transform(tensor_mask):
    return mask_transform(tensor_mask)

def apply_image_transform(tensor_img):
    return image_transform(tensor_img)

def inverse_mask_transform(transformed_mask, original_size):
    # Extract a single channel from the 3-channel mask
    transformed_mask = transformed_mask[0, :, :].unsqueeze(0)
    
    # Resize to original size and convert to PIL Image
    resize = Resize(original_size)
    inverse_transform = Compose([resize])
    
    return ToPILImage()(inverse_transform(transformed_mask))

def inverse_image_transform(transformed_tensor, original_size):
    # Inverse Normalize
    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]
    transformed_tensor = transformed_tensor * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1)

    # Resize to original size and convert to PIL Image
    resize = Resize(original_size)
    inverse_transform = Compose([resize])
    inverse_image = inverse_transform(transformed_tensor.unsqueeze(0))

    return ToPILImage()(inverse_image[0])

def transform_to_3_512_512(tensor):
    padded_tensor = torch.nn.functional.pad(tensor[:3], (0, 416, 0, 416), mode='constant', value=0)
    # Distributing the fourth channel across the three channels
    for i in range(3):
        padded_tensor[i, :96, 480:512] = tensor[3, :, 32 * i:32 * (i + 1)]

    return padded_tensor

def revert_back_to_4_96_96(tensor):
    # Reconstructing the fourth channel from the three channels
    reconstructed_channel = torch.zeros((1, 96, 96))
    for i in range(3):
        reconstructed_channel[0, :, 32 * i:32 * (i + 1)] = tensor[i, :96, 480:512]

    # Removing padding to get back to 4 x 96 x 96
    reverted_tensor = torch.cat((tensor[:3, :96, :96], reconstructed_channel), dim=0)

    return reverted_tensor

def vae_encode(vae, tensor):
    # Encode the image to latent space
    tensor = tensor.to(device)
    with torch.no_grad():
        latent = vae.encode(tensor).latent_dist.sample().detach().cpu()
    return latent

def vae_decode(vae, tensor):
    tensor = tensor.to(device)
    with torch.no_grad():
        reconstructed_image = vae.decode(tensor).sample.detach().cpu()
    return reconstructed_image



In [8]:
from PIL import Image

In [12]:
img = Image.open(os.path.join("MAT", "test_sets", "Places", "masks", "mask1.png"))
# img = Image.open(os.path.join("MAT", "test_sets", "Places", "images", "test1.jpg"))
transformed_image = apply_mask_transform(img)
# transformed_image = apply_image_transform(img)


latent_image = vae_encode(vae, transformed_image.unsqueeze(0))[0]
latent_image_512 = transform_to_3_512_512(latent_image)
# MAT NETWORK
reconstructed_image_96 = revert_back_to_4_96_96(latent_image_512)
reconstructed_image = vae_decode(vae, reconstructed_image_96.unsqueeze(0))


# img2 = inverse_image_transform(transformed_image, 512)
img2 = inverse_mask_transform(transformed_image, 512)