In [1]:
import argparse
import random
import sys
import os

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
from tqdm import tqdm
%matplotlib inline

sys.path.append("../learnable_encryption_robustness/encryption/")
from ELE import blockwise_scramble_ele, block_location_shuffle

In [2]:
DATA_PATH = "../data/"
ENCRYPTION = "ele"

In [3]:
transform = transforms.Compose([transforms.ToTensor()])

In [4]:
to_pil = transforms.ToPILImage()

In [5]:
# mkdir
for data_type in ["train", "test"]:
    for key_condition in ["same_key", "diff_key"]:
        os.makedirs(f"../encrypted_images/{ENCRYPTION}/{data_type}/{key_condition}/",
                    exist_ok=True)

## same key

In [6]:
KEY_COND = "same_key"

In [7]:
ds = torchvision.datasets.CIFAR10(
    root=DATA_PATH, train=True, transform=transform)

In [8]:
for idx in tqdm(range(len(ds))):
    img = torch.unsqueeze(ds[idx][0], 0).numpy().astype("float32")
    x_stack = blockwise_scramble_ele(img, seed=0)
    img = np.transpose(x_stack, (0, 3, 1, 2))
    img = block_location_shuffle(img, seed=0)
    img = torch.from_numpy(img)
    img = to_pil(img[0])
    img.save(f"../encrypted_images/{ENCRYPTION}/train/{KEY_COND}/{idx:05}.png")

100%|██████████| 50000/50000 [51:22<00:00, 16.22it/s]  


In [9]:
ds = torchvision.datasets.CIFAR10(
    root=DATA_PATH, train=False, transform=transform)

In [10]:
for idx in tqdm(range(len(ds))):
    img = torch.unsqueeze(ds[idx][0], 0).numpy().astype("float32")
    x_stack = blockwise_scramble_ele(img, seed=0)
    img = np.transpose(x_stack, (0, 3, 1, 2))
    img = block_location_shuffle(img, seed=0)
    img = torch.from_numpy(img)
    img = to_pil(img[0])
    img.save(f"../encrypted_images/{ENCRYPTION}/test/{KEY_COND}/{idx:05}.png")

100%|██████████| 10000/10000 [09:38<00:00, 17.29it/s]


## diff key

In [17]:
KEY_COND = "diff_key"

In [18]:
ds = torchvision.datasets.CIFAR10(
    root=DATA_PATH, train=True, transform=transform)

In [19]:
for idx in tqdm(range(len(ds))):
    img = torch.unsqueeze(ds[idx][0], 0).numpy().astype("float32")
    x_stack = blockwise_scramble_ele(img, seed=idx)
    img = np.transpose(x_stack, (0, 3, 1, 2))
    img = block_location_shuffle(img, seed=idx)
    img = torch.from_numpy(img)
    img = to_pil(img[0])
    img.save(f"../encrypted_images/{ENCRYPTION}/train/{KEY_COND}/{idx:05}.png")

100%|██████████| 50000/50000 [59:30<00:00, 14.00it/s]   


In [14]:
# ds = torchvision.datasets.CIFAR10(
#     root=DATA_PATH, train=False, transform=transform)

In [15]:
# for idx in tqdm(range(len(ds))):
#     img = torch.unsqueeze(ds[idx][0], 0).numpy().astype("float32")
#     x_stack = blockwise_scramble_ele(img, seed=idx+50000)
#     img = np.transpose(x_stack, (0, 3, 1, 2))
#     img = block_location_shuffle(img, seed=0)
#     img = torch.from_numpy(img)
#     img = to_pil(img[0])
#     img.save(f"../encrypted_images/{ENCRYPTION}/test/{KEY_COND}/{idx:05}.png")