## 加密部分

In [1]:
#@title 加密encoder

import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, in_ch, e_ch, opts=None, use_fp16=False):
        super(Encoder, self).__init__()
        self.in_ch = in_ch
        self.e_ch = e_ch
        self.opts = opts if opts is not None else {}
        self.use_fp16 = use_fp16

        if 't' in self.opts:
            self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5)
            self.res1 = ResidualBlock(self.e_ch)
            self.down2 = Downscale(self.e_ch, self.e_ch * 2, kernel_size=5)
            self.down3 = Downscale(self.e_ch * 2, self.e_ch * 4, kernel_size=5)
            self.down4 = Downscale(self.e_ch * 4, self.e_ch * 8, kernel_size=5)
            self.down5 = Downscale(self.e_ch * 8, selfa.e_ch * 8, kernel_size=5)
            self.res5 = ResidualBlock(self.e_ch * 8)
        else:
            n_downscales = 4 if 't' not in self.opts else 5
            self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=n_downscales, kernel_size=5)

    def forward(self, x):
        if self.use_fp16:
            x = x.half()

        if 't' in self.opts:
            x = self.down1(x)
            x = self.res1(x)
            x = self.down2(x)
            x = self.down3(x)
            x = self.down4(x)
            x = self.down5(x)
            x = self.res5(x)
        else:
            x = self.down1(x)

        x = torch.flatten(x, 1)

        if 'u' in self.opts:
            x = F.normalize(x, p=2, dim=-1)

        if self.use_fp16:
            x = x.float()

        return x

    def get_out_res(self, res):
        return res // (2**4 if 't' not in self.opts else 2**5)

    def get_out_ch(self):
        return self.e_ch * 8

# 下面是 Downscale 和 ResidualBlock 的示例实现（需要根据你的情况具体实现）
class Downscale(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=5):
        super(Downscale, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=2, padding=kernel_size//2)

    def forward(self, x):
        return F.relu(self.conv(x))

class ResidualBlock(nn.Module):
    def __init__(self, ch):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(ch, ch, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(ch, ch, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return F.relu(x + residual)

class DownscaleBlock(nn.Module):
    def __init__(self, in_ch, out_ch, n_downscales, kernel_size=5):
        super(DownscaleBlock, self).__init__()
        layers = []
        for _ in range(n_downscales):
            layers.append(Downscale(in_ch, out_ch, kernel_size))
            in_ch = out_ch
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)


In [6]:
#@title 保存权重

# Example instantiation
model = Encoder(in_ch=3, e_ch=64, opts={'t': False}, use_fp16=False)
# Save model weights
torch.save(model.state_dict(), 'encoder_weights.pth')


## 解密部分

In [3]:
#@title 解密decoder

import torch
import torch.nn as nn
import torch.nn.functional as F

class Upscale(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3):
        super(Upscale, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=kernel_size // 2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

    def forward(self, x):
        x = self.upsample(x)
        return F.relu(self.conv(x))

class Decoder(nn.Module):
    def __init__(self, in_ch, d_ch, d_mask_ch):
        super(Decoder, self).__init__()
        self.upscale0 = Upscale(in_ch, d_ch * 8, kernel_size=3)
        self.upscale1 = Upscale(d_ch * 8, d_ch * 4, kernel_size=3)
        self.upscale2 = Upscale(d_ch * 4, d_ch * 2, kernel_size=3)
        self.res0 = ResidualBlock(d_ch * 8, kernel_size=3)
        self.res1 = ResidualBlock(d_ch * 4, kernel_size=3)
        self.res2 = ResidualBlock(d_ch * 2, kernel_size=3)

        self.upscalem0 = Upscale(in_ch, d_mask_ch * 8, kernel_size=3)
        self.upscalem1 = Upscale(d_mask_ch * 8, d_mask_ch * 4, kernel_size=3)
        self.upscalem2 = Upscale(d_mask_ch * 4, d_mask_ch * 2, kernel_size=3)

        self.out_conv = nn.Conv2d(d_ch * 2, 3, kernel_size=1)
        self.out_conv1 = nn.Conv2d(d_ch * 2, 3, kernel_size=3, padding=1)
        self.out_conv2 = nn.Conv2d(d_ch * 2, 3, kernel_size=3, padding=1)
        self.out_conv3 = nn.Conv2d(d_ch * 2, 3, kernel_size=3, padding=1)
        self.upscalem3 = Upscale(d_mask_ch * 2, d_mask_ch * 1, kernel_size=3)
        self.out_convm = nn.Conv2d(d_mask_ch * 1, 1, kernel_size=1)

    def forward(self, z):
        # Decoder path
        x = self.upscale0(z)
        x = self.res0(x)
        x = self.upscale1(x)
        x = self.res1(x)
        x = self.upscale2(x)
        x = self.res2(x)

        # Combine the output of multiple conv layers and apply pixel shuffle
        x = torch.cat([
            self.out_conv(x),
            self.out_conv1(x),
            self.out_conv2(x),
            self.out_conv3(x)
        ], dim=1)

        x = F.pixel_shuffle(x, upscale_factor=2)  # Equivalent to depth_to_space

        # Mask path
        m = self.upscalem0(z)
        m = self.upscalem1(m)
        m = self.upscalem2(m)
        m = self.upscalem3(m)
        m = torch.sigmoid(self.out_convm(m))

        return x, m

class ResidualBlock(nn.Module):
    def __init__(self, ch, kernel_size=3):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(ch, ch, kernel_size=kernel_size, padding=kernel_size//2)
        self.conv2 = nn.Conv2d(ch, ch, kernel_size=kernel_size, padding=kernel_size//2)

    def forward(self, x):
        residual = x
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return F.relu(x + residual)


In [None]:
#@title 保存权重

import torch
import torch.optim as optim

# Initialize the model
in_ch = 3
d_ch = 64
d_mask_ch = 32
decoder = Decoder(in_ch, d_ch, d_mask_ch)

# Create a dummy input tensor (e.g., batch of images with 3 channels and 64x64 size)
dummy_input = torch.randn(1, in_ch, 64, 64)  # Batch size of 1, 3 channels, 64x64

# Forward pass
x, m = decoder(dummy_input)

print(x.shape)  # Output shape of the main decoder path
print(m.shape)  # Output shape of the mask path


In [5]:
#@title 保存权重

# Save the model weights
torch.save(decoder.state_dict(), 'decoder_weights.pth')


## 处理inner

In [7]:
#@title inner模型

import torch
import torch.nn as nn
import torch.nn.functional as F

class Upscale(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3):
        super(Upscale, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=kernel_size // 2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

    def forward(self, x):
        x = self.upsample(x)
        return F.relu(self.conv(x))

class Inter(nn.Module):
    def __init__(self, in_ch, ae_ch, ae_out_ch, lowest_dense_res, opts=None, use_fp16=False):
        super(Inter, self).__init__()
        self.in_ch = in_ch
        self.ae_ch = ae_ch
        self.ae_out_ch = ae_out_ch
        self.lowest_dense_res = lowest_dense_res
        self.opts = opts if opts is not None else {}
        self.use_fp16 = use_fp16

        self.dense1 = nn.Linear(in_ch, ae_ch)
        self.dense2 = nn.Linear(ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch)

        if 't' not in self.opts:
            self.upscale1 = Upscale(ae_out_ch, ae_out_ch)

    def forward(self, inp):
        x = inp
        x = self.dense1(x)
        x = self.dense2(x)

        # Reshape the tensor to 4D (batch_size, channels, height, width)
        x = x.view(-1, self.ae_out_ch, self.lowest_dense_res, self.lowest_dense_res)

        if self.use_fp16:
            x = x.half()

        if 't' not in self.opts:
            x = self.upscale1(x)

        return x

    def get_out_res(self):
        return self.lowest_dense_res * 2 if 't' not in self.opts else self.lowest_dense_res

    def get_out_ch(self):
        return self.ae_out_ch


In [8]:
#@title 调试inner

# Parameters
in_ch = 256
ae_ch = 128
ae_out_ch = 64
lowest_dense_res = 16
opts = {}  # or {'t': True} to modify behavior
use_fp16 = False

# Create model
inter = Inter(in_ch, ae_ch, ae_out_ch, lowest_dense_res, opts, use_fp16)

# Dummy input
dummy_input = torch.randn(1, in_ch)  # Batch size of 1, flattened input

# Forward pass
output = inter(dummy_input)

print(output.shape)  # Output shape


torch.Size([1, 64, 32, 32])


In [9]:
#@title 保存权重

# Save the model weights
torch.save(inter.state_dict(), 'inter_weights.pth')
