In [1]:
import gzip
import numpy as np
import os
import pickle
import random

from argparse import ArgumentParser
from copy import deepcopy
from sklearn.model_selection import train_test_split
from urllib.request import urlretrieve
from PIL import Image, ImageOps
from torchvision.transforms import functional as TF


In [2]:
SEED = 0
np.random.seed(SEED)


In [3]:
def unpickle(file):

    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def save_data(images, labels, dir, name, split, center):

    save_as = f"{name.lower()}_{split}_{center}.npz"
    np.savez(os.path.join(dir, save_as), 
                images=images,
                labels=labels)

In [4]:
root = "data/CIFAR10"
image_size = 32
number_of_labels = 10
patch_size = 4
labels = [i for i in range(number_of_labels)]

train_files = [f"data/CIFAR10/cifar-10-batches-py/data_batch_{i}" for i in range(1,6)]
test_file = "data/CIFAR10/cifar-10-batches-py/test_batch"

train_images = np.concatenate([unpickle(file)[b"data"] for file in train_files]).reshape(-1,3,32,32)
train_labels = np.concatenate([unpickle(file)[b"labels"] for file in train_files])

test_images = np.array(unpickle(test_file)[b"data"].reshape(-1,3,32,32))
test_labels = np.array(unpickle(test_file)[b"labels"])


In [5]:
def get_shuffled_subset(images, labels, centers_per_group):

    images, labels = shuffle(images, labels)

    split_labels = np.array_split(labels, centers_per_group)
    split_images = np.array_split(images, centers_per_group)

    return split_images, split_labels

def shuffle(images, labels):
    n_samples = len(labels)
    shuffled_indices = np.random.permutation(n_samples)

    labels = labels[shuffled_indices]
    images = images[shuffled_indices]

    return images, labels


def rotate(dataset, rotation_angle, border_size=0):
    images = []
    for img in dataset:
        img = img.transpose(1,2,0)
        img = Image.fromarray(img, mode='RGB')
        img = ImageOps.expand(img, border=border_size, fill='black')

        # Apply rotation transformation using torchvision's functional API
        rotated_img = TF.rotate(img, rotation_angle, expand=True)
        # Convert to numpy array and store
        images.append(np.array(rotated_img))  # Convert tensor to numpy array
    return np.array(images)

In [6]:
# all_cases = ['max', 'no', 'min', 'two']
all_cases = ['max']
all_centers_per_group = [10]

for case in all_cases:
    for centers_per_group in all_centers_per_group:

        if case == 'max':
            # rot_groups = [45, 135, 225, 315] # max
            rot_groups = [0,90,180,270] # max
        elif case == 'no':
            rot_groups = [0, 0, 0, 0] # no
        elif case == 'min':
            # rot_groups = [42, 44, 46, 48] # min
            rot_groups = [-3, -1, 1, 3] # min
        elif case == 'two':
            rot_groups = [-3, 3, 177, 183] # two
        else:
            exit()

        border_size = 0
        if case == 'min':
            border_size = 2

        dest_dir = f'data/CIFAR10/rotcifar10hard{case}{centers_per_group}c'
        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)
        else:
            print("Warning, existing files might be overwritten")

        # duplicate and shuffle data so that we have num_cluster x data_len data
        shuffled_trains = [(shuffle(train_images, train_labels)) for _ in range(len(rot_groups))]
        shuffled_tests = [(shuffle(test_images, test_labels)) for _ in range(len(rot_groups))]

        # go through each rotation
        for i, rot_group in enumerate(rot_groups):

            grouped_train_images, grouped_train_labels = shuffled_trains[i]
            # split data into centers_per_group chunks
            split_train_images, split_train_labels = get_shuffled_subset(grouped_train_images, grouped_train_labels, centers_per_group)
            for j, (imgs, lbls) in enumerate(zip(split_train_images, split_train_labels)):
                imgs = rotate(imgs, rot_group, border_size=border_size)
                X_train, X_val, y_train, y_val = train_test_split(imgs, lbls, test_size=0.1, random_state=SEED)
                center_num = i * centers_per_group + j
                save_data(X_train, y_train, dest_dir, "rotcifar10", "train", center_num)
                save_data(X_val, y_val, dest_dir, "rotcifar10", "val", center_num)
                print("center:", center_num, "train", len(y_train), "val", len(y_val))

            # same for test
            grouped_test_images, grouped_test_labels = shuffled_tests[i]
            split_test_images, split_test_labels = get_shuffled_subset(grouped_test_images, grouped_test_labels, centers_per_group)
            for j, (imgs, lbls) in enumerate(zip(split_test_images, split_test_labels)):
                imgs = rotate(imgs, rot_group, border_size=border_size)
                center_num = i * centers_per_group + j
                save_data(imgs, lbls, dest_dir, "rotcifar10", "test", center_num)
                print("center:", center_num, "test", len(lbls))

        duplicated_rot_groups = [group for group in rot_groups for _ in range(centers_per_group)]

        log = [[grp for grp in duplicated_rot_groups], [grp for grp in rot_groups]]
        with open(os.path.join(dest_dir, "log.txt"), "w") as f:
            f.write(str(log))

center: 0 train 4500 val 500
center: 1 train 4500 val 500
center: 2 train 4500 val 500
center: 3 train 4500 val 500
center: 4 train 4500 val 500
center: 5 train 4500 val 500
center: 6 train 4500 val 500
center: 7 train 4500 val 500
center: 8 train 4500 val 500
center: 9 train 4500 val 500
center: 0 test 1000
center: 1 test 1000
center: 2 test 1000
center: 3 test 1000
center: 4 test 1000
center: 5 test 1000
center: 6 test 1000
center: 7 test 1000
center: 8 test 1000
center: 9 test 1000
center: 10 train 4500 val 500
center: 11 train 4500 val 500
center: 12 train 4500 val 500
center: 13 train 4500 val 500
center: 14 train 4500 val 500
center: 15 train 4500 val 500
center: 16 train 4500 val 500
center: 17 train 4500 val 500
center: 18 train 4500 val 500
center: 19 train 4500 val 500
center: 10 test 1000
center: 11 test 1000
center: 12 test 1000
center: 13 test 1000
center: 14 test 1000
center: 15 test 1000
center: 16 test 1000
center: 17 test 1000
center: 18 test 1000
center: 19 test 1000
