In [1]:
import numpy as np
from PIL import Image
import os
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Resnet(torch.nn.Module):

    def __init__(self, dim_in, dim_out):
        super().__init__()

        self.s = torch.nn.Sequential(
            torch.nn.GroupNorm(num_groups=32,
                               num_channels=dim_in,
                               eps=1e-6,
                               affine=True),
            torch.nn.SiLU(),
            torch.nn.Conv2d(dim_in,
                            dim_out,
                            kernel_size=3,
                            stride=1,
                            padding=1),
            torch.nn.GroupNorm(num_groups=32,
                               num_channels=dim_out,
                               eps=1e-6,
                               affine=True),
            torch.nn.SiLU(),
            torch.nn.Conv2d(dim_out,
                            dim_out,
                            kernel_size=3,
                            stride=1,
                            padding=1),
        )

        self.res = None
        if dim_in != dim_out:
            self.res = torch.nn.Conv2d(dim_in,
                                       dim_out,
                                       kernel_size=1,
                                       stride=1,
                                       padding=0)

    def forward(self, x):
        #x -> [1, 128, 10, 10]

        res = x
        if self.res:
            #[1, 128, 10, 10] -> [1, 256, 10, 10]
            res = self.res(x)

        #[1, 128, 10, 10] -> [1, 256, 10, 10]
        return res + self.s(x)

In [3]:
class Atten(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.norm = torch.nn.GroupNorm(num_channels=128,
                                       num_groups=32,
                                       eps=1e-6,
                                       affine=True)

        self.q = torch.nn.Linear(128, 128)
        self.k = torch.nn.Linear(128, 128)
        self.v = torch.nn.Linear(128, 128)
        self.out = torch.nn.Linear(128, 128)

    def forward(self, x):
        #x -> [1, 512, 64, 64]
        res = x

        #norm,维度不变
        #[1, 512, 64, 64]
        x = self.norm(x)

        #[1, 512, 64, 64] -> [1, 512, 4096] -> [1, 4096, 512]
        x = x.flatten(start_dim=2).transpose(1, 2)

        #线性运算,维度不变
        #[1, 4096, 512]
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)

        #[1, 4096, 512] -> [1, 512, 4096]
        k = k.transpose(1, 2)

        #[1, 4096, 512] * [1, 512, 4096] -> [1, 4096, 4096]
        #0.044194173824159216 = 1 / 512**0.5
        #atten = q.bmm(k) * 0.044194173824159216

        #照理来说应该是等价的,但是却有很小的误差
        atten = torch.baddbmm(torch.empty(1, 1024, 1024, device=q.device),
                              q,
                              k,
                              beta=0,
                              alpha=0.044194173824159216)

        atten = torch.softmax(atten, dim=2)

        #[1, 4096, 4096] * [1, 4096, 512] -> [1, 4096, 512]
        atten = atten.bmm(v)

        #线性运算,维度不变
        #[1, 4096, 512]
        atten = self.out(atten)

        #[1, 4096, 512] -> [1, 512, 4096] -> [1, 512, 64, 64]
        atten = atten.transpose(1, 2).reshape(-1, 128, 32, 32)

        #残差连接,维度不变
        #[1, 512, 64, 64]
        atten = atten + res

        return atten

In [4]:
class Pad(torch.nn.Module):

    def forward(self, x):
        return torch.nn.functional.pad(x, (0, 1, 0, 1),
                                       mode='constant',
                                       value=0)

In [5]:
class VAE(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.encoder = torch.nn.Sequential(
            #in
            torch.nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),

            #down
            torch.nn.Sequential(
                Resnet(32, 32),
                Resnet(32, 32),
                torch.nn.Sequential(
                    Pad(),
                    torch.nn.Conv2d(32, 32, 3, stride=2, padding=0),
                ),
            ),
            torch.nn.Sequential(
                Resnet(32, 64),
                Resnet(64, 64),
                torch.nn.Sequential(
                    Pad(),
                    torch.nn.Conv2d(64,64, 3, stride=2, padding=0),
                ),
            ),
            torch.nn.Sequential(
                Resnet(64, 128),
                Resnet(128, 128),
                torch.nn.Sequential(
                    Pad(),
                    torch.nn.Conv2d(128, 128, 3, stride=2, padding=0),
                ),
            ),
            torch.nn.Sequential(
                Resnet(128, 128),
                Resnet(128, 128),
            ),

            #mid
            torch.nn.Sequential(
                Resnet(128, 128),
                Atten(),
                Resnet(128, 128),
            ),

            #out
            torch.nn.Sequential(
                torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6),
                torch.nn.SiLU(),
                torch.nn.Conv2d(128, 8, 3, padding=1),
            ),

            #正态分布层
            torch.nn.Conv2d(8, 8, 1),
        )

        self.decoder = torch.nn.Sequential(
            #正态分布层
            torch.nn.Conv2d(4, 4, 1),

            #in
            torch.nn.Conv2d(4, 128, kernel_size=3, stride=1, padding=1),

            #middle
            torch.nn.Sequential(Resnet(128, 128), Atten(), Resnet(128,128 )),

            #up
            torch.nn.Sequential(
                Resnet(128, 128),
                Resnet(128, 128),
                torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
                torch.nn.Conv2d(128, 128, kernel_size=3, padding=1),
            ),
            torch.nn.Sequential(
                Resnet(128, 128),
                Resnet(128, 128),
                torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
                torch.nn.Conv2d(128, 128, kernel_size=3, padding=1),
            ),
            torch.nn.Sequential(
                Resnet(128, 64),
                Resnet(64, 64),
                torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
                torch.nn.Conv2d(64, 64, kernel_size=3, padding=1),
            ),
            torch.nn.Sequential(
                Resnet(64, 32),
                Resnet(32, 32),
            ),

            #out
            torch.nn.Sequential(
                torch.nn.GroupNorm(num_channels=32, num_groups=32, eps=1e-6),
                torch.nn.SiLU(),
                torch.nn.Conv2d(32, 3, 3, padding=1),
            ),
        )

    def sample(self, h):
        #h -> [1, 8, 64, 64]

        #[1, 4, 64, 64]
        mean = h[:, :4]
        logvar = h[:, 4:]
        std = logvar.exp()**0.5

        #[1, 4, 64, 64]
        h = torch.randn(mean.shape, device=mean.device)
        h = mean + std * h

        return h,mean,logvar

    def forward(self, x):
        #x -> [1, 3, 512, 512]

        #[1, 3, 512, 512] -> [1, 8, 64, 64]
        h = self.encoder(x)

        #[1, 8, 64, 64] -> [1, 4, 64, 64]
        h,mu,log_var = self.sample(h)

        #[1, 4, 64, 64] -> [1, 3, 512, 512]
        h = self.decoder(h)

        return h,mu,log_var
pred, mu, log_var =VAE()(torch.randn(1, 3, 256, 256))
pred.shape, mu.shape, log_var.shape

(torch.Size([1, 3, 256, 256]),
 torch.Size([1, 4, 32, 32]),
 torch.Size([1, 4, 32, 32]))

In [6]:
# 加载模型权重，并调整为单GPU或CPU格式
def load_model(model, model_path):
    # 加载保存的状态字典
    state_dict = torch.load(model_path)

    # 创建新的状态字典，移除'module.'前缀
    new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

    # 加载调整后的状态字典
    model.load_state_dict(new_state_dict)

# 加载模型并进行推断
model_path = "/data/run01/scz0ruj/model/best_model500.pth"
vae = VAE()
load_model(vae, model_path)
vae.eval()

VAE(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Sequential(
      (0): Resnet(
        (s): Sequential(
          (0): GroupNorm(32, 32, eps=1e-06, affine=True)
          (1): SiLU()
          (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (3): GroupNorm(32, 32, eps=1e-06, affine=True)
          (4): SiLU()
          (5): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (1): Resnet(
        (s): Sequential(
          (0): GroupNorm(32, 32, eps=1e-06, affine=True)
          (1): SiLU()
          (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (3): GroupNorm(32, 32, eps=1e-06, affine=True)
          (4): SiLU()
          (5): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (2): Sequential(
        (0): Pad()
        (1): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2))
      