from https://arxiv.org/pdf/2410.10733

DEEP COMPRESSION AUTOENCODER FOR EFFICIENT HIGH-RESOLUTION DIFFUSION MODELS

Junyu Chen1,2∗, Han Cai3∗†, Junsong Chen3, Enze Xie3, Shang Yang1, Haotian Tang1, Muyang Li1, Yao Lu3, Song Han1,3

1MIT 2Tsinghua University 3NVIDIA

https://github.com/mit-han-lab/efficientvit

In [None]:
from init_notebook import *

In [None]:
class ResampleBlock(nn.Module):

    def __init__(
            self,
            num_channels: int,
            kernel_size: int = 3,
            down: bool = True,
    ):
        assert kernel_size % 2 == 1, f"Must have odd `kernel_size`, got {kernel_size}"
        super().__init__()
        self._down = down
        if self._down:
            self.conv = nn.Conv2d(num_channels, num_channels * 2, kernel_size, padding=int(math.floor(kernel_size / 2)), stride=2)
        else:
            self.conv = nn.ConvTranspose2d(num_channels, num_channels * 2, kernel_size, padding=int(math.floor(kernel_size / 2)), stride=2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.ndim == 4, x.shape
        B, C, H, W = x.shape
        assert C / 2 == C // 2, x.shape

        if self._down:
            r = F.pixel_unshuffle(x, 2)
            r = (r[:, :C*2] + r[:, C*2:]) / 2.
        else:
            r = F.pixel_shuffle(x, 2)
            r = torch.concat([r, r], dim=-3)
            return r
            
        return self.conv(x) + r


ch = 4
ks = 5
b1, b2, b3 = ResampleBlock(ch, ks), ResampleBlock(ch * 2, ks), ResampleBlock(ch * 4, ks)
print(f"params: {num_module_parameters(b1):,} -> {num_module_parameters(b2):,} -> {num_module_parameters(b3):,}")
y = b1(torch.ones((1, ch, 16, 16)))
print(y.shape)
y = b2(y)
print(y.shape)
y = b3(y)
print(y.shape)

print("--up--")
bu = ResampleBlock(ch, down=False)
y = bu(y)
print(y.shape)


In [None]:
class ResAE(nn.Module):

    def __init__(
            self,
            input_channels: int,
            num_channels: int,
            kernel_size: int = 3,
            num_layers: int = 4,
            verbose: bool = False,
    ):
        super().__init__()
        self.verbose = verbose
        self.input = nn.Conv2d(input_channels, num_channels, 1)
        self.down_blocks = nn.ModuleList()
        self.up_blocks = nn.ModuleList()
        ch = num_channels
        for i in range(num_layers):
            dub_chan = False#i % 2 == 0
            self.down_blocks.append(ResampleBlock(ch, kernel_size, down=True))
            if not dub_chan:
                self.down_blocks.append(nn.Conv2d(ch * 2, ch, 1))
            self.up_blocks.insert(0, ResampleBlock(ch, kernel_size, down=False))
            if not dub_chan:
                self.up_blocks.insert(0, nn.ConvTranspose2d(ch, ch * 2, 1))
            if dub_chan:
                ch *= 2
        self.output = nn.Conv2d(num_channels, input_channels, 1)
        
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        y = self.input(x)
        for b in self.down_blocks:
            if self.verbose:
                bs = str(b).replace('\n', ' ')
                print(f"{y.shape} -> {bs}")
            y = b(y)
        if self.verbose:
            print(y.shape)
        return y

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        y = x
        for b in self.up_blocks:
            if self.verbose:
                bs = str(b).replace('\n', ' ')
                print(f"{y.shape} -> {bs}")    
            y = b(y)
        if self.verbose:            
            print(y.shape)
        return self.output(y)
            
    def forward(self, x):
        return self.decode(self.encode(x))

ch = 3
ae = ResAE(ch, 16, num_layers=4)
#display(ae)
print(f"params: {num_module_parameters(ae):,}")

x = torch.ones(1, ch, 32, 32)
c = ae.encode(x)
print("decode")
y = ae.decode(c)
print("result")
print(f"{x.shape} -> {c.shape} -> {y.shape}, compression ratio: {math.prod(x.shape) / math.prod(c.shape)}")

display(VF.to_pil_image(y[0, :3].clamp(0, 1)))

display(ae)

In [None]:
3*32*32, 512*2*2

In [None]:
df = pd.DataFrame()
df["x"] = np.linspace(-5, 5, 100)
df["gelu"] = nn.GELU()(torch.tensor(df["x"])).numpy()
df["gelu-tanh"] = nn.GELU(approximate="tanh")(torch.tensor(df["x"])).numpy()
df["diff"] = df.gelu - df["gelu-tanh"]
px.line(df.set_index("x"), height=1000)