In [3]:
#coding=utf8
import numpy as np
import torch
import PIL
from PIL import Image


#载入图片
def load_image(path):
    image = Image.open(path).convert("RGB")
    image = np.array(image).astype(np.float32) / 255.0   #(512, 512, 3)
    image = image[None].transpose(0, 3, 1, 2)           # (1, 3, 512, 512)
    image = torch.from_numpy(image)
    return 2.*image - 1.

#保存图片
def save_image(samples, path):     
    samples = 255 * (samples/2+0.5).clamp(0,1)    # (1, 3, 512, 512)
    samples = samples.detach().numpy()
    samples = samples.transpose(0, 2, 3, 1)       #(1, 512, 512, 3)
    image = samples[0]                            #(512, 512, 3)
    image = Image.fromarray(image.astype(np.uint8))
    image.save(path)

def test_load_and_save_img():
    img = load_image("girl.jpg")
    save_image(img, "girl2.jpg")

test_load_and_save_img()

### VAE模型
下列代码指示VAE模型的使用方法，其中load_vae为根据配置init_config去初始化模型，然后从预训练模型model.ckpt中读取参数，预训练模型的first_stage_model即指代VAE模型。

In [4]:
from ldm.models.autoencoder import AutoencoderKL
#VAE模型
def load_vae():
    #初始化模型
    init_config = {
        "embed_dim": 4,
        "monitor": "val/rec_loss",
        "ddconfig":{
          "double_z": True,
          "z_channels": 4,
          "resolution": 256,
          "in_channels": 3,
          "out_ch": 3,
          "ch": 128,
          "ch_mult":[1,2,4,4],
          "num_res_blocks": 2,
          "attn_resolutions": [],
          "dropout": 0.0,
        },
        "lossconfig":{
          "target": "torch.nn.Identity"
        }
    }
    vae = AutoencoderKL(**init_config)
    #加载预训练参数
    pl_sd = torch.load("model.ckpt", map_location="cpu")
    sd = pl_sd["state_dict"]
    model_dict = vae.state_dict()
    for k, v in model_dict.items():
        model_dict[k] = sd["first_stage_model."+k]
    vae.load_state_dict(model_dict, strict=False)

    vae.eval()
    return vae

#测试vae模型
def test_vae():
    vae = load_vae()
    img = load_image("girl_and_horse.png")  #(1,3,512,512)   
    latent = vae.encode(img).sample()       #(1,4,64,64)
    samples = vae.decode(latent)            #(1,3,512,512)
    save_image(samples,"vae.png")

test_vae()

ModuleNotFoundError: No module named 'ldm'