In [1]:
import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from typing import Optional, Any
from inspect import isfunction

try:
    import xformers
    import xformers.ops
    XFORMERS_IS_AVAILBLE = True
except:
    XFORMERS_IS_AVAILBLE = False
    print("No module 'xformers'. Proceeding without it.")


# Test

In [4]:
def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

class Encoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
                 **ignore_kwargs):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        # downsampling
        self.conv_in = torch.nn.Conv2d(in_channels,
                                       self.ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)
        curr_res = resolution
        in_ch_mult = (1,)+tuple(ch_mult)
        self.in_ch_mult = in_ch_mult
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)
        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        2*z_channels if double_z else z_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x):
        # timestep embedding
        print("Down is:")
        #print(self.down)
        temb = None

        # downsampling
        hs = [self.conv_in(x)]
        print(f"After down sampling, hs is {len(hs)} {hs[0].shape}")
        for i_level in range(self.num_resolutions):
            print(f"level {i_level+1}/{self.num_resolutions}")
            for i_block in range(self.num_res_blocks):
                print(f"block {i_block+1}/{self.num_res_blocks}")
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            if i_level != self.num_resolutions-1:
                hs.append(self.down[i_level].downsample(hs[-1]))
        print(f"Now computing middle block1->attn1->block2")
        # middle
        print(f"h shape is {len(h)} {h[0].shape}")
        h = hs[-1]
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)
        print(f"h shape is {len(h)} {h[0].shape}")
        print(f"Now computing end norn_out->attn1->conv_out")
        # end
        print(f"h shape is {len(h)} {h[0].shape}")
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        print(f"h shape is {len(h)} {h[0].shape}")
        return h

In [5]:
from sdvae import MemoryEfficientAttnBlock,Normalize,Downsample,make_attn,ResnetBlock

encoder = Encoder(
    ch=2,  # Base channel dimensionality
    out_ch=None,  # This parameter might be unused based on your class definition
    ch_mult=(1, 2, 4, 8),  # Channel multipliers for each level
    num_res_blocks=2,  # Number of residual blocks per level
    attn_resolutions=(20, 40),  # Resolutions to apply attention, adjust based on your needs
    dropout=0.1,  # Dropout rate for regularization
    resamp_with_conv=True,  # Use convolutional downsampling
    in_channels=1,  # Match the input channel size
    resolution=640,  # Input resolution (for understanding, not a direct parameter)
    z_channels=128,  # Half the desired latent space dimensionality if double_z=True
    double_z=True,  # Output both mean and variance for the latent distribution
    use_linear_attn=False,  # Based on preference and model complexity
    attn_type="vanilla"  # Type of attention, if used
)
x = torch.randn(10, 1, 640, 640)
encoder(x)

making attention of type 'vanilla-xformers' with 16 in_channels
building MemoryEfficientAttnBlock with 16 in_channels...
Down is:
After down sampling, hs is 1 torch.Size([10, 2, 640, 640])
level 1/4
block 1/2
Now in a ResNet. We do norm1->nonlinear->conv1->norm2->nonlinear->dropout->conv2->x+h
Input shape is torch.Size([10, 2, 640, 640])
Output shape is torch.Size([10, 2, 640, 640])
block 2/2
Now in a ResNet. We do norm1->nonlinear->conv1->norm2->nonlinear->dropout->conv2->x+h
Input shape is torch.Size([10, 2, 640, 640])
Output shape is torch.Size([10, 2, 640, 640])
level 2/4
block 1/2
Now in a ResNet. We do norm1->nonlinear->conv1->norm2->nonlinear->dropout->conv2->x+h
Input shape is torch.Size([10, 2, 320, 320])
Output shape is torch.Size([10, 4, 320, 320])
block 2/2
Now in a ResNet. We do norm1->nonlinear->conv1->norm2->nonlinear->dropout->conv2->x+h
Input shape is torch.Size([10, 4, 320, 320])
Output shape is torch.Size([10, 4, 320, 320])
level 3/4
block 1/2
Now in a ResNet. We do 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (12800x80 and 16x512)