# import requred libraries and packages

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torchvision.transforms.functional as F
from scipy import ndimage
from scipy.ndimage import gaussian_filter
from numpy import save

# for simclr
from simclr_feature_extraction import get_features
import argparse
from simclr import SimCLR
from simclr.modules import get_resnet
from pprint import pprint
from utils import yaml_config_hook

# Download CIFAR 10 data form torchvision

In [2]:

train_dataset = torchvision.datasets.CIFAR10('~/datasets/cifar', train=True, download=True)
test_dataset = torchvision.datasets.CIFAR10('~/datasets/cifar', train=False, download=True)


Files already downloaded and verified
Files already downloaded and verified


# Useful functions 

In [3]:
def apply_contrast(images_data, contrast_factor):
    to_tensor = transforms.ToTensor()   
    images_list = [to_tensor(im) for im in images_data]
    images = torch.stack(images_list)
    contrasted_images = F.adjust_contrast(images, contrast_factor)
    contrasted_images = np.array(np.stack([transforms.ToPILImage()(im) for im in contrasted_images]))
    return contrasted_images

def apply_rotation(images_data, angle=30):
    rotated_images = []
    for img in images_data:
        rotated_image = ndimage.rotate(img, angle, reshape=False)
        rotated_images.append(rotated_image)
    rotated_images = np.array(rotated_images)
    return rotated_images

def blur_images(images_data, sigma=1):
    blurred_images = []
    for img in images_data:
        blurred_image = gaussian_filter(img, sigma)
        blurred_images.append(blurred_image)
    blurred_images = np.array(blurred_images)
    return blurred_images


def apply_saturation(images_data, sat_factor):
    to_tensor = transforms.ToTensor()   
    images_list = [to_tensor(im) for im in images_data]
    images = torch.stack(images_list)
    saturated_images = F.adjust_saturation(images, sat_factor)
    saturated_images = np.array(np.stack([transforms.ToPILImage()(im) for im in saturated_images]))
    return saturated_images


# Transform images and extract their features

In [4]:
# Setting up the parameters from the parameters file config.yaml
parser = argparse.ArgumentParser(description="SimCLR")
config = yaml_config_hook("./config.yaml")
for k, v in config.items():
    parser.add_argument(f"--{k}", default=v, type=type(v))

args = parser.parse_args([])
args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# args.device = torch.device("cpu")

# override some configuration parameters 
args.batch_size = 64
# pprint(vars(args))

# don't load a pre-trained model from PyTorch repo
encoder = get_resnet(args.resnet, pretrained=False) 

# get dimensions of fc layer
n_features = encoder.fc.in_features  

# load pre-trained SimCLR model from checkpoint - pr-etrained on CIFAR10
simclr_model = SimCLR(encoder, args.projection_dim, n_features)
model_fp = './saved_models/checkpoint_100_simclr_original.tar'

simclr_model.load_state_dict(torch.load(model_fp, map_location=args.device.type))
simclr_model = simclr_model.to(args.device)
simclr_model.eval()

print("Done.")

Done.


In [5]:
# Transform images 

# Transformation parameters
angle = 15 # for rotation
contrast_factor = 0.3 # for contrast
sigma = 0.5 # for blur
saturation_factor = 5 # for saturation

train_labels = np.array(train_dataset.targets)
test_labels = np.array(test_dataset.targets)

X_train_og = train_dataset.data
X_train_rotated = apply_rotation(train_dataset.data, angle)
X_train_contrasted = apply_contrast(train_dataset.data, contrast_factor) 
X_train_blurred = blur_images(train_dataset.data, sigma)
X_train_saturated = apply_saturation(train_dataset.data, saturation_factor)


X_test_og = test_dataset.data
X_test_rotated = apply_rotation(test_dataset.data, angle)
X_test_contrasted = apply_contrast(test_dataset.data, contrast_factor) 
X_test_blurred = blur_images(test_dataset.data, sigma)
X_test_saturated = apply_saturation(test_dataset.data, saturation_factor)


In [6]:
# Extract image features
print("### Extracting features using SimCLR model pre-trained on CIFAR10 ###")
Z_train_og, _ = get_features(X_train_og, simclr_model, args.batch_size, args.device)
Z_train_rotated, _ = get_features(X_train_rotated, simclr_model, args.batch_size, args.device)
Z_train_contrasted, _= get_features(X_train_contrasted, simclr_model, args.batch_size, args.device)
Z_train_blurred, _= get_features(X_train_blurred, simclr_model, args.batch_size, args.device)
Z_train_saturated, _ = get_features(X_train_saturated, simclr_model, args.batch_size, args.device)


Z_test_og, _ = get_features(X_test_og, simclr_model, args.batch_size, args.device)
Z_test_rotated, _ = get_features(X_test_rotated, simclr_model, args.batch_size, args.device)
Z_test_contrasted, _= get_features(X_test_contrasted, simclr_model, args.batch_size, args.device)
Z_test_blurred, _= get_features(X_test_blurred, simclr_model, args.batch_size, args.device)
Z_test_saturated, _ = get_features(X_test_saturated, simclr_model, args.batch_size, args.device)


### Extracting features using SimCLR model pre-trained on CIFAR10 ###
h features shape (50000, 512)
h features shape (50000, 512)
h features shape (50000, 512)
h features shape (50000, 512)
h features shape (50000, 512)
h features shape (10000, 512)
h features shape (10000, 512)
h features shape (10000, 512)
h features shape (10000, 512)
h features shape (10000, 512)


In [7]:
# Save extracted features and image labels as numpy files

save('../data/Z_train_og_cifar10_simclr.npy', Z_train_og)
save('../data/Z_train_rotated_cifar10_simclr.npy', Z_train_rotated)
save('../data/Z_train_contrasted_cifar10_simclr.npy', Z_train_contrasted)
save('../data/Z_train_blurred_cifar10_simclr.npy', Z_train_blurred)
save('../data/Z_train_saturated_cifar10_simclr.npy', Z_train_saturated)

save('../data/Z_test_og_cifar10_simclr.npy', Z_test_og)
save('../data/Z_test_rotated_cifar10_simclr.npy', Z_test_rotated)
save('../data/Z_test_contrasted_cifar10_simclr.npy', Z_test_contrasted)
save('../data/Z_test_blurred_cifar10_simclr.npy', Z_test_blurred)
save('../data/Z_test_saturated_cifar10_simclr.npy', Z_test_saturated)

save('../data/train_labels_cifar10.npy', train_labels)
save('../data/test_labels_cifar10.npy', test_labels)