# Model structure

<img src='https://drive.google.com/uc?id=1xq5YGrLMtNO4lxMw7sWa2VoQXgDpx-y0' width="800"/>


In [None]:
class CARE_Net(nn.Module):
    def __init__(self, c_in=2, c_out=1, embed_dim=256, dim = 64, device="cuda"):
        super().__init__()


        self.device = device
        self.embed_dim = embed_dim

        self.inc = nn.Sequential(nn.Conv2d(c_in, dim, kernel_size=3, padding=1, bias=False),
                                LayerNorm2d(dim, elementwise_affine=False, eps=1e-6))
        self.inc_block1 = Double_Convnext(dim,embed_dim)
        self.sa0 = LinearAttention(dim)

        self.down1 = Down(dim, dim*2, embed_dim)
        self.sa1 = LinearAttention(dim*2)

        self.down2 = Down(dim*2, dim*4, embed_dim)
        self.sa2 = LinearAttention(dim*4)

        self.down3 = nn.Sequential(
                    LayerNorm2d(dim*4, elementwise_affine=False, eps=1e-6),
                    nn.Conv2d(dim*4, dim*8, kernel_size=2, stride=2),)

        self.bot1_1 = ResBlock(dim*8)
        self.bot1_2 = ResBlock(dim*8)
        self.bot_sa1 = Attention(dim*8)

        self.bot2_1 = ResBlock(dim*8)
        self.bot2_2 = ResBlock(dim*8)
        self.bot_sa2 = Attention(dim*8)

        self.bot3_1 = ResBlock(dim*8)
        self.bot3_2 = ResBlock(dim*8)
        self.bot_sa3 = Attention(dim*8)

        self.up1 = Up(dim*8, dim*4, dim*4, embed_dim)
        self.sa4 = LinearAttention(dim*4)

        self.up2 = Up(dim*4, dim*2, dim*2, embed_dim)
        self.sa5 = LinearAttention(dim*2)

        self.up3 = Up(dim*2, dim, dim, embed_dim)
        self.sa6 = LinearAttention(dim)

        self.outc = nn.Conv2d(dim, c_out, kernel_size=1)

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (
            10000
            ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc


    def forward(self, x,  t):

        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.embed_dim)


        x1 = self.inc(x)
        x1 = self.inc_block1(x1,t)
        x1 = self.sa0(x1)

        x2 = self.down1(x1, t)
        x2 = self.sa1(x2)

        x3 = self.down2(x2, t)
        x3 = self.sa2(x3)

        x4 = self.down3(x3)
        x4 = self.bot1_1(x4)
        x4 = self.bot1_2(x4)
        x4 = self.bot_sa1(x4)

        x4 = self.bot2_1(x4)
        x4 = self.bot2_2(x4)
        x4 = self.bot_sa2(x4)

        x4 = self.bot3_1(x4)
        x4 = self.bot3_2(x4)
        x4 = self.bot_sa3(x4)

        x = self.up1(x4, t, x3)
        x = self.sa4(x)

        x = self.up2(x, t, x2)
        x = self.sa5(x)

        x = self.up3(x, t,  x1)
        x = self.sa6(x)

        output = self.outc(x)
        return output

 # ConvNeXt block

In [None]:
class ResBlock(nn.Module):
    def __init__(self, c, c_emb = None, c_skip=0, kernel_size=7, dropout=0.0):
        super().__init__()
        self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
        self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
        self.channelwise = nn.Sequential(
            nn.Linear(c, c * 4),
            nn.GELU(),
            GlobalResponseNorm(c * 4),
            nn.Dropout(dropout),
            nn.Linear(c * 4, c)
        )
        self.mlp = nn.Sequential(
            nn.GELU(),
            nn.Linear(c_emb, c)
        ) if exists(c_emb) else None

    def forward(self, x, t=None, x_skip=None):
        x_res = x
        if x_skip is not None:
            x = torch.cat([x, x_skip], dim=1)
        x = self.depthwise(x)
        if t is not None:
            emb = self.mlp(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
            x = x + emb
        x = self.norm(x).permute(0, 2, 3, 1)
        x = self.channelwise(x).permute(0, 3, 1, 2)
        return x + x_res
