In [None]:
#@title 链接Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip /content/drive/MyDrive/workspace.zip

## 加密部分

In [1]:
#@title 加密encoder

import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, in_ch, e_ch, opts=None, use_fp16=False):
        super(Encoder, self).__init__()
        self.in_ch = in_ch
        self.e_ch = e_ch
        self.opts = opts if opts is not None else {}
        self.use_fp16 = use_fp16

        if 't' in self.opts:
            self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5)
            self.res1 = ResidualBlock(self.e_ch)
            self.down2 = Downscale(self.e_ch, self.e_ch * 2, kernel_size=5)
            self.down3 = Downscale(self.e_ch * 2, self.e_ch * 4, kernel_size=5)
            self.down4 = Downscale(self.e_ch * 4, self.e_ch * 8, kernel_size=5)
            self.down5 = Downscale(self.e_ch * 8, self.e_ch * 8, kernel_size=5)
            self.res5 = ResidualBlock(self.e_ch * 8)
        else:
            n_downscales = 4 if 't' not in self.opts else 5
            self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=n_downscales, kernel_size=5)

    def forward(self, x):
        if self.use_fp16:
            x = x.half()

        if 't' in self.opts:
            x = self.down1(x)
            x = self.res1(x)
            x = self.down2(x)
            x = self.down3(x)
            x = self.down4(x)
            x = self.down5(x)
            x = self.res5(x)
        else:
            x = self.down1(x)

        # print("encoder str shape:", x.shape)
        x = torch.flatten(x, 1)

        if 'u' in self.opts:
            x = F.normalize(x, p=2, dim=-1)

        if self.use_fp16:
            x = x.float()

        return x

    def get_out_res(self, res):
        return res // (2**4 if 't' not in self.opts else 2**5)

    def get_out_ch(self):
        return self.e_ch * 8

# 下面是 Downscale 和 ResidualBlock 的示例实现（需要根据你的情况具体实现）
class Downscale(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=5):
        super(Downscale, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=2, padding=kernel_size//2)

    def forward(self, x):
        return F.relu(self.conv(x))

class ResidualBlock(nn.Module):
    def __init__(self, ch):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(ch, ch, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(ch, ch, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return F.relu(x + residual)

class DownscaleBlock(nn.Module):
    def __init__(self, in_ch, out_ch, n_downscales, kernel_size=5):
        super(DownscaleBlock, self).__init__()
        layers = []
        for _ in range(n_downscales):
            layers.append(Downscale(in_ch, out_ch, kernel_size))
            in_ch = out_ch
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)


In [None]:
#@title 保存权重

# Example instantiation
model = Encoder(in_ch=3, e_ch=512, opts={'t': False}, use_fp16=False)
print(model)
# Save model weights
# torch.save(model.state_dict(), 'encoder_weights.pth')


## 解密部分

In [3]:
#@title 解密decoder

import torch
import torch.nn as nn
import torch.nn.functional as F

class Upscale(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3):
        super(Upscale, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=kernel_size // 2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

    def forward(self, x):
        x = self.upsample(x)
        return F.relu(self.conv(x))

class DecoderSrc(nn.Module):
    def __init__(self, in_ch, d_ch, d_mask_ch):
        super(DecoderSrc, self).__init__()
        self.upscale0 = Upscale(in_ch, d_ch * 8, kernel_size=3)
        self.upscale1 = Upscale(d_ch * 8, d_ch * 4, kernel_size=3)
        self.upscale2 = Upscale(d_ch * 4, d_ch * 2, kernel_size=3)
        self.res0 = ResidualBlock(d_ch * 8, kernel_size=3)
        self.res1 = ResidualBlock(d_ch * 4, kernel_size=3)
        self.res2 = ResidualBlock(d_ch * 2, kernel_size=3)

        self.upscalem0 = Upscale(in_ch, d_mask_ch * 8, kernel_size=3)
        self.upscalem1 = Upscale(d_mask_ch * 8, d_mask_ch * 4, kernel_size=3)
        self.upscalem2 = Upscale(d_mask_ch * 4, d_mask_ch * 2, kernel_size=3)

        self.out_conv = nn.Conv2d(d_ch * 2, 3, kernel_size=1)
        self.out_conv1 = nn.Conv2d(d_ch * 2, 3, kernel_size=3, padding=1)
        self.out_conv2 = nn.Conv2d(d_ch * 2, 3, kernel_size=3, padding=1)
        self.out_conv3 = nn.Conv2d(d_ch * 2, 3, kernel_size=3, padding=1)
        self.upscalem3 = Upscale(d_mask_ch * 2, d_mask_ch * 1, kernel_size=3)
        self.out_convm = nn.Conv2d(d_mask_ch * 1, 1, kernel_size=1)

    def forward(self, z):
        # Decoder path
        x = self.upscale0(z)
        x = self.res0(x)
        x = self.upscale1(x)
        x = self.res1(x)
        x = self.upscale2(x)
        x = self.res2(x)

        # Combine the output of multiple conv layers and apply pixel shuffle
        x = torch.cat([
            self.out_conv(x),
            self.out_conv1(x),
            self.out_conv2(x),
            self.out_conv3(x)
        ], dim=1)

        x = F.pixel_shuffle(x, upscale_factor=2)  # Equivalent to depth_to_space

        # Mask path
        m = self.upscalem0(z)
        m = self.upscalem1(m)
        m = self.upscalem2(m)
        m = self.upscalem3(m)
        m = torch.sigmoid(self.out_convm(m))

        return x, m

class ResidualBlock(nn.Module):
    def __init__(self, ch, kernel_size=3):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(ch, ch, kernel_size=kernel_size, padding=kernel_size//2)
        self.conv2 = nn.Conv2d(ch, ch, kernel_size=kernel_size, padding=kernel_size//2)

    def forward(self, x):
        residual = x
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return F.relu(x + residual)


In [None]:
#@title 保存权重

import torch
import torch.optim as optim

# Initialize the model
in_ch = 128
d_ch = 64
d_mask_ch = 16
decoder = DecoderSrc(in_ch, d_ch, d_mask_ch)

print(decoder)



In [None]:
#@title 保存权重

# Save the model weights
torch.save(decoder.state_dict(), 'decoder_weights.pth')


## 读取png文件

In [None]:
from PIL import Image
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class CustomImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
        if not self.image_files:
            raise ValueError(f"No PNG files found in directory {image_dir}")
        print(f"Found {len(self.image_files)} PNG files.")  # Debug line

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')
        imageGrey = Image.open(img_name).convert('L')
        if self.transform:
            image = self.transform(image)
            imageGrey = self.transform(imageGrey)
        return image, imageGrey

# Define your transformations
transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x / 255.0)  # Normalize to [0, 1]
])

# Specify the path to your image directory
image_dir = '/content/data_dst/aligned'

# Create dataset and DataLoader
dataset = CustomImageDataset(image_dir=image_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# Example usage
for images, imageGrey in data_loader:
    print(images.shape, imageGrey.shape)  # Output the shape of the image batch


## 计算loss

In [None]:
#@title 计算loss

import torch
import torch.nn.functional as F
from torchvision.transforms import functional as TF

def ssim(x, y, max_val=1.0, filter_size=11):
    # Placeholder for SSIM calculation
    # You need to replace this with the actual implementation of SSIM
    return torch.ones_like(x)  # Dummy implementation

# Assume gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, gpu_target_srcm, gpu_pred_src_srcm are tensors
# and resolution is defined

# Example tensors and resolution (Replace with actual tensors and value)
gpu_target_src_masked_opt = torch.rand((batch_size, channels, height, width))
gpu_pred_src_src_masked_opt = torch.rand((batch_size, channels, height, width))
gpu_target_srcm = torch.rand((batch_size, channels, height, width))
gpu_pred_src_srcm = torch.rand((batch_size, channels, height, width))
resolution = 224  # Example resolution

# SSIM calculation
filter_size = int(resolution / 11.6)
dssim = ssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=filter_size)

# Compute loss
gpu_src_loss = torch.mean(10 * dssim, dim=1)
gpu_src_loss += torch.mean(10 * torch.square(gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt), dim=[1, 2, 3])
gpu_src_loss += torch.mean(10 * torch.square(gpu_target_srcm - gpu_pred_src_srcm), dim=[1, 2, 3])


## 通过keras的summary改成pytorch

In [None]:
#@title 更好的inter

import torch
import torch.nn as nn
import torch.nn.functional as F

# 假设的常量定义（你需要根据实际情况修改这些值）
lowest_dense_res = 6  # 例如，设置为 32 或其他合适的值

# 定义你的模型类（之前已经给出）
class Inter(nn.Module):
    def __init__(self, in_ch, ae_ch, ae_out_ch, opts=None, use_fp16=False):
        super().__init__()
        self.in_ch = in_ch
        self.ae_ch = ae_ch
        self.ae_out_ch = ae_out_ch
        self.opts = opts if opts is not None else []
        self.use_fp16 = use_fp16

        self.dense1 = nn.Linear(in_ch, ae_ch)
        self.dense2 = nn.Linear(ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch)

        if 't' not in self.opts:
            self.upscale1 = Upscale(ae_out_ch, ae_out_ch)

    def forward(self, inp):
        x = inp
        x = self.dense1(x)
        x = self.dense2(x)
        x = x.view(-1, self.ae_out_ch, lowest_dense_res, lowest_dense_res)

        if self.use_fp16:
            x = x.half()

        if 't' not in self.opts:
            x = self.upscale1(x)

        return x

    def get_out_res(self):
        return lowest_dense_res * 2 if 't' not in self.opts else lowest_dense_res

    def get_out_ch(self):
        return self.ae_out_ch

class Upscale(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3):
        super(Upscale, self).__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch * 4, kernel_size=kernel_size, padding='same')

    def forward(self, x):
        x = self.conv1(x)
        x = F.leaky_relu(x, negative_slope=0.1)
        # print("x 1 shape: ", x.shape)
        x = self.pixel_shuffle(x, upscale_factor=2)
        # print("x 2 shape: ", x.shape)
        return x

    @staticmethod
    def pixel_shuffle(x, upscale_factor):
        batch_size, in_channels, height, width = x.size()
        out_channels = in_channels // (upscale_factor ** 2)
        new_height = height * upscale_factor
        new_width = width * upscale_factor

        x = x.view(batch_size, out_channels, upscale_factor, upscale_factor, height, width)
        x = x.permute(0, 1, 4, 2, 5, 3)
        x = x.contiguous().view(batch_size, out_channels, new_height, new_width)
        return x

# 创建模型实例
in_channels = 18432   # 输入通道数
ae_channels = 32  # 自编码器的通道数
ae_out_channels = 128  # 自编码器输出的通道数
opts = []  # 可选参数
use_fp16 = False  # 是否使用 FP16

dummy_input = torch.randn(32, in_channels)

model = Inter(in_channels, ae_channels, ae_out_channels, opts, use_fp16)
output = model(dummy_input)
print(output.shape)


In [None]:
#@title 算法

import torch
import torch.nn as nn
import torch.nn.functional as F

class Upscale(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3):
        super(Upscale, self).__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch * 4, kernel_size=kernel_size, padding='same')

    def forward(self, x):
        x = self.conv1(x)
        x = F.leaky_relu(x, negative_slope=0.1)
        x = self.pixel_shuffle(x, upscale_factor=2)
        return x

    @staticmethod
    def pixel_shuffle(x, upscale_factor):
        batch_size, in_channels, height, width = x.size()
        out_channels = in_channels // (upscale_factor ** 2)
        new_height = height * upscale_factor
        new_width = width * upscale_factor

        x = x.view(batch_size, out_channels, upscale_factor, upscale_factor, height, width)
        x = x.permute(0, 1, 4, 2, 5, 3)
        x = x.contiguous().view(batch_size, out_channels, new_height, new_width)
        return x


# 创建一个输入张量 (batch_size, channels, height, width)
input_tensor = torch.randn(1, 3, 128, 128)  # 例如，一个 batch size 为 1，通道数为 3，高度和宽度为 32 的张量


# 创建模型实例
model = Upscale(in_ch=3, out_ch=3, kernel_size=3)  # 例如，输入通道为 3，输出通道为 6


# 将输入张量传递给模型
output_tensor = model(input_tensor)

# 打印输出张量的形状
print(output_tensor.shape)  # 输出的形状应该是 (1, 6, 64, 64) 因为 `upscale_factor` 为 2

In [None]:
!pip install torchmetrics scikit-image

In [None]:
#@title 训练

import torch.optim as optim
import torch
import torchmetrics.functional as tmf
from itertools import chain
import os
from PIL import Image
import numpy as np
import cv2
import torchmetrics as tm
from skimage.metrics import structural_similarity as ssim

encoder = Encoder(3, 512)  # Adjust parameters as needed

# Example optimizer

num_epochs = 3000
# 检查是否可以使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Parameters
in_ch = 256
ae_ch = 128
ae_out_ch = 64
lowest_dense_res = 3
opts = {}  # or {'t': True} to modify behavior
use_fp16 = False


# 创建模型实例
in_channels = 18432   # 输入通道数
# in_channels = 18432   # 输入通道数
ae_channels = 32  # 自编码器的通道数
ae_out_channels = 128  # 自编码器输出的通道数
opts = []  # 可选参数
use_fp16 = False  # 是否使用 FP16

inter = Inter(in_channels, ae_channels, ae_out_channels, opts, use_fp16)
# Create model
# inter = Inter()

# Initialize the model
in_ch = 128
d_ch = 64
d_mask_ch = 16
decoder = DecoderSrc(in_ch, d_ch, d_mask_ch)

# 检查是否可以使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 将模型移动到设备 (GPU 或 CPU)
encoder = encoder.to(device)
inter = inter.to(device)
decoder = decoder.to(device)

optimizer = optim.Adam(
    chain(encoder.parameters(), inter.parameters(), decoder.parameters()),
    lr=1e-3
)



# 文件路径
file_path = "/content/drive/MyDrive/model_epoch.pth"

# 判断文件是否存在
if os.path.exists(file_path):

    try:
      checkpoint = torch.load(f"/content/drive/MyDrive/model_epoch.pth")  # 加载特定 epoch 的检查点

      # 恢复模型的状态字典
      encoder.load_state_dict(checkpoint['encoder_state_dict'])
      inter.load_state_dict(checkpoint['inter_state_dict'])
      decoder.load_state_dict(checkpoint['decoder_state_dict'])

      # 恢复优化器的状态
      optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

      # 恢复其他训练信息（例如 epoch 和 loss）
      epoch = checkpoint['epoch']
      loss = checkpoint['loss']

      print(f"Model loaded from epoch {epoch} with loss {loss.item()}")
      print(f"{file_path} exists.")
    except Exception as e:
      print(f"An unexpected error occurred: {e}")
else:
    print(f"{file_path} does not exist.")

# 用于存储灰度图像的列表
gray_images_list = []
x_gray_images_list = []

# Training loop
for epoch in range(num_epochs):
    bPrint = True
    bPrintSimple = True
    for batch in data_loader:
        images, imageGrey = batch  # Move images to the appropriate device (e.g., GPU)
        images = images.to(device)
        imageGrey = imageGrey.to(device)

        # Forward pass
        encoder_output = encoder(images)
        # print("encoder shape:", encoder_output.shape)

        # Forward pass
        inter_output = inter(encoder_output)
        # print("inter_output shape: ", inter_output.shape)
        x, m = decoder(inter_output)

        resolution = 96
        # Assuming you have `resolution`, `gpu_target_src_masked_opt`, and `gpu_pred_src_src_masked_opt` defined
        filter_size = int(resolution / 11.6)

        ssim_value = tmf.structural_similarity_index_measure(imageGrey, m, data_range=1, kernel_size=3)


        # DSSIM 是 1 - SSIM
        dssim_value = 1 - ssim_value

        # 可以将 DSSIM 乘以权重，例如 10，作为 loss
        gpu_src_loss = torch.mean(dssim_value * 10)


        # Example loss computation (Replace with your actual loss function)
        # loss = some_loss_function(encoder_output, target)

        # # Backward pass and optimization
        optimizer.zero_grad()
        gpu_src_loss.backward()
        optimizer.step()

        if bPrintSimple:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {gpu_src_loss.item()}")
            bPrintSimple = False

        # 每 100 个 epoch 保存一次模型
        if (epoch + 1) % 3 == 0 and bPrint:
            bPrint = False
            torch.save({
                'epoch': epoch,
                'encoder_state_dict': encoder.state_dict(),
                'inter_state_dict': inter.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': gpu_src_loss
            }, f"/content/drive/MyDrive/model_epoch.pth")
            print(f"Model saved at epoch {epoch+1}")

            # 假设 x 和 x1 是形状为 [batch_size, channels, height, width] 的张量
            # 将 x 和 x1 转换为 NumPy 数组
            x = images.cpu().detach().numpy()
            x1 = images.cpu().detach().numpy()  # 确保 x1 变量是正确的

            # 创建保存目录（如果不存在）
            output_dir = "saved_images"
            os.makedirs(output_dir, exist_ok=True)

            # 遍历保存每一张图片
            for i in range(min(5, x.shape[0])):  # 保存最多 5 张图片
                # 转换 x 和 x1 为 [height, width, channels] 形状
                img_array_x = np.transpose(x[i], (1, 2, 0))
                img_array_x = np.clip(img_array_x * 255, 0, 255).astype(np.uint8)

                img_array_x1 = np.transpose(x1[i], (1, 2, 0))
                img_array_x1 = np.clip(img_array_x1 * 255, 0, 255).astype(np.uint8)

                # 将 NumPy 数组转换为 PIL Image 对象
                img_x = Image.fromarray(img_array_x)
                img_x1 = Image.fromarray(img_array_x1)

                # 合并两张图片（左右拼接）
                width_x = img_x.width
                height_x = img_x.height

                width_x1 = img_x1.width
                height_x1 = img_x1.height

                # 确保两张图片的高度相同，否则需要调整大小
                if height_x != height_x1:
                    img_x1 = img_x1.resize((width_x1, height_x))  # 调整大小
                    img_array_x1 = np.array(img_x1)
                    img_array_x1 = np.clip(img_array_x1, 0, 255).astype(np.uint8)
                    img_x1 = Image.fromarray(img_array_x1)

                img_x_combined = Image.new('RGB', (width_x + width_x1, height_x))
                img_x_combined.paste(img_x, (0, 0))
                img_x_combined.paste(img_x1, (width_x, 0))

                # 保存合并后的图片
                img_x_combined.save(os.path.join(output_dir, f"combined_image_{i+1}.png"))

                print(f"Saved combined_image_{i+1}.png")