# EDS

In [3]:
import torch
from torch import nn, cuda
from tqdm.auto import trange, tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split, ConcatDataset
from torchvision import datasets, models, transforms, utils
from torchvision.transforms.functional import to_pil_image
import numpy as np
from PIL import Image
import skimage
from skimage import metrics
import os
import sklearn
from sklearn.datasets import fetch_lfw_people
from pytorch_msssim import SSIM, MS_SSIM
import shutil

%pylab inline
matplotlib.use('Agg') # prevent plt memory leak


Populating the interactive namespace from numpy and matplotlib


# For reproductibility


In [4]:
random_seed=0
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)


# Data


## CIFAR10
- 50000 + 10000 = 60000
- **=> 40000 + (10000+10000)**
- train_size: 40000
- test_size: 20000


In [5]:
def build_cifar10_dataloader(batch_size):
    root='data/cifar10'
    train_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((64, 64)),
            # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ]
    )

    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((64, 64)),
            # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ]
    )

    train_data = datasets.CIFAR10(
        root=root, train=True, transform=train_transform, download=True
    )
    test_data = datasets.CIFAR10(
        root=root, train=False, transform=test_transform, download=True
    )

    train_size = len(train_data)
    test_size = len(test_data)

    train_data, extra_test_data = random_split(
        train_data, [train_size-10000, 10000], generator=torch.Generator().manual_seed(0)
    )
    test_data=ConcatDataset((extra_test_data, test_data))

    train_size = len(train_data)
    test_size = len(test_data)
    print(f"[+] load cifar10: train_size = {train_size}, test_size = {test_size}")

    train_c_data, train_p_data = random_split(train_data, [train_size//2, train_size//2])
    test_c_data, test_p_data = random_split(test_data, [test_size//2, test_size//2])

    train_c_loader = DataLoader(
        dataset=train_c_data, batch_size=batch_size, shuffle=True, drop_last=True
    )
    train_p_loader = DataLoader(
        dataset=train_p_data, batch_size=batch_size, shuffle=True, drop_last=True
    )

    test_c_loader = DataLoader(
        dataset=test_c_data, batch_size=batch_size, shuffle=False, drop_last=True
    )
    test_p_loader = DataLoader(
        dataset=test_p_data, batch_size=batch_size, shuffle=False, drop_last=True
    )
    return train_c_loader, train_p_loader, test_c_loader, test_p_loader

## CelebA in Kaggle
- train_size: 182598
- test_size: 20000

In [6]:
def build_celeba_dataloader(batch_size, full=False):
    root = "data/celeba-dataset/img_align_celeba/img_align_celeba"
    class ImageDataset(TensorDataset):
        def __init__(self, files):
            self.transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Resize((64, 64)),
                    # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
                ]
            )
            self.files = files
        def __getitem__(self, index):
            img = Image.open(self.files[index % len(self.files)])
            img=self.transform(img)
            return img, 0

        def __len__(self):
            return len(self.files)

    import glob
    from sklearn.model_selection import train_test_split
    train_paths, test_paths = train_test_split(sorted(glob.glob(root + "/*.*")), test_size=20000, random_state=0, shuffle=True)
    
    if not full:
        train_paths=train_paths[:40001] #! limit train size
    train_size=len(train_paths)
    test_size=len(test_paths)
    if train_size % 2:
        train_size -= 1
        train_paths = train_paths[:-1]
    print(f"[+] load celeba: train_size = {train_size}, test_size = {test_size}")
    
    train_dataset=ImageDataset(train_paths)
    test_dataset=ImageDataset(test_paths)
    
    train_c_dataset, train_p_dataset = random_split(train_dataset, [train_size // 2, train_size // 2])
    test_c_dataset, test_p_dataset = random_split(test_dataset, [test_size // 2, test_size // 2])

    train_c_loader = DataLoader(
        dataset=train_c_dataset, batch_size=batch_size, shuffle=True, drop_last=True
    )
    train_p_loader = DataLoader(
        dataset=train_p_dataset, batch_size=batch_size, shuffle=True, drop_last=True
    )
    test_c_loader = DataLoader(
        dataset=test_c_dataset, batch_size=batch_size, shuffle=False, drop_last=True
    )
    test_p_loader = DataLoader(
        dataset=test_p_dataset, batch_size=batch_size, shuffle=False, drop_last=True
    )

    return train_c_loader, train_p_loader, test_c_loader, test_p_loader

## ImageNet in Kaggle
- train_size: 115000
- test_size: 20000

In [7]:
def build_imagenet_dataloader(batch_size, full=False):
    # ImageNet100 - A Sample of ImageNet Classes
    root = "data/imagenet100"
    class ImageDataset(TensorDataset):
        def __init__(self, files):
            self.transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Resize((64, 64)),
                    # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
                ]
            )
            self.files = files
        def __getitem__(self, index):
            img = Image.open(self.files[index % len(self.files)])
            if img.mode!='RGB':
                img = img.convert('RGB')
            img=self.transform(img)
            return img, 0

        def __len__(self):
            return len(self.files)

    import glob
    from sklearn.model_selection import train_test_split
    train_paths, test_paths = train_test_split(sorted(glob.glob(root + "/*/*/*.*")), test_size=20000, random_state=0, shuffle=True)
    
    if not full:
        train_paths=train_paths[:40001] #! limit train size
    train_size=len(train_paths)
    test_size=len(test_paths)
    
    if train_size % 2:
        train_size -= 1
        train_paths = train_paths[:-1]
    print(f"[+] load imagenet: train_size = {train_size}, test_size = {test_size}")
    
    train_dataset=ImageDataset(train_paths)
    test_dataset=ImageDataset(test_paths)
    
    train_c_dataset, train_p_dataset = random_split(train_dataset, [train_size // 2, train_size // 2])
    test_c_dataset, test_p_dataset = random_split(test_dataset, [test_size // 2, test_size // 2])

    train_c_loader = DataLoader(
        dataset=train_c_dataset, batch_size=batch_size, shuffle=True, drop_last=True
    )
    train_p_loader = DataLoader(
        dataset=train_p_dataset, batch_size=batch_size, shuffle=True, drop_last=True
    )
    test_c_loader = DataLoader(
        dataset=test_c_dataset, batch_size=batch_size, shuffle=False, drop_last=True
    )
    test_p_loader = DataLoader(
        dataset=test_p_dataset, batch_size=batch_size, shuffle=False, drop_last=True
    )

    return train_c_loader, train_p_loader, test_c_loader, test_p_loader

# Define model


In [8]:
class EDS(nn.Module):
    """EDS model

    Args:
        (c, p): pair of cover image, payload image
    Returns
        e_o: encoder output
        d_o: decoder output
    """

    def __init__(self):
        super(EDS, self).__init__()
        self.name="EDS"
        self.define_encoder()
        self.define_decoder()

    def make_seq(self, in_channels=16, out_channels=16, kernel_size=3, padding=1):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
            nn.ReLU(),
        )

    def define_decoder(self):
        self.decoder = nn.ModuleList()
        self.decoder.append(self.make_seq(in_channels=3))
        self.decoder.append(self.make_seq())
        self.decoder.append(self.make_seq(out_channels=8))
        self.decoder.append(self.make_seq(in_channels=8, out_channels=8))
        self.decoder.append(self.make_seq(in_channels=8, out_channels=3))
        self.decoder.append(self.make_seq(in_channels=3, out_channels=3))
        self.decoder.append(nn.Conv2d(3, 1, kernel_size=1))

    def define_encoder(self):
        # host branch
        self.host_branch = nn.ModuleList()
        self.host_branch.append(self.make_seq(3))
        self.host_branch.append(self.make_seq(in_channels=32))
        self.host_branch.append(self.make_seq())
        self.host_branch.append(self.make_seq(in_channels=32))
        self.host_branch.append(self.make_seq())
        self.host_branch.append(self.make_seq(in_channels=32))
        self.host_branch.append(self.make_seq())
        self.host_branch.append(self.make_seq(in_channels=32, kernel_size=1, padding=0))
        self.host_branch.append(self.make_seq(out_channels=8, kernel_size=1, padding=0))
        self.host_branch.append(nn.Conv2d(8, 3, kernel_size=1, padding=0))

        # guest branch
        self.guest_branch = nn.ModuleList()
        self.guest_branch.append(self.make_seq(1))
        for i in range(6):
            self.guest_branch.append(self.make_seq())

    def forward(self, c, p):
        ##### Encoder #####
        # layer1
        p1 = self.guest_branch[0](p)
        c1 = self.host_branch[0](c)
        c1_p1 = torch.cat((c1, p1), 1)  # 32 channels
        # layer2
        p2 = self.guest_branch[1](p1)
        c2 = self.host_branch[1](c1_p1)
        # layer3
        p3 = self.guest_branch[2](p2)
        c3 = self.host_branch[2](c2)
        c3_p3 = torch.cat((c3, p3), 1)  # 32 channels
        # layer4
        p4 = self.guest_branch[3](p3)
        c4 = self.host_branch[3](c3_p3)
        # layer5
        p5 = self.guest_branch[4](p4)
        c5 = self.host_branch[4](c4)
        c5_p5 = torch.cat((c5, p5), 1)  # 32 channels
        # layer6
        p6 = self.guest_branch[5](p5)
        c6 = self.host_branch[5](c5_p5)
        # layer7
        p7 = self.guest_branch[6](p6)
        c7 = self.host_branch[6](c6)
        c7_p7 = torch.cat((c7, p7), 1)  # 32 channels
        # layer8
        c8 = self.host_branch[7](c7_p7)
        # layer9
        c9 = self.host_branch[8](c8)
        # layer10
        encoder_output = self.host_branch[9](c9)

        ##### Decoder #####
        decoder_output = encoder_output
        for layer in self.decoder:
            decoder_output = layer(decoder_output)

        return encoder_output, decoder_output


## Xavier initialization


In [9]:
def weight_init_xavier_uniform(submodule):
    if isinstance(submodule, torch.nn.Conv2d):
        torch.nn.init.xavier_uniform_(submodule.weight)
        # submodule.bias.data.fill_(0.01)

# model.apply(weight_init_xavier_uniform)

# Train


In [10]:
def get_MSE_from_moduleList(moduleList):
    device = "cuda" if cuda.is_available() else "cpu"
    w = torch.tensor(0, dtype=torch.get_default_dtype()).to(device)
    cnt = 0
    for layer in moduleList:
        for params in layer.parameters():
            w.add_(torch.mean(torch.square(params)))
            cnt += params.nelement()
    return w / cnt

In [11]:
def get_loss(model, coefficients, c, p, e_out, d_out):
    """get_loss function

    Args:
        model (model): EDS model
        criterion (criterion): MSELoss
        coefficients (dict): {alpha, beta, lambd}
        c (B 3 W H): cover image batch
        p (B 1 W H): payload image batch
        e_out (B 3 W H): encoded(stego) image batch
        d_out (B 1 W H): decoded image batch

    Returns:
        loss
    """
    e_loss = nn.MSELoss()(c, e_out)
    d_loss = nn.MSELoss()(p, d_out)
    w_d = get_MSE_from_moduleList(model.decoder)
    w_h = get_MSE_from_moduleList(model.host_branch)
    w_g = get_MSE_from_moduleList(model.guest_branch)
    w_loss = (w_h + w_g) / 2 + w_d

    loss = (
        coefficients["alpha"] * e_loss
        + coefficients["beta"] * d_loss
        + coefficients["lambd"] * w_loss
    )

    return loss

In [12]:
def save_weight(
    model, path, weight_file_name, loss_file_name, train_losses
):
        def make_path(path):
            try:
                os.makedirs(path)
            except OSError:
                if not os.path.isdir(path):
                    raise
            return path

        path = make_path(path)

        print(f"[+] Saving {path}/{weight_file_name}")
        torch.save(model.state_dict(), f"{path}/{weight_file_name}")
        with open(f"{path}/{loss_file_name}", "w+") as f:
            for loss in train_losses:
                f.write("%s\n" % loss)


In [13]:
def train_and_save(
    model,
    optimizer,
    train_c_loader,
    train_p_loader,
    epochs,
    period,
    coefficients,
    weight_path,
):
    model.train()
    device = "cuda" if cuda.is_available() else "cpu"
    train_losses = []
    batchs = min(len(train_c_loader), len(train_p_loader))
    
    for epoch in trange(1, epochs+1, desc="Total"):
    # for epoch in range(1, epochs+1):
        train_loss = 0.0
        iter_c = iter(train_c_loader)
        iter_p = iter(train_p_loader)
        # pbar = tqdm(range(batchs), desc=f"Epoch {epoch}/{epochs}", ascii=True, leave=False)
        pbar = range(batchs)
        for i in pbar:
            c = next(iter_c)[0].to(device)
            p = next(iter_p)[0].to(device)
            p = transforms.functional.rgb_to_grayscale(p) # (B, 1, H, W) # to_pil_image했을 때 제대로 나오는거 확인 O
            
            ### train step
            optimizer.zero_grad()
            e_out, d_out = model.forward(c, p)
            loss = get_loss(model, coefficients, c, p, e_out, d_out)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_losses.append(train_loss / batchs)
        # print(f"[+] Train loss({epoch}/{epochs}): {train_losses[-1]}")
        # checkpoint
        if epoch % period == 0 and len(weight_path) and epoch:
            save_weight(
                model,
                weight_path,
                f"{epoch}.pickle",
                f"{epoch}.txt",
                train_losses,
            )


# Test (Infer)


In [14]:
def save_pil_image(pil_image, path, name):
    """save_pil_image

    Args:
        pil_image
    """
    def make_path(path):
        try:
            os.makedirs(path)
        except OSError:
            if not os.path.isdir(path):
                raise
        return path
    pil_image.save(f"{make_path(path)}/{name}")


In [15]:
def save_merged_pil_image(pil_imgs, path, name):
    """save_merged_image

    Args:
        pil_imgs : [c_img, p_img, e_o_img, d_o_img]
    """
    [c, p, e_o, d_o] = pil_imgs

    plt.figure(tight_layout=True)
    fontsize=15
    plt.subplot(1, 4, 1)
    # plt.subplot(2, 2, 1)
    # plt.title("Cover", fontsize=fontsize)
    plt.imshow(c)
    plt.axis("off")

    plt.subplot(1, 4, 2)
    # plt.subplot(2, 2, 2)
    # plt.title("Encoded", fontsize=fontsize)
    plt.imshow(e_o)
    plt.axis("off")

    plt.subplot(1, 4, 3)
    # plt.subplot(2, 2, 3)
    # plt.title("Payload", fontsize=fontsize)
    plt.imshow(p, cmap="gray")
    plt.axis("off")

    plt.subplot(1, 4, 4)
    # plt.subplot(2, 2, 4)
    # plt.title("Decoded", fontsize=fontsize)
    plt.imshow(d_o, cmap="gray")
    plt.axis("off")

    def make_path(path):
        try:
            os.makedirs(path)
        except OSError:
            if not os.path.isdir(path):
                raise
        return path

    plt.savefig(f"{make_path(path)}/{name}", transparent=True, bbox_inches='tight', pad_inches=0)
    plt.close('all') # prevent plt memory leak
    plt.clf() # prevent plt memory leak


In [16]:
def infer_from_saved_weight(
    model, test_c_loader, test_p_loader, weight_file, result_path
):
    device = "cuda" if cuda.is_available() else "cpu"

    with torch.no_grad():
        model.load_state_dict(torch.load(weight_file, map_location=device))
        model.eval()

        batchs = min(len(test_c_loader), len(test_p_loader))
        len_datas = min(len(test_c_loader.dataset), len(test_p_loader.dataset))

        iter_c = iter(test_c_loader)
        iter_p = iter(test_p_loader)
        total_e_psnr = 0.0
        total_d_psnr = 0.0
        total_e_ssim = 0.0
        total_d_ssim = 0.0
        total_e=0
        total_d=0
        for batch in trange(batchs, desc="Total"):
        # for batch in range(batchs):
            e_psnr = 0.0
            e_cnt=0
            d_psnr = 0.0
            d_cnt=0
            e_ssim = 0.0
            d_ssim = 0.0
            c = next(iter_c)[0].to(device)
            p = next(iter_p)[0].to(device)
            p = transforms.functional.rgb_to_grayscale(p) # to_pil_image했을 때 제대로 나오는거 확인 O
            e_o, d_o = model(c, p) # infer
            # pbar = tqdm(range(batch_size), desc=f"Batch {batch}/{batchs}", ascii=True, leave=False)
            pbar = range(batch_size)
            for i in pbar:
                # c_img = to_pil_image(c[i])
                # p_img = to_pil_image(p[i])
                # e_o_img = to_pil_image(e_o[i])
                # d_o_img = to_pil_image(d_o[i])
                # c_np = np.array(c_img) # [0, 255]
                # p_np = np.array(p_img) # [0, 255]
                # e_o_np = np.array(e_o_img) # [0, 255]
                # d_o_np = np.array(d_o_img) # [0, 255]
                c_np = c[i].mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
                p_np = p[i].mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy().squeeze()
                e_o_np = e_o[i].mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
                d_o_np = d_o[i].mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy().squeeze()
                c_img = Image.fromarray(c_np)
                p_img = Image.fromarray(p_np)
                e_o_img = Image.fromarray(e_o_np)
                d_o_img = Image.fromarray(d_o_np)
                # save result images
                save_merged_pil_image(
                    [c_img, p_img, e_o_img, d_o_img], f"{result_path}/merged", f"{batch}-{i}.png"
                ) ## merged image
                ## c, p, e_o, d_o separately
                # save_pil_image(c_img, f"{result_path}/c", f"{batch}-{i}.png")
                # save_pil_image(p_img, f"{result_path}/p", f"{batch}-{i}.png")
                # save_pil_image(d_o_img, f"{result_path}/d_o", f"{batch}-{i}.png")
                save_pil_image(e_o_img, f"{result_path}/e_o", f"{batch}-{i}.png") # encoded cover (stego) image
                
                # calc psnr, ssim
                psnr=skimage.metrics.peak_signal_noise_ratio(c_np, e_o_np)
                if psnr!=np.inf:
                    e_psnr+=psnr
                    e_cnt+=1
                psnr=skimage.metrics.peak_signal_noise_ratio(p_np, d_o_np)
                if psnr!=np.inf:
                    d_psnr+=psnr
                    d_cnt+=1
                e_ssim += skimage.metrics.structural_similarity(
                    c_np, e_o_np, multichannel=True, channel_axis=2 # channel 3개로 나눠서 계산한 평균과 같은 값이 나옴
                )
                d_ssim += skimage.metrics.structural_similarity(p_np, d_o_np)
                # del c_img, p_img, e_o_img, d_o_img, c_np, p_np, e_o_np, d_o_np # not have to do
            total_e+=e_cnt
            total_d+=d_cnt
            # save psnr, ssim
            with open(f"{result_path}/metrics-per-batch.txt", "a+") as f:
                f.write(f"# {batch}\n")
                f.write(f"e_psnr: {e_psnr/e_cnt}\n")
                f.write(f"d_psnr: {d_psnr/d_cnt}\n")
                f.write(f"e_ssim: {e_ssim/batch_size}\n")
                f.write(f"d_ssim: {d_ssim/batch_size}\n")
            total_e_psnr += e_psnr
            total_d_psnr += d_psnr
            total_e_ssim += e_ssim
            total_d_ssim += d_ssim
            # del c, p, e_o, d_o # not have to do
        with open(result_path + "/metrics-total.txt", "w+") as f:
            f.write(f"total_e_psnr: {total_e_psnr/total_e}\n")
            f.write(f"total_d_psnr: {total_d_psnr/total_d}\n")
            f.write(f"total_e_ssim: {total_e_ssim/len_datas}\n")
            f.write(f"total_d_ssim: {total_d_ssim/len_datas}\n")


# Run


## Tree
- weight/{model.name}/{train_set}/{epoch}.pickle
    - cifar10
        - 50.pickle
        - 100.pickle
        - 150.pickle
    - celeba
        - 50.pickle
        - 100.pickle
        - 150.pickle
    - imagenet
        - 50.pickle
        - 100.pickle
        - 150.pickle
- result/{model.name}/{train_set}/{test_set}/{epoch}/{[c, p, d_o, e_d]}/*.png
    - cifar10
        - cifar10
            - 150
                - c
                - p
                - e_o
                - d_o

## Global variable

In [17]:
batch_size = 32
epochs = 150  #!!!
period = 50
lr = 1e-4
coefficients = {
    "alpha": 1,
    "beta": 1,
    "lambd": 1e-4,
}

## Train function (CIFAR10, CelebA, ImageNet)

In [18]:
def train(train_set):
    global epochs, period, coefficients, lr, batch_size, device
    
    # train loader
    if train_set=='cifar10':
        train_c_loader, train_p_loader = build_cifar10_dataloader(batch_size)[:2]
    elif train_set=='celeba':
        train_c_loader, train_p_loader = build_celeba_dataloader(batch_size)[:2]
    elif train_set=='imagenet':
        train_c_loader, train_p_loader = build_imagenet_dataloader(batch_size)[:2]
    else:
        print("[-] invalid train_set")
        return
    
    model = EDS().to(device) ###############
    model.apply(weight_init_xavier_uniform)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    weight_path = f"./weight/{model.name}/{train_set}"
    
    train_and_save(
        model,
        optimizer,
        train_c_loader,
        train_p_loader,
        epochs,
        period,
        coefficients,
        weight_path,
    )
    del train_c_loader, train_p_loader, model
    print(f"[+] train {train_set} done (epochs: {epochs})")
    

## Infer function (CIFAR10, CelebA, ImageNet)

In [19]:
def infer(train_set, infer_set):
    global epochs, period, coefficients, lr, batch_size, device
    
    # test loader
    if infer_set=='cifar10':
        test_c_loader, test_p_loader = build_cifar10_dataloader(batch_size)[2:]
    elif infer_set=='celeba':
        test_c_loader, test_p_loader = build_celeba_dataloader(batch_size)[2:]
    elif infer_set=='imagenet':
        test_c_loader, test_p_loader = build_imagenet_dataloader(batch_size)[2:]
    else:
        print("[-] invalid test_set")
        return
    
    model = EDS().to(device) ###############
    weight_path = f"./weight/{model.name}/{train_set}"
    for e in range(period, epochs+1, period):
        result_path = f"./result/{model.name}/{train_set}/{infer_set}/{e}"
        infer_from_saved_weight(model,
                                test_c_loader,
                                test_p_loader,
                                f"{weight_path}/{e}.pickle",
                                result_path)
        shutil.make_archive(f"{model.name}-{train_set}-{infer_set}-{e}", "zip", result_path)
        print(f"[+] infer {infer_set} with {model.name}-{train_set} (epochs: {e}) done")
    del test_c_loader, test_p_loader, model
    
    

In [20]:
def train_and_infer():
    datasets=['cifar10', 'celeba', 'imagenet']
    for train_set in datasets:
        train(train_set)
    for train_set in datasets:
        # train(train_set)
        for infer_set in datasets:
            infer(train_set, infer_set)
    print("[+] weight saved")

In [21]:
device = "cuda" if cuda.is_available() else "cpu"
print("[+] Device:", device)

train_and_infer()
print('[+] done')

[+] Device: cuda
Files already downloaded and verified
Files already downloaded and verified
[+] load cifar10: train_size = 40000, test_size = 20000


Total:   0%|          | 0/150 [00:00<?, ?it/s]

[+] Saving ./weight/EDS/cifar10/50.pickle
[+] Saving ./weight/EDS/cifar10/100.pickle
[+] Saving ./weight/EDS/cifar10/150.pickle
[+] train cifar10 done (epochs: 150)
[+] load celeba: train_size = 40000, test_size = 20000


Total:   0%|          | 0/150 [00:00<?, ?it/s]

[+] Saving ./weight/EDS/celeba/50.pickle
[+] Saving ./weight/EDS/celeba/100.pickle
[+] Saving ./weight/EDS/celeba/150.pickle
[+] train celeba done (epochs: 150)
[+] load imagenet: train_size = 40000, test_size = 20000


Total:   0%|          | 0/150 [00:00<?, ?it/s]

[+] Saving ./weight/EDS/imagenet/50.pickle
[+] Saving ./weight/EDS/imagenet/100.pickle
[+] Saving ./weight/EDS/imagenet/150.pickle
[+] train imagenet done (epochs: 150)
Files already downloaded and verified
Files already downloaded and verified
[+] load cifar10: train_size = 40000, test_size = 20000


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer cifar10 with EDS-cifar10 (epochs: 50) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer cifar10 with EDS-cifar10 (epochs: 100) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer cifar10 with EDS-cifar10 (epochs: 150) done
[+] load celeba: train_size = 40000, test_size = 20000


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer celeba with EDS-cifar10 (epochs: 50) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer celeba with EDS-cifar10 (epochs: 100) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer celeba with EDS-cifar10 (epochs: 150) done
[+] load imagenet: train_size = 40000, test_size = 20000


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer imagenet with EDS-cifar10 (epochs: 50) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer imagenet with EDS-cifar10 (epochs: 100) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer imagenet with EDS-cifar10 (epochs: 150) done
Files already downloaded and verified
Files already downloaded and verified
[+] load cifar10: train_size = 40000, test_size = 20000


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer cifar10 with EDS-celeba (epochs: 50) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer cifar10 with EDS-celeba (epochs: 100) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer cifar10 with EDS-celeba (epochs: 150) done
[+] load celeba: train_size = 40000, test_size = 20000


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer celeba with EDS-celeba (epochs: 50) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer celeba with EDS-celeba (epochs: 100) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer celeba with EDS-celeba (epochs: 150) done
[+] load imagenet: train_size = 40000, test_size = 20000


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer imagenet with EDS-celeba (epochs: 50) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer imagenet with EDS-celeba (epochs: 100) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer imagenet with EDS-celeba (epochs: 150) done
Files already downloaded and verified
Files already downloaded and verified
[+] load cifar10: train_size = 40000, test_size = 20000


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer cifar10 with EDS-imagenet (epochs: 50) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer cifar10 with EDS-imagenet (epochs: 100) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer cifar10 with EDS-imagenet (epochs: 150) done
[+] load celeba: train_size = 40000, test_size = 20000


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer celeba with EDS-imagenet (epochs: 50) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer celeba with EDS-imagenet (epochs: 100) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer celeba with EDS-imagenet (epochs: 150) done
[+] load imagenet: train_size = 40000, test_size = 20000


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer imagenet with EDS-imagenet (epochs: 50) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer imagenet with EDS-imagenet (epochs: 100) done


Total:   0%|          | 0/312 [00:00<?, ?it/s]

[+] infer imagenet with EDS-imagenet (epochs: 150) done
[+] weight saved
[+] done
