In [1]:
#@title 链接Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip /content/drive/MyDrive/workspace.zip

## 加密部分

In [None]:
#@title 加密encoder

import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, in_ch, e_ch, opts=None, use_fp16=False):
        super(Encoder, self).__init__()
        self.in_ch = in_ch
        self.e_ch = e_ch
        self.opts = opts if opts is not None else {}
        self.use_fp16 = use_fp16

        if 't' in self.opts:
            self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5)
            self.res1 = ResidualBlock(self.e_ch)
            self.down2 = Downscale(self.e_ch, self.e_ch * 2, kernel_size=5)
            self.down3 = Downscale(self.e_ch * 2, self.e_ch * 4, kernel_size=5)
            self.down4 = Downscale(self.e_ch * 4, self.e_ch * 8, kernel_size=5)
            self.down5 = Downscale(self.e_ch * 8, selfa.e_ch * 8, kernel_size=5)
            self.res5 = ResidualBlock(self.e_ch * 8)
        else:
            n_downscales = 4 if 't' not in self.opts else 5
            self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=n_downscales, kernel_size=5)

    def forward(self, x):
        if self.use_fp16:
            x = x.half()

        if 't' in self.opts:
            x = self.down1(x)
            x = self.res1(x)
            x = self.down2(x)
            x = self.down3(x)
            x = self.down4(x)
            x = self.down5(x)
            x = self.res5(x)
        else:
            x = self.down1(x)

        x = torch.flatten(x, 1)

        if 'u' in self.opts:
            x = F.normalize(x, p=2, dim=-1)

        if self.use_fp16:
            x = x.float()

        return x

    def get_out_res(self, res):
        return res // (2**4 if 't' not in self.opts else 2**5)

    def get_out_ch(self):
        return self.e_ch * 8

# 下面是 Downscale 和 ResidualBlock 的示例实现（需要根据你的情况具体实现）
class Downscale(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=5):
        super(Downscale, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=2, padding=kernel_size//2)

    def forward(self, x):
        return F.relu(self.conv(x))

class ResidualBlock(nn.Module):
    def __init__(self, ch):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(ch, ch, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(ch, ch, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return F.relu(x + residual)

class DownscaleBlock(nn.Module):
    def __init__(self, in_ch, out_ch, n_downscales, kernel_size=5):
        super(DownscaleBlock, self).__init__()
        layers = []
        for _ in range(n_downscales):
            layers.append(Downscale(in_ch, out_ch, kernel_size))
            in_ch = out_ch
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)


In [None]:
#@title 保存权重

# Example instantiation
model = Encoder(in_ch=3, e_ch=64, opts={'t': False}, use_fp16=False)
# Save model weights
torch.save(model.state_dict(), 'encoder_weights.pth')


## 解密部分

In [None]:
#@title 解密decoder

import torch
import torch.nn as nn
import torch.nn.functional as F

class Upscale(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3):
        super(Upscale, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=kernel_size // 2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

    def forward(self, x):
        x = self.upsample(x)
        return F.relu(self.conv(x))

class Decoder(nn.Module):
    def __init__(self, in_ch, d_ch, d_mask_ch):
        super(Decoder, self).__init__()
        self.upscale0 = Upscale(in_ch, d_ch * 8, kernel_size=3)
        self.upscale1 = Upscale(d_ch * 8, d_ch * 4, kernel_size=3)
        self.upscale2 = Upscale(d_ch * 4, d_ch * 2, kernel_size=3)
        self.res0 = ResidualBlock(d_ch * 8, kernel_size=3)
        self.res1 = ResidualBlock(d_ch * 4, kernel_size=3)
        self.res2 = ResidualBlock(d_ch * 2, kernel_size=3)

        self.upscalem0 = Upscale(in_ch, d_mask_ch * 8, kernel_size=3)
        self.upscalem1 = Upscale(d_mask_ch * 8, d_mask_ch * 4, kernel_size=3)
        self.upscalem2 = Upscale(d_mask_ch * 4, d_mask_ch * 2, kernel_size=3)

        self.out_conv = nn.Conv2d(d_ch * 2, 3, kernel_size=1)
        self.out_conv1 = nn.Conv2d(d_ch * 2, 3, kernel_size=3, padding=1)
        self.out_conv2 = nn.Conv2d(d_ch * 2, 3, kernel_size=3, padding=1)
        self.out_conv3 = nn.Conv2d(d_ch * 2, 3, kernel_size=3, padding=1)
        self.upscalem3 = Upscale(d_mask_ch * 2, d_mask_ch * 1, kernel_size=3)
        self.out_convm = nn.Conv2d(d_mask_ch * 1, 1, kernel_size=1)

    def forward(self, z):
        # Decoder path
        x = self.upscale0(z)
        x = self.res0(x)
        x = self.upscale1(x)
        x = self.res1(x)
        x = self.upscale2(x)
        x = self.res2(x)

        # Combine the output of multiple conv layers and apply pixel shuffle
        x = torch.cat([
            self.out_conv(x),
            self.out_conv1(x),
            self.out_conv2(x),
            self.out_conv3(x)
        ], dim=1)

        x = F.pixel_shuffle(x, upscale_factor=2)  # Equivalent to depth_to_space

        # Mask path
        m = self.upscalem0(z)
        m = self.upscalem1(m)
        m = self.upscalem2(m)
        m = self.upscalem3(m)
        m = torch.sigmoid(self.out_convm(m))

        return x, m

class ResidualBlock(nn.Module):
    def __init__(self, ch, kernel_size=3):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(ch, ch, kernel_size=kernel_size, padding=kernel_size//2)
        self.conv2 = nn.Conv2d(ch, ch, kernel_size=kernel_size, padding=kernel_size//2)

    def forward(self, x):
        residual = x
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return F.relu(x + residual)


In [None]:
#@title 保存权重

import torch
import torch.optim as optim

# Initialize the model
in_ch = 3
d_ch = 64
d_mask_ch = 32
decoder = Decoder(in_ch, d_ch, d_mask_ch)

# Create a dummy input tensor (e.g., batch of images with 3 channels and 64x64 size)
dummy_input = torch.randn(1, in_ch, 64, 64)  # Batch size of 1, 3 channels, 64x64

# Forward pass
x, m = decoder(dummy_input)

print(x.shape)  # Output shape of the main decoder path
print(m.shape)  # Output shape of the mask path


In [None]:
#@title 保存权重

# Save the model weights
torch.save(decoder.state_dict(), 'decoder_weights.pth')


## 处理inner

In [None]:
#@title inner模型

import torch
import torch.nn as nn
import torch.nn.functional as F

class Upscale(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3):
        super(Upscale, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=kernel_size // 2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

    def forward(self, x):
        x = self.upsample(x)
        return F.relu(self.conv(x))

class Inter(nn.Module):
    def __init__(self, in_ch, ae_ch, ae_out_ch, lowest_dense_res, opts=None, use_fp16=False):
        super(Inter, self).__init__()
        self.in_ch = in_ch
        self.ae_ch = ae_ch
        self.ae_out_ch = ae_out_ch
        self.lowest_dense_res = lowest_dense_res
        self.opts = opts if opts is not None else {}
        self.use_fp16 = use_fp16

        self.dense1 = nn.Linear(in_ch, ae_ch)
        self.dense2 = nn.Linear(ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch)

        if 't' not in self.opts:
            self.upscale1 = Upscale(ae_out_ch, ae_out_ch)

    def forward(self, inp):
        x = inp
        x = self.dense1(x)
        x = self.dense2(x)

        # Reshape the tensor to 4D (batch_size, channels, height, width)
        x = x.view(-1, self.ae_out_ch, self.lowest_dense_res, self.lowest_dense_res)

        if self.use_fp16:
            x = x.half()

        if 't' not in self.opts:
            x = self.upscale1(x)

        return x

    def get_out_res(self):
        return self.lowest_dense_res * 2 if 't' not in self.opts else self.lowest_dense_res

    def get_out_ch(self):
        return self.ae_out_ch


In [None]:
#@title 调试inner

# Parameters
in_ch = 256
ae_ch = 128
ae_out_ch = 64
lowest_dense_res = 16
opts = {}  # or {'t': True} to modify behavior
use_fp16 = False

# Create model
inter = Inter(in_ch, ae_ch, ae_out_ch, lowest_dense_res, opts, use_fp16)

# Dummy input
dummy_input = torch.randn(1, in_ch)  # Batch size of 1, flattened input

# Forward pass
output = inter(dummy_input)

print(output.shape)  # Output shape


torch.Size([1, 64, 32, 32])


In [None]:
#@title 保存权重

# Save the model weights
torch.save(inter.state_dict(), 'inter_weights.pth')


## 读取png文件

In [3]:
from PIL import Image
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class CustomImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
        if not self.image_files:
            raise ValueError(f"No PNG files found in directory {image_dir}")
        print(f"Found {len(self.image_files)} PNG files.")  # Debug line

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

# Define your transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# Specify the path to your image directory
image_dir = '/content/data_dst/aligned'

# Create dataset and DataLoader
dataset = CustomImageDataset(image_dir=image_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# Example usage
for images in data_loader:
    print(images.shape)  # Output the shape of the image batch


Found 72 PNG files.




torch.Size([32, 3, 128, 128])
torch.Size([32, 3, 128, 128])
torch.Size([8, 3, 128, 128])


## 计算loss

In [None]:
#@title 计算loss

import torch
import torch.nn.functional as F
from torchvision.transforms import functional as TF

def ssim(x, y, max_val=1.0, filter_size=11):
    # Placeholder for SSIM calculation
    # You need to replace this with the actual implementation of SSIM
    return torch.ones_like(x)  # Dummy implementation

# Assume gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, gpu_target_srcm, gpu_pred_src_srcm are tensors
# and resolution is defined

# Example tensors and resolution (Replace with actual tensors and value)
gpu_target_src_masked_opt = torch.rand((batch_size, channels, height, width))
gpu_pred_src_src_masked_opt = torch.rand((batch_size, channels, height, width))
gpu_target_srcm = torch.rand((batch_size, channels, height, width))
gpu_pred_src_srcm = torch.rand((batch_size, channels, height, width))
resolution = 224  # Example resolution

# SSIM calculation
filter_size = int(resolution / 11.6)
dssim = ssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=filter_size)

# Compute loss
gpu_src_loss = torch.mean(10 * dssim, dim=1)
gpu_src_loss += torch.mean(10 * torch.square(gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt), dim=[1, 2, 3])
gpu_src_loss += torch.mean(10 * torch.square(gpu_target_srcm - gpu_pred_src_srcm), dim=[1, 2, 3])


## 训练

In [None]:
#@title 训练

import torch.optim as optim

encoder = Encoder(in_ch=3, e_ch=64, opts={'t': False})  # Adjust parameters as needed

# Example optimizer
optimizer = optim.Adam(encoder.parameters(), lr=1e-3)
num_epochs = 10
# 检查是否可以使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Parameters
in_ch = 256
ae_ch = 128
ae_out_ch = 64
lowest_dense_res = 16
opts = {}  # or {'t': True} to modify behavior
use_fp16 = False

# Create model
inter = Inter(in_ch, ae_ch, ae_out_ch, lowest_dense_res, opts, use_fp16)

# Initialize the model
in_ch = 3
d_ch = 64
d_mask_ch = 32
decoder = Decoder(in_ch, d_ch, d_mask_ch)

# Create a dummy input tensor (e.g., batch of images with 3 channels and 64x64 size)
# dummy_input = torch.randn(1, in_ch, 64, 64)  # Batch size of 1, 3 channels, 64x64

# # Forward pass
# x, m = decoder(dummy_input)

# print(x.shape)  # Output shape of the main decoder path
# print(m.shape)  # Output shape of the mask path


# Training loop
for epoch in range(num_epochs):
    for batch in data_loader:
        images = batch.to(device)  # Move images to the appropriate device (e.g., GPU)

        # Forward pass
        encoder_output = encoder(images)

        # Forward pass
        # inter_output = inter(encoder_output)
        x, m = decoder(encoder_output)

        # Example loss computation (Replace with your actual loss function)
        loss = some_loss_function(encoder_output, target)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")


## 通过keras的summary改成pytorch

In [None]:
#@title 用pytorch写encoder

import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.relu(self.conv4(x))
        return x

# 创建模型实例
model = Encoder()

# 打印模型的概要信息
from torchsummary import summary
summary(model, input_size=(3, 128, 128))  # 这里假设输入图像大小为 224x224, 可以根据需要调整


In [None]:
#@title 改成inter

import torch
import torch.nn as nn
from torchsummary import summary

class InterModel(nn.Module):
    def __init__(self):
        super(InterModel, self).__init__()
        # 定义卷积层
        self.conv1 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=3, stride=1, padding=1)

        # 定义全连接层
        # 计算全连接层的输入特征数需要知道卷积输出的特征图尺寸
        # 这里假设卷积层的输出尺寸是 (batch_size, 512, 28, 28)
        # 输入特征数 = 512 * 28 * 28 = 411,648
        self.fc1 = nn.Linear(512 * 28 * 28, 2359424)  # 第一层全连接
        self.fc2 = nn.Linear(2359424, 148608)          # 第二层全连接

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = x.view(x.size(0), -1)  # Flatten
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 创建模型实例
model = InterModel()

# 打印模型的概要信息
summary(model, input_size=(128, 28, 28))  # 输入形状为 (128, 28, 28)，假设这里的特征图尺寸是 28x28

# # 计算总参数数目
# total_params = sum(p.numel() for p in model.parameters())
# print(f"Total params count: {total_params}")


In [None]:
#@title 改成decoder

import torch
import torch.nn as nn

class DecoderSrc(nn.Module):
    def __init__(self):
        super(DecoderSrc, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=128, out_channels=2048, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv7 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv8 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self.conv9 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self.conv10 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=3, padding=1)
        self.conv11 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv12 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.conv13 = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=1)
        self.conv14 = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=1)
        self.conv15 = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=1)
        self.conv16 = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=1)
        self.conv17 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.conv18 = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.conv8(x)
        x = self.conv9(x)
        x = self.conv10(x)
        x = self.conv11(x)
        x = self.conv12(x)
        x = self.conv13(x)
        x = self.conv14(x)
        x = self.conv15(x)
        x = self.conv16(x)
        x = self.conv17(x)
        x = self.conv18(x)
        return x

# Example usage
model = DecoderSrc()
print(model)


In [None]:
#@title 训练

import torch.optim as optim

encoder = Encoder()  # Adjust parameters as needed

# Example optimizer
optimizer = optim.Adam(encoder.parameters(), lr=1e-3)
num_epochs = 10
# 检查是否可以使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Parameters
in_ch = 256
ae_ch = 128
ae_out_ch = 64
lowest_dense_res = 16
opts = {}  # or {'t': True} to modify behavior
use_fp16 = False

# Create model
# inter = Inter()

# Initialize the model
in_ch = 3
d_ch = 64
d_mask_ch = 32
decoder = DecoderSrc()

# Create a dummy input tensor (e.g., batch of images with 3 channels and 64x64 size)
# dummy_input = torch.randn(1, in_ch, 64, 64)  # Batch size of 1, 3 channels, 64x64

# # Forward pass
# x, m = decoder(dummy_input)

# print(x.shape)  # Output shape of the main decoder path
# print(m.shape)  # Output shape of the mask path


# Training loop
for epoch in range(num_epochs):
    for batch in data_loader:
        images = batch.to(device)  # Move images to the appropriate device (e.g., GPU)

        # Forward pass
        encoder_output = encoder(images)
        print("encoder shape:", encoder_output.shape)

        # Forward pass
        # inter_output = inter(encoder_output)
        x, m = decoder(encoder_output)

        # Example loss computation (Replace with your actual loss function)
        loss = some_loss_function(encoder_output, target)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")