In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torchvision.transforms.functional as F
from scipy import ndimage
from scipy.ndimage import gaussian_filter
from numpy import save

from resnet_feature_extraction import Img2Vec


# Download CIFAR 10 data form torchvision

In [2]:

train_dataset = torchvision.datasets.CIFAR10('~/datasets/cifar', train=True, download=True)
test_dataset = torchvision.datasets.CIFAR10('~/datasets/cifar', train=False, download=True)


Files already downloaded and verified
Files already downloaded and verified


# Useful functions 

In [3]:
def apply_contrast(images_data, contrast_factor):
    to_tensor = transforms.ToTensor()   
    images_list = [to_tensor(im) for im in images_data]
    images = torch.stack(images_list)
    contrasted_images = F.adjust_contrast(images, contrast_factor)
    contrasted_images = np.array(np.stack([transforms.ToPILImage()(im) for im in contrasted_images]))
    return contrasted_images

def apply_rotation(images_data, angle=30):
    rotated_images = []
    for img in images_data:
        rotated_image = ndimage.rotate(img, angle, reshape=False)
        rotated_images.append(rotated_image)
    rotated_images = np.array(rotated_images)
    return rotated_images

def blur_images(images_data, sigma=1):
    blurred_images = []
    for img in images_data:
        blurred_image = gaussian_filter(img, sigma)
        blurred_images.append(blurred_image)
    blurred_images = np.array(blurred_images)
    return blurred_images


def apply_saturation(images_data, sat_factor):
    to_tensor = transforms.ToTensor()   
    images_list = [to_tensor(im) for im in images_data]
    images = torch.stack(images_list)
    saturated_images = F.adjust_saturation(images, sat_factor)
    saturated_images = np.array(np.stack([transforms.ToPILImage()(im) for im in saturated_images]))
    return saturated_images


# Function to extract image features using resnet    
def get_features(images, batch_size):
    Z_list = []
    # img2vec = Img2Vec(model="resnet50")
    img2vec = Img2Vec()
    for first in range(0, len(images), batch_size):
        images_subset = images[first:first+batch_size]
        Z_subset = img2vec.get_vec(images_subset)
        Z_list.append(Z_subset)
    Z = np.vstack(Z_list)
    print("Z features shape", Z.shape)
    return Z


# Transform images and extract their features

In [4]:
# Transform images 

# Transformation parameters
angle = 15 # for rotation
contrast_factor = 0.3 # for contrast
sigma = 0.5 # for blur
saturation_factor = 5 # for saturation

batch_size = 64

train_labels = np.array(train_dataset.targets)
test_labels = np.array(test_dataset.targets)

X_train_og = train_dataset.data
X_train_rotated = apply_rotation(train_dataset.data, angle)
X_train_contrasted = apply_contrast(train_dataset.data, contrast_factor) 
X_train_blurred = blur_images(train_dataset.data, sigma)
X_train_saturated = apply_saturation(train_dataset.data, saturation_factor)


X_test_og = test_dataset.data
X_test_rotated = apply_rotation(test_dataset.data, angle)
X_test_contrasted = apply_contrast(test_dataset.data, contrast_factor) 
X_test_blurred = blur_images(test_dataset.data, sigma)
X_test_saturated = apply_saturation(test_dataset.data, saturation_factor)


In [5]:
# Extract image features 
print("### Extracting features using a pre-trained resnet 18 ###")
Z_train_og = get_features(X_train_og, batch_size)
Z_train_rotated = get_features(X_train_rotated, batch_size)
Z_train_contrasted = get_features(X_train_contrasted, batch_size)
Z_train_blurred = get_features(X_train_blurred, batch_size)
Z_train_saturated = get_features(X_train_saturated, batch_size)


Z_test_og = get_features(X_test_og, batch_size )
Z_test_rotated = get_features(X_test_rotated, batch_size)
Z_test_contrasted = get_features(X_test_contrasted, batch_size)
Z_test_blurred = get_features(X_test_blurred, batch_size)
Z_test_saturated = get_features(X_test_saturated, batch_size)


### Extracting features using a pre-trained resnet 18 ###
Z features shape (50000, 512)
Z features shape (50000, 512)
Z features shape (50000, 512)
Z features shape (50000, 512)
Z features shape (50000, 512)
Z features shape (10000, 512)
Z features shape (10000, 512)
Z features shape (10000, 512)
Z features shape (10000, 512)
Z features shape (10000, 512)


In [6]:
# Save extracted features and image labels as numpy files

save('../data/Z_train_og_cifar10_resnet.npy', Z_train_og)
save('../data/Z_train_rotated_cifar10_resnet.npy', Z_train_rotated)
save('../data/Z_train_contrasted_cifar10_resnet.npy', Z_train_contrasted)
save('../data/Z_train_blurred_cifar10_resnet.npy', Z_train_blurred)
save('../data/Z_train_saturated_cifar10_resnet.npy', Z_train_saturated)

save('../data/Z_test_og_cifar10_resnet.npy', Z_test_og)
save('../data/Z_test_rotated_cifar10_resnet.npy', Z_test_rotated)
save('../data/Z_test_contrasted_cifar10_resnet.npy', Z_test_contrasted)
save('../data/Z_test_blurred_cifar10_resnet.npy', Z_test_blurred)
save('../data/Z_test_saturated_cifar10_resnet.npy', Z_test_saturated)

save('../data/train_labels_cifar10.npy', train_labels)
save('../data/test_labels_cifar10.npy', test_labels)