# EDS

In [1]:
# import torch

# torch.cuda.get_device_name(0)

In [2]:
# !pip install pytorch_msssim

In [3]:
import glob
import os
import shutil

import cv2
import numpy as np
import skimage
import sklearn
import torch
import torch.nn.functional as F
import torchvision
from PIL import Image
from pytorch_msssim import MS_SSIM, SSIM
from skimage import metrics
from sklearn.datasets import fetch_lfw_people
from sklearn.metrics import (
    ConfusionMatrixDisplay,
    classification_report,
    confusion_matrix,
)
from sklearn.model_selection import train_test_split
from torch import cuda, nn
from torch.utils.data import ConcatDataset, DataLoader, TensorDataset, random_split
from torchvision import models, transforms, utils
from torchvision.datasets import ImageFolder
from torchvision.transforms.functional import to_pil_image
from tqdm.auto import tqdm, trange

%pylab inline
matplotlib.use("Agg")  # prevent plt memory leak

Populating the interactive namespace from numpy and matplotlib


# MISC

### reproductibility

In [4]:
def set_seed(random_seed: int = 0):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)


global SEED
SEED = 0

### save result

In [5]:
def save_weight(model, path: str, weight_file_name: str):
    os.makedirs(path, exist_ok=True)
    torch.save(model.state_dict(), f"{path}/{weight_file_name}")
    print(f"[+] Saved {path}/{weight_file_name}")


def save_loss(losses, path: str, loss_file_name: str):
    os.makedirs(path, exist_ok=True)
    with open(f"{path}/{loss_file_name}", "w+") as f:
        for loss in losses:
            f.write("%s\n" % loss)
    print(f"[+] Saved {path}/{loss_file_name}")


def save_pil_image(pil_image, path, name):
    os.makedirs(path, exist_ok=True)
    pil_image.save(f"{path}/{name}")


def save_merged_pil_image(pil_imgs, path, name, titles=None):
    fontsize = 15
    plt.figure(tight_layout=True)

    for i, pil_img in enumerate(pil_imgs):
        plt.subplot(1, len(pil_imgs), i + 1)
        if titles:
            plt.title(titles[i], fontsize=fontsize)
        if pil_img.mode != "RGB":
            plt.imshow(pil_img, cmap="gray")
        else:
            plt.imshow(pil_img)
        plt.axis("off")

    os.makedirs(path, exist_ok=True)
    plt.savefig(f"{path}/{name}", transparent=True, bbox_inches="tight", pad_inches=0)
    plt.close("all")  # prevent plt memory leak
    plt.clf()  # prevent plt memory leak

# Data


## CIFAR10
- 50000 + 10000 = 60000
- **=> 40000 + (10000+10000)**
- train_size: 40000
- test_size: 20000


In [6]:
def build_cifar10_dataloader(batch_size):
    set_seed(SEED)
    root = "data/cifar10"
    transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((64, 64))])
    train_data = torchvision.datasets.CIFAR10(
        root=root, train=True, transform=transform, download=True
    )

    test_data = torchvision.datasets.CIFAR10(
        root=root, train=False, transform=transform, download=True
    )

    train_size = len(train_data)
    test_size = len(test_data)

    train_data, extra_test_data = random_split(train_data, [train_size - 10000, 10000],)
    test_data = ConcatDataset((extra_test_data, test_data))

    train_size = len(train_data)
    test_size = len(test_data)
    print(f"[+] load cifar10: train_size = {train_size}, test_size = {test_size}")

    train_c_data, train_p_data = random_split(
        train_data, [train_size // 2, train_size // 2]
    )
    test_c_data, test_p_data = random_split(test_data, [test_size // 2, test_size // 2])

    train_c_loader = DataLoader(
        dataset=train_c_data, batch_size=batch_size, shuffle=True, drop_last=False
    )
    train_p_loader = DataLoader(
        dataset=train_p_data, batch_size=batch_size, shuffle=True, drop_last=False
    )

    test_c_loader = DataLoader(
        dataset=test_c_data, batch_size=batch_size, shuffle=False, drop_last=False,
    )
    test_p_loader = DataLoader(
        dataset=test_p_data, batch_size=batch_size, shuffle=False, drop_last=False,
    )
    return train_c_loader, train_p_loader, test_c_loader, test_p_loader

In [7]:
class ImageDataset(TensorDataset):
    def __init__(self, files):
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Resize((64, 64)),]
        )
        self.files = files

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        if img.mode != "RGB":
            img = img.convert("RGB")
        img = self.transform(img)
        return img, self.files[index % len(self.files)].split("\\")[-1].split(".")[0]

    def __len__(self):
        return len(self.files)

## CelebA in Kaggle
- train_size: 182598
- test_size: 20000

In [8]:
def build_celeba_dataloader(batch_size, full=False):
    set_seed(SEED)
    root = "data/celeba-dataset/img_align_celeba/img_align_celeba"

    train_paths, test_paths = train_test_split(
        sorted(glob.glob(root + "/*.*")), test_size=20000, random_state=0, shuffle=True
    )

    if not full:
        train_paths = train_paths[:40001]  #! limit train size
    train_size = len(train_paths)
    test_size = len(test_paths)
    if train_size % 2:
        train_size -= 1
        train_paths = train_paths[:-1]
    print(f"[+] load celeba: train_size = {train_size}, test_size = {test_size}")

    train_dataset = ImageDataset(train_paths)
    test_dataset = ImageDataset(test_paths)

    train_c_dataset, train_p_dataset = random_split(
        train_dataset, [train_size // 2, train_size // 2]
    )
    test_c_dataset, test_p_dataset = random_split(
        test_dataset, [test_size // 2, test_size // 2]
    )

    train_c_loader = DataLoader(
        dataset=train_c_dataset, batch_size=batch_size, shuffle=True, drop_last=True
    )
    train_p_loader = DataLoader(
        dataset=train_p_dataset, batch_size=batch_size, shuffle=True, drop_last=True
    )
    test_c_loader = DataLoader(
        dataset=test_c_dataset, batch_size=batch_size, shuffle=False, drop_last=True
    )
    test_p_loader = DataLoader(
        dataset=test_p_dataset, batch_size=batch_size, shuffle=False, drop_last=True
    )

    return train_c_loader, train_p_loader, test_c_loader, test_p_loader

## ImageNet in Kaggle
- train_size: 115000
- test_size: 20000

In [9]:
def build_imagenet_dataloader(batch_size, full=False):
    set_seed(SEED)
    # ImageNet100 - A Sample of ImageNet Classes
    root = "data/imagenet100"

    train_paths, test_paths = train_test_split(
        sorted(glob.glob(root + "/*/*/*.*")),
        test_size=20000,
        random_state=0,
        shuffle=True,
    )

    if not full:
        train_paths = train_paths[:40001]  #! limit train size
    train_size = len(train_paths)
    test_size = len(test_paths)

    if train_size % 2:
        train_size -= 1
        train_paths = train_paths[:-1]
    print(f"[+] load imagenet: train_size = {train_size}, test_size = {test_size}")

    train_dataset = ImageDataset(train_paths)
    test_dataset = ImageDataset(test_paths)

    train_c_dataset, train_p_dataset = random_split(
        train_dataset, [train_size // 2, train_size // 2]
    )
    test_c_dataset, test_p_dataset = random_split(
        test_dataset, [test_size // 2, test_size // 2]
    )

    train_c_loader = DataLoader(
        dataset=train_c_dataset, batch_size=batch_size, shuffle=True, drop_last=True
    )
    train_p_loader = DataLoader(
        dataset=train_p_dataset, batch_size=batch_size, shuffle=True, drop_last=True
    )
    test_c_loader = DataLoader(
        dataset=test_c_dataset, batch_size=batch_size, shuffle=False, drop_last=True
    )
    test_p_loader = DataLoader(
        dataset=test_p_dataset, batch_size=batch_size, shuffle=False, drop_last=True
    )

    return train_c_loader, train_p_loader, test_c_loader, test_p_loader

# Define model


In [10]:
class EDS(nn.Module):
    """EDS model

    Args:
        (c, p): pair of cover image, payload image
    Returns
        e_o: encoder output
        d_o: decoder output
    """

    def __init__(self):
        super(EDS, self).__init__()
        self.name = "EDS"
        self.define_encoder()
        self.define_decoder()
        self.weight_init_xavier_uniform()

    def weight_init_xavier_uniform(self):
        for submodule in self.modules():
            if isinstance(submodule, torch.nn.Conv2d):
                torch.nn.init.xavier_uniform_(submodule.weight)
                # submodule.bias.data.fill_(0.01)

    def make_seq(self, in_channels=16, out_channels=16, kernel_size=3, padding=1):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
            nn.ReLU(),
        )

    def define_decoder(self):
        self.decoder = nn.ModuleList()
        self.decoder.append(self.make_seq(in_channels=3))
        self.decoder.append(self.make_seq())
        self.decoder.append(self.make_seq(out_channels=8))
        self.decoder.append(self.make_seq(in_channels=8, out_channels=8))
        self.decoder.append(self.make_seq(in_channels=8, out_channels=3))
        self.decoder.append(self.make_seq(in_channels=3, out_channels=3))
        self.decoder.append(nn.Conv2d(3, 1, kernel_size=1))

    def define_encoder(self):
        # host branch
        self.host_branch = nn.ModuleList()
        self.host_branch.append(self.make_seq(3))
        self.host_branch.append(self.make_seq(in_channels=32))
        self.host_branch.append(self.make_seq())
        self.host_branch.append(self.make_seq(in_channels=32))
        self.host_branch.append(self.make_seq())
        self.host_branch.append(self.make_seq(in_channels=32))
        self.host_branch.append(self.make_seq())
        self.host_branch.append(self.make_seq(in_channels=32, kernel_size=1, padding=0))
        self.host_branch.append(self.make_seq(out_channels=8, kernel_size=1, padding=0))
        self.host_branch.append(nn.Conv2d(8, 3, kernel_size=1, padding=0))

        # guest branch
        self.guest_branch = nn.ModuleList()
        self.guest_branch.append(self.make_seq(1))
        for i in range(6):
            self.guest_branch.append(self.make_seq())

    def forward(self, c, p):
        ##### Encoder #####
        # layer1
        p1 = self.guest_branch[0](p)
        c1 = self.host_branch[0](c)
        c1_p1 = torch.cat((c1, p1), 1)  # 32 channels
        # layer2
        p2 = self.guest_branch[1](p1)
        c2 = self.host_branch[1](c1_p1)
        # layer3
        p3 = self.guest_branch[2](p2)
        c3 = self.host_branch[2](c2)
        c3_p3 = torch.cat((c3, p3), 1)  # 32 channels
        # layer4
        p4 = self.guest_branch[3](p3)
        c4 = self.host_branch[3](c3_p3)
        # layer5
        p5 = self.guest_branch[4](p4)
        c5 = self.host_branch[4](c4)
        c5_p5 = torch.cat((c5, p5), 1)  # 32 channels
        # layer6
        p6 = self.guest_branch[5](p5)
        c6 = self.host_branch[5](c5_p5)
        # layer7
        p7 = self.guest_branch[6](p6)
        c7 = self.host_branch[6](c6)
        c7_p7 = torch.cat((c7, p7), 1)  # 32 channels
        # layer8
        c8 = self.host_branch[7](c7_p7)
        # layer9
        c9 = self.host_branch[8](c8)
        # layer10
        encoder_output = self.host_branch[9](c9)

        ##### Decoder #####
        decoder_output = encoder_output
        for layer in self.decoder:
            decoder_output = layer(decoder_output)

        return encoder_output, decoder_output

# Loss functions

In [11]:
# For L2 regularization
# Instead of this, we can use weight_decay in optimizer
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
def get_MSE_from_moduleList(moduleList):
    device = "cuda" if cuda.is_available() else "cpu"
    w = torch.tensor(0, dtype=torch.get_default_dtype()).to(device)
    cnt = 0
    for layer in moduleList:
        for params in layer.parameters():
            w.add_(torch.mean(torch.square(params)))
            cnt += params.nelement()
    return w / cnt

In [12]:
def get_loss(coefficients, c, p, e_out, d_out):
    """get_loss function

    Args:
        model (model): EDS model
        criterion (criterion): MSELoss
        coefficients (dict): {alpha, beta, lambd}
        c (B 3 W H): cover image batch
        p (B 1 W H): payload image batch
        e_out (B 3 W H): encoded(stego) image batch
        d_out (B 1 W H): decoded image batch

    Returns:
        loss
    """
    e_loss = nn.MSELoss()(c, e_out)
    d_loss = nn.MSELoss()(p, d_out)
    # Instead of get_MSE_from_moduleList, using weight_decay
    #     w_d = get_MSE_from_moduleList(model.decoder)
    #     w_h = get_MSE_from_moduleList(model.host_branch)
    #     w_g = get_MSE_from_moduleList(model.guest_branch)
    #     w_loss = (w_h + w_g) / 2 + w_d

    #     loss = (
    #         coefficients["alpha"] * e_loss
    #         + coefficients["beta"] * d_loss
    #         + coefficients["lambd"] * w_loss
    #     )
    loss = coefficients["alpha"] * e_loss + coefficients["beta"] * d_loss

    return loss

# Train


In [13]:
def train_and_save(
    model,
    optimizer,
    train_c_dataloader,
    train_p_dataloader,
    epochs,
    period,
    coefficients,
    weight_path,
):
    model.train()
    losses = []
    batchs = min(len(train_c_dataloader), len(train_p_dataloader))

    for epoch in trange(1, epochs + 1, desc="Total"):
        loss_item = 0.0
        iter_c = iter(train_c_dataloader)
        iter_p = iter(train_p_dataloader)
        # pbar = tqdm(range(batchs), desc=f"Epoch {epoch}/{epochs}", ascii=True, leave=False)
        pbar = range(batchs)
        for i in pbar:
            c = next(iter_c)[0].to(device)
            p = next(iter_p)[0].to(device)
            p = transforms.functional.rgb_to_grayscale(
                p
            )  # (B, 1, H, W) # to_pil_image했을 때 제대로 나오는거 확인 O

            ### train step
            optimizer.zero_grad()
            e_out, d_out = model.forward(c, p)
            loss = get_loss(coefficients, c, p, e_out, d_out)
            loss.backward()
            optimizer.step()
            loss_item += loss.item()

        losses.append(loss_item / batchs)
        # print(f"[+] Train loss({epoch}/{epochs}): {train_losses[-1]}")
        # checkpoint
        if epoch % period == 0 and len(weight_path) and epoch:
            save_weight(model, weight_path, f"{epoch}.pickle")
            save_loss(losses, weight_path, f"{epoch}.txt")

# Test (Infer)


In [14]:
def infer_from_saved_weight(
    model, test_c_dataloader, test_p_dataloader, weight_file, result_path
):
    with torch.no_grad():
        model.load_state_dict(torch.load(weight_file, map_location=device))
        model.eval()

        batchs = min(len(test_c_dataloader), len(test_p_dataloader))
        len_datas = min(len(test_c_dataloader.dataset), len(test_p_dataloader.dataset))

        iter_c = iter(test_c_dataloader)
        iter_p = iter(test_p_dataloader)
        total_e_psnr = 0.0
        total_d_psnr = 0.0
        total_e_ssim = 0.0
        total_d_ssim = 0.0
        total_e = 0
        total_d = 0
        for batch in trange(batchs, desc="Total"):
            # for batch in range(batchs):
            e_psnr = 0.0
            e_cnt = 0
            d_psnr = 0.0
            d_cnt = 0
            e_ssim = 0.0
            d_ssim = 0.0
            c = next(iter_c)[0].to(device)
            p = next(iter_p)[0].to(device)
            p = transforms.functional.rgb_to_grayscale(
                p
            )  # to_pil_image했을 때 제대로 나오는거 확인 O
            e_o, d_o = model(c, p)  # infer
            pbar = range(c.shape[0])
            for i in pbar:
                c_np = (
                    c[i]
                    .mul(255)
                    .add_(0.5)
                    .clamp_(0, 255)
                    .permute(1, 2, 0)
                    .to("cpu", torch.uint8)
                    .numpy()
                )
                p_np = (
                    p[i]
                    .mul(255)
                    .add_(0.5)
                    .clamp_(0, 255)
                    .permute(1, 2, 0)
                    .to("cpu", torch.uint8)
                    .numpy()
                    .squeeze()
                )
                e_o_np = (
                    e_o[i]
                    .mul(255)
                    .add_(0.5)
                    .clamp_(0, 255)
                    .permute(1, 2, 0)
                    .to("cpu", torch.uint8)
                    .numpy()
                )
                d_o_np = (
                    d_o[i]
                    .mul(255)
                    .add_(0.5)
                    .clamp_(0, 255)
                    .permute(1, 2, 0)
                    .to("cpu", torch.uint8)
                    .numpy()
                    .squeeze()
                )
                c_img = Image.fromarray(c_np)
                p_img = Image.fromarray(p_np)
                e_o_img = Image.fromarray(e_o_np)
                d_o_img = Image.fromarray(d_o_np)
                # save result images
                save_merged_pil_image(
                    [c_img, p_img, e_o_img, d_o_img],
                    f"{result_path}/merged",
                    f"{batch}-{i}.png",
                )  ## merged image
                ## c, p, e_o, d_o separately
                save_pil_image(c_img, f"{result_path}/c", f"{batch}-{i}.png")
                # save_pil_image(p_img, f"{result_path}/p", f"{batch}-{i}.png")
                # save_pil_image(d_o_img, f"{result_path}/d_o", f"{batch}-{i}.png")
                save_pil_image(
                    e_o_img, f"{result_path}/e_o", f"{batch}-{i}.png"
                )  # encoded cover (stego) image

                # calc psnr, ssim
                psnr = skimage.metrics.peak_signal_noise_ratio(c_np, e_o_np)
                if psnr != np.inf:
                    e_psnr += psnr
                    e_cnt += 1
                psnr = skimage.metrics.peak_signal_noise_ratio(p_np, d_o_np)
                if psnr != np.inf:
                    d_psnr += psnr
                    d_cnt += 1
                e_ssim += skimage.metrics.structural_similarity(
                    c_np,
                    e_o_np,
                    multichannel=True,
                    channel_axis=2,  # channel 3개로 나눠서 계산한 평균과 같은 값이 나옴
                )
                d_ssim += skimage.metrics.structural_similarity(p_np, d_o_np)
            total_e += e_cnt
            total_d += d_cnt
            # save psnr, ssim
            with open(f"{result_path}/metrics-per-batch.txt", "a+") as f:
                f.write(f"# {batch}\n")
                f.write(f"e_psnr: {e_psnr/e_cnt}\n")
                f.write(f"d_psnr: {d_psnr/d_cnt}\n")
                f.write(f"e_ssim: {e_ssim/batch_size}\n")
                f.write(f"d_ssim: {d_ssim/batch_size}\n")
            total_e_psnr += e_psnr
            total_d_psnr += d_psnr
            total_e_ssim += e_ssim
            total_d_ssim += d_ssim
        with open(result_path + "/metrics-total.txt", "w+") as f:
            f.write(f"total_e_psnr: {total_e_psnr/total_e}\n")
            f.write(f"total_d_psnr: {total_d_psnr/total_d}\n")
            f.write(f"total_e_ssim: {total_e_ssim/len_datas}\n")
            f.write(f"total_d_ssim: {total_d_ssim/len_datas}\n")

# Tree

- weight/{model.name}/{train_set}/{epoch}.pickle
    - cifar10
        - 50.pickle
        - 100.pickle
        - 150.pickle
    - celeba
        - 50.pickle
        - 100.pickle
        - 150.pickle
    - imagenet
        - 50.pickle
        - 100.pickle
        - 150.pickle
- result/{model.name}/{train_set}/{test_set}/{epoch}/{[c, p, d_o, e_d]}/*.png
    - cifar10
        - cifar10
            - 150
                - c
                - e_o
                - merged

# Train and infer with dataset (CIFAR10, CelebA, ImageNet)

In [15]:
def train(train_set):
    # train loader
    if train_set == "cifar10":
        train_c_dataloader, train_p_dataloader = build_cifar10_dataloader(batch_size)[
            :2
        ]
    elif train_set == "celeba":
        train_c_dataloader, train_p_dataloader = build_celeba_dataloader(batch_size)[:2]
    elif train_set == "imagenet":
        train_c_dataloader, train_p_dataloader = build_imagenet_dataloader(batch_size)[
            :2
        ]
    else:
        print("[-] invalid train_set")
        return

    model = EDS().to(device)  ###############
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=lambd)
    weight_path = f"./weight/{model.name}/{train_set}"

    train_and_save(
        model,
        optimizer,
        train_c_dataloader,
        train_p_dataloader,
        epochs,
        period,
        coefficients,
        weight_path,
    )
    del train_c_dataloader, train_p_dataloader, model
    print(f"[+] train {train_set} done (epochs: {epochs})")

In [16]:
def infer(train_set, infer_set):
    # test loader
    if infer_set == "cifar10":
        test_c_dataloader, test_p_dataloader = build_cifar10_dataloader(batch_size)[2:]
    elif infer_set == "celeba":
        test_c_dataloader, test_p_dataloader = build_celeba_dataloader(batch_size)[2:]
    elif infer_set == "imagenet":
        test_c_dataloader, test_p_dataloader = build_imagenet_dataloader(batch_size)[2:]
    else:
        print("[-] invalid test_set")
        return

    model = EDS().to(device)  ###############
    weight_path = f"./weight/{model.name}/{train_set}"
    for e in range(epochs, 0, -period):
        result_path = f"./result/{model.name}/{train_set}/{infer_set}/{e}"
        infer_from_saved_weight(
            model,
            test_c_dataloader,
            test_p_dataloader,
            f"{weight_path}/{e}.pickle",
            result_path,
        )
        # shutil.make_archive(
        #     f"{model.name}-{train_set}-{infer_set}-{e}", "zip", result_path
        # )
        print(f"[+] infer {infer_set} with {model.name}-{train_set} (epochs: {e}) done")
    del test_c_dataloader, test_p_dataloader, model

In [17]:
def train_and_infer():
    for train_set in datasets:
        train(train_set)
        for infer_set in datasets:
            infer(train_set, infer_set)
    print("[+] weight saved")

# Global variable

In [18]:
global device, batch_size, epochs, period, lr, coeddicients, lambd, datasets

device = "cuda" if cuda.is_available() else "cpu"
batch_size = 32
epochs = 150
period = 50
lr = 1e-4
lambd = 1e-4
coefficients = {
    "alpha": 1,
    "beta": 1,
}

datasets = ["cifar10", "celeba", "imagenet"]

In [None]:
# Run
print("[+] Device:", device)
train_and_infer()
print("[+] done")

[+] Device: cuda
Files already downloaded and verified
Files already downloaded and verified
[+] load cifar10: train_size = 40000, test_size = 20000


Total:   0%|          | 0/313 [00:00<?, ?it/s]