In [1]:
import torch
from src.models.autoencoder_kl import AutoencoderKL
from src.models.emasc import EMASC

In [2]:
input_channels_list = [128, 128, 128, 256, 512]
output_channels_list = [128, 256, 512, 512, 512]
emasc = EMASC(input_channels_list, output_channels_list).to('cuda', dtype=torch.float16)

In [3]:
vae = AutoencoderKL.from_pretrained(
    'stable-diffusion-v1-5/stable-diffusion-inpainting',
    subfolder='vae',
    torch_dtype=torch.float16,
    use_safetensors=False
).to('cuda')

In [4]:
bs, c, h, w = 16, 3, 512, 384
x = torch.randn(bs, c, h, w).to('cuda', dtype=torch.float16) # move the input to device & cast it to float16 data type.

In [16]:
with torch.no_grad():
    # with torch.amp.autocast('cuda'):
    posterior, intermediate_features = vae.encode(x)
    print(posterior.latent_dist.sample().shape)
    print('')
    print('Intermediate features at each of Downblock Encoder of VAE:')
    print(f'([batch size, channel, height, weight])')
    for in_feats in intermediate_features:
        print(f'  {in_feats.shape}')

torch.Size([16, 4, 64, 48])

Intermediate features at each of Downblock Encoder of VAE:
([batch size, channel, height, weight])
  torch.Size([16, 128, 512, 384])
  torch.Size([16, 128, 512, 384])
  torch.Size([16, 128, 256, 192])
  torch.Size([16, 256, 128, 96])
  torch.Size([16, 512, 64, 48])


In [6]:
# Emasc outputs
emasc_outputs = emasc(intermediate_features)
for i in range(len(emasc_outputs)):
    print(emasc_outputs[i].shape)

torch.Size([16, 128, 512, 384])
torch.Size([16, 256, 512, 384])
torch.Size([16, 512, 256, 192])
torch.Size([16, 512, 128, 96])
torch.Size([16, 512, 64, 48])


In [7]:
latents = posterior.latent_dist.sample()
latents.shape

torch.Size([16, 4, 64, 48])

In [8]:
with torch.no_grad():
    vae.decode(z=latents, intermediate_features=emasc_outputs)

torch.Size([16, 512, 64, 48])	torch.Size([16, 512, 64, 48])
torch.Size([16, 512, 128, 96])	torch.Size([16, 512, 128, 96])
torch.Size([16, 512, 256, 192])	torch.Size([16, 512, 256, 192])
torch.Size([16, 256, 512, 384])	torch.Size([16, 256, 512, 384])
torch.Size([16, 128, 512, 384])	torch.Size([16, 128, 512, 384])
