In [1]:
import os
import sys
import torch

base_directory = "../"
sys.path.insert(0, base_directory)

from stable_diffusion2.latent_diffusion import LatentDiffusion
from stable_diffusion2.utils.model import *
from stable_diffusion2.utils.utils import SectionManager as section
from stable_diffusion2.model.clip.clip_embedder import CLIPTextEmbedder

from stable_diffusion2.model.vae.autoencoder import Autoencoder
from stable_diffusion2.model.vae.encoder import Encoder
from stable_diffusion2.model.vae.decoder import Decoder

from stable_diffusion2.model.unet.unet import UNetModel

from pathlib import Path

In [2]:
CHECKPOINT_PATH = os.path.abspath('../input/model/v1-5-pruned-emaonly.ckpt')

EMBEDDER_PATH = os.path.abspath('../input/model/clip/clip_embedder.ckpt')
TOKENIZER_PATH = os.path.abspath('../input/model/clip/clip_tokenizer.ckpt')
TRANSFORMER_PATH = os.path.abspath('../input/model/clip/clip_transformer.ckpt')

UNET_PATH = os.path.abspath('../input/model/unet/unet.ckpt')

AUTOENCODER_PATH = os.path.abspath('../input/model/autoencoder/autoencoder.ckpt')
ENCODER_PATH = os.path.abspath('../input/model/autoencoder/encoder.ckpt')
DECODER_PATH = os.path.abspath('../input/model/autoencoder/decoder.ckpt')

LATENT_DIFFUSION_PATH = os.path.abspath('../input/model/latent_diffusion/latent_diffusion.ckpt')

In [3]:
def initialize_encoder(device = None, 
                        z_channels=4,
                        in_channels=3,
                        channels=128,
                        channel_multipliers=[1, 2, 4, 4],
                        n_resnet_blocks=2) -> Encoder:
    
    device = check_device(device)
    # Initialize the encoder
    with section('encoder initialization'):
        encoder = Encoder(z_channels=z_channels,
                        in_channels=in_channels,
                        channels=channels,
                        channel_multipliers=channel_multipliers,
                        n_resnet_blocks=n_resnet_blocks).to(device)
    return encoder

In [4]:
encoder = initialize_encoder()
encoder.save(ENCODER_PATH)

Using cuda:0: NVIDIA GeForce RTX 3080 Ti. Slow on CPU.
Starting encoder initialization...
Finished encoder initialization in 0.27 seconds


In [5]:
def initialize_decoder(device = None, 
                        out_channels=3,
                        z_channels=4,
                        channels=128,
                        channel_multipliers=[1, 2, 4, 4],
                        n_resnet_blocks=2) -> Decoder:
    
    device = check_device(device)
    with section('decoder initialization'):
        decoder = Decoder(out_channels=out_channels,
                        z_channels=z_channels,
                        channels=channels,
                        channel_multipliers=channel_multipliers,
                        n_resnet_blocks=n_resnet_blocks).to(device)    
    return decoder

In [6]:
decoder = initialize_decoder()
decoder.save(DECODER_PATH)

Using cuda:0: NVIDIA GeForce RTX 3080 Ti. Slow on CPU.
Starting decoder initialization...
Finished decoder initialization in 0.28 seconds


In [7]:
def initialize_autoencoder(device = None, encoder = None, decoder = None, emb_channels = 4, z_channels = 4) -> Autoencoder:
    device = check_device(device)
    # Initialize the autoencoder
    with section('autoencoder initialization'):
        if encoder is None:
            encoder = initialize_encoder(device=device, z_channels=z_channels)
        if decoder is None:
            decoder = initialize_decoder(device=device, z_channels=z_channels)
        
        autoencoder = Autoencoder(emb_channels=emb_channels,
                                    encoder=encoder,
                                    decoder=decoder,
                                    z_channels=z_channels).to(device)
    return autoencoder

In [8]:
autoencoder = initialize_autoencoder(encoder=encoder, decoder=decoder)

Using cuda:0: NVIDIA GeForce RTX 3080 Ti. Slow on CPU.
Starting autoencoder initialization...
Finished autoencoder initialization in 0.00 seconds


In [9]:
img = load_img("test_img.jpg").to("cuda:0")

In [10]:
encoded_img = autoencoder.encoder(img)

In [11]:
del autoencoder
del decoder
del encoder
torch.cuda.empty_cache()

In [11]:
with section("autoencoder initialization from saved submodels"):
    autoencoder_3 = Autoencoder(emb_channels=4, z_channels=4)
    autoencoder_3.load_submodels(encoder_path=ENCODER_PATH, decoder_path=DECODER_PATH)

Starting autoencoder initialization from saved submodels...
Finished autoencoder initialization from saved submodels in 0.25 seconds


In [12]:
encoded_img_3 = autoencoder_3.encoder(img)

In [13]:
encoded_img.shape, encoded_img_3.shape

(torch.Size([1, 8, 64, 64]), torch.Size([1, 8, 64, 64]))

In [14]:
torch.norm(encoded_img_3 - encoded_img)

tensor(0., device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)