In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, in_ch, e_ch, opts=None, use_fp16=False):
        super(Encoder, self).__init__()
        self.in_ch = in_ch
        self.e_ch = e_ch
        self.opts = opts if opts is not None else {}
        self.use_fp16 = use_fp16

        if 't' in self.opts:
            self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5)
            self.res1 = ResidualBlock(self.e_ch)
            self.down2 = Downscale(self.e_ch, self.e_ch * 2, kernel_size=5)
            self.down3 = Downscale(self.e_ch * 2, self.e_ch * 4, kernel_size=5)
            self.down4 = Downscale(self.e_ch * 4, self.e_ch * 8, kernel_size=5)
            self.down5 = Downscale(self.e_ch * 8, self.e_ch * 8, kernel_size=5)
            self.res5 = ResidualBlock(self.e_ch * 8)
        else:
            n_downscales = 4 if 't' not in self.opts else 5
            self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=n_downscales, kernel_size=5)

    def forward(self, x):
        if self.use_fp16:
            x = x.half()

        if 't' in self.opts:
            x = self.down1(x)
            x = self.res1(x)
            x = self.down2(x)
            x = self.down3(x)
            x = self.down4(x)
            x = self.down5(x)
            x = self.res5(x)
        else:
            x = self.down1(x)

        x = torch.flatten(x, 1)

        if 'u' in self.opts:
            x = F.normalize(x, p=2, dim=-1)

        if self.use_fp16:
            x = x.float()

        return x

    def get_out_res(self, res):
        return res // (2**4 if 't' not in self.opts else 2**5)

    def get_out_ch(self):
        return self.e_ch * 8

# 下面是 Downscale 和 ResidualBlock 的示例实现（需要根据你的情况具体实现）
class Downscale(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=5):
        super(Downscale, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=2, padding=kernel_size//2)

    def forward(self, x):
        return F.relu(self.conv(x))

class ResidualBlock(nn.Module):
    def __init__(self, ch):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(ch, ch, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(ch, ch, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return F.relu(x + residual)

class DownscaleBlock(nn.Module):
    def __init__(self, in_ch, out_ch, n_downscales, kernel_size=5):
        super(DownscaleBlock, self).__init__()
        layers = []
        for _ in range(n_downscales):
            layers.append(Downscale(in_ch, out_ch, kernel_size))
            in_ch = out_ch
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)
