In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [29]:
class MyBlock(nn.Module):
    def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, padding=1, activation=None, normalize=True):
        super(MyBlock, self).__init__()
        self.ln = nn.LayerNorm(shape)
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding)
        self.activation = nn.SiLU() if activation is None else activation
        self.normalize = normalize

    def forward(self, x):
        out = self.ln(x) if self.normalize else x
        out = self.conv1(out)
        out = self.activation(out)
        out = self.conv2(out)
        out = self.activation(out)
        return out

def sinusoidal_embedding(n, d):
    # Returns the standard positional embedding
    embedding = torch.zeros(n, d)
    wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
    wk = wk.reshape((1, d))
    t = torch.arange(n).reshape((n, 1))
    embedding[:,::2] = torch.sin(t * wk[:,::2])
    embedding[:,1::2] = torch.cos(t * wk[:,::2])

    return embedding

def _make_te(self, dim_in, dim_out):
    return nn.Sequential(
    nn.Linear(dim_in, dim_out),
    nn.SiLU(),
    nn.Linear(dim_out, dim_out)
    )

In [52]:
class UNet(nn.Module):
    def __init__(self, input_channels=1, shape=28, dims=[10, 20, 40, 80], n_steps=1000, time_emb_dim=100):
        super(UNet, self).__init__()
        self.output_channels = input_channels
        self.dims = dims

        # Sinusoidal embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)

        # Define layers for each block
        self.te = nn.ModuleList()
        self.blocks = nn.ModuleList()
        self.downs = nn.ModuleList()
        dims = [input_channels]+dims
        # First half
        for i in range(len(dims) - 1):
            self.te.append(self._make_te(time_emb_dim, dims[i]))
            self.blocks.append(nn.Sequential(
                MyBlock((dims[i], shape // (2 ** i), shape // (2 ** i)), dims[i], dims[i + 1]),
                MyBlock((dims[i + 1], shape // (2 ** i), shape // (2 ** i)), dims[i + 1], dims[i + 1]),
                MyBlock((dims[i + 1], shape // (2 ** i), shape // (2 ** i)), dims[i + 1], dims[i + 1])
            ))
            self.downs.append(nn.Conv2d(dims[i + 1], dims[i + 1], 4, 2, 1))

        # Bottleneck
        self.te_mid = self._make_te(time_emb_dim, dims[-1])
        self.b_mid = nn.Sequential(
            MyBlock((dims[-1], shape // (2 ** (len(dims) - 1)), shape // (2 ** (len(dims) - 1))), dims[-1], dims[-2]),
            MyBlock((dims[-2], shape // (2 ** (len(dims) - 1)), shape // (2 ** (len(dims) - 1))), dims[-2], dims[-2]),
            MyBlock((dims[-2], shape // (2 ** (len(dims) - 1)), shape // (2 ** (len(dims) - 1))), dims[-2], dims[-1])
        )

        # Second half
        self.ups = nn.ModuleList()
        self.te_out = nn.ModuleList()
        self.blocks_out = nn.ModuleList()

        for i in range(len(dims) - 1, 0, -1):
            self.ups.append(nn.ConvTranspose2d(dims[i], dims[i], 4, 2, 1))
            self.te_out.append(self._make_te(time_emb_dim, dims[i - 1]))
            self.blocks_out.append(nn.Sequential(
                MyBlock((dims[i], shape // (2 ** (len(dims) - i)), shape // (2 ** (len(dims) - i))), dims[i], dims[i - 1]),
                MyBlock((dims[i - 1], shape // (2 ** (len(dims) - i)), shape // (2 ** (len(dims) - i))), dims[i - 1], dims[i - 1]),
                MyBlock((dims[i - 1], shape // (2 ** (len(dims) - i)), shape // (2 ** (len(dims) - i))), dims[i - 1], dims[i - 1], normalize=False)
            ))

        self.conv_out = nn.Conv2d(dims[0], self.output_channels, 3, 1, 1)

    def forward(self, x, t):
        # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
        t = self.time_embed(t)
        n = len(x)
        out = x
        down_outputs = []

        # First half
        for i in range(len(self.blocks)):
            if i == 0 :
                 out = self.blocks[i](out + self.te[i](t).reshape(n, -1, 1, 1))
            else:
                print(f"before {i}", out.shape)
                out = self.blocks[i](self.downs[i-1](out) + self.te[i](t).reshape(n, -1, 1, 1))
                print(f"after {i}", out.shape)
            down_outputs.append(out)

        # Bottleneck
        out_mid = self.b_mid(self.downs[-1](down_outputs[-1]) + self.te_mid(t).reshape(n, -1, 1, 1))
        print(f"out_mid {out_mid.shape}")
        # Second half
        for i in range(len(self.blocks_out)):
            out = torch.cat((down_outputs[-(i + 1)], self.ups[i](out_mid)), dim=1)
            out = self.blocks_out[i](out + self.te_out[i](t).reshape(n, -1, 1, 1))

        out = self.conv_out(out)

        # Store feature maps during forward pass
        self.feature_maps = {
            'out_mid': out_mid,
            'out': out
        }

        return out

    def feature_extract(self, feature_maps, original_size):
        # Upsample feature maps to the original input size
        original_size = original_size[1:]
        upsampled_out_mid = F.interpolate(feature_maps['out_mid'], size=original_size, mode='bilinear', align_corners=False)
        upsampled_out = F.interpolate(feature_maps['out'], size=original_size, mode='bilinear', align_corners=False)

        # Concatenate the upsampled feature maps along the channel dimension
        concatenated_feature_maps = torch.cat([upsampled_out_mid, upsampled_out], dim=1)

        return concatenated_feature_maps

    def _make_te(self, dim_in, dim_out):
        return nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.SiLU(),
            nn.Linear(dim_out, dim_out))

In [53]:
unet = UNet(input_channels=4, shape=100, dims = [10,20,40])
print(unet)

UNet(
  (time_embed): Embedding(1000, 100)
  (te): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=100, out_features=4, bias=True)
      (1): SiLU()
      (2): Linear(in_features=4, out_features=4, bias=True)
    )
    (1): Sequential(
      (0): Linear(in_features=100, out_features=10, bias=True)
      (1): SiLU()
      (2): Linear(in_features=10, out_features=10, bias=True)
    )
    (2): Sequential(
      (0): Linear(in_features=100, out_features=20, bias=True)
      (1): SiLU()
      (2): Linear(in_features=20, out_features=20, bias=True)
    )
  )
  (blocks): ModuleList(
    (0): Sequential(
      (0): MyBlock(
        (ln): LayerNorm((4, 100, 100), eps=1e-05, elementwise_affine=True)
        (conv1): Conv2d(4, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (activation): SiLU()
      )
      (1): MyBlock(
        (ln): LayerNorm((10, 100, 100), eps=1e-05, elementwise_a

In [54]:
# Set the seed for reproducibility
torch.manual_seed(42)

# Define the input dimensions
batch_size = 2
channels = 4
height = 100
width = 100

# Generate random inputs for testing
input_data = torch.randn(batch_size, channels, height, width)
n_steps = 1000
time_steps = torch.randint(0, n_steps, (batch_size,))
output = unet(input_data, time_steps)

before 1 torch.Size([2, 10, 100, 100])
after 1 torch.Size([2, 20, 50, 50])
before 2 torch.Size([2, 20, 50, 50])
after 2 torch.Size([2, 40, 25, 25])
out_mid torch.Size([2, 40, 12, 12])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 25 but got size 24 for tensor number 1 in the list.