In [1]:
import os
import pickle
import numpy as np
import tqdm
from matplotlib import pyplot as plt #Use for image debugging
import torch 
from torchvision import models, transforms #Need this to get VGG-11
from sklearn.model_selection import train_test_split

In [7]:
# Load the processed batches (based on James' code)
# The original data is organize slightly differently so it's somewhat messy. Sorry. 

def load_processed_batches(path, test = 0):
    #Path is the directory where the files of interest are
    
    data = []
    labels = []

    print('Loading data...')
    if path == 'cifar10_dataset/cifar-10-batches-py':
        #The original data is in its own folder with some extra files
        #So we only loop through the ones that we care about 
        files = os.listdir(path)
        files2 = []
        for i in range(len(files)): #If we want train data
            if test == 0: 
                if 'data_batch' in files[i]: 
                    files2.append(files[i])
            else: 
                if 'test_batch' in files[i]: 
                    files2.append(files[i])
        for file in tqdm.tqdm(files2): #If we want test data
            with open(os.path.join(path, file), 'rb') as f:
                processed_batch_dict = pickle.load(f, encoding='bytes')
                data.append(processed_batch_dict[b'data'])
                labels.append(processed_batch_dict[b'labels'])
        
        #Store the data and labels 
        data = np.concatenate(data)
        data = data.astype(np.float32) / 255 #Divide by 255 to get into 0 to 1 range 
        labels = np.concatenate(labels)
        
        #Reshape to the same dimensions as the processed data
        data = data.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)

    else: #If we're using black boxes or Gaussian noise data 
        if test == 1: path = path + '_test' 
        files = os.listdir(path)
        for file in tqdm.tqdm(files):
            with open(os.path.join(path, file), 'rb') as f:
                processed_batch_dict = pickle.load(f, encoding='bytes')
                data.append(processed_batch_dict['data'])
                labels.append(processed_batch_dict['labels'])
        #Store the data and labels 
        data = np.concatenate(data)
        data = data.astype(np.float32) / 255 #Divide by 255 to get into 0 to 1 range 

        labels = np.concatenate(labels)
        labels = np.repeat(labels,9) #assume that the same image is repeated 9 times for each of the superpixels (3-by-3)

    return data, labels

In [3]:
# Set up the transformation for the data 

# Normalize images via the statistics of the original dataset
path = 'cifar10_dataset/cifar-10-batches-py'
processed_images, _ = load_processed_batches(path,test=0)

imMean = np.mean(processed_images.reshape(-1,3),axis=0)
imStd = np.std(processed_images.reshape(-1,3),axis=0)

print('Mean:',imMean)
print('Std Dev:', imStd)

#Set-up normalization
normalize = transforms.Normalize(mean=imMean,std=imStd)

Loading data...


100%|██████████| 5/5 [00:00<00:00, 51.52it/s]


Mean: [0.32768 0.32768 0.32768]
Std Dev: [0.27755317 0.26929596 0.26811677]


In [8]:
#Get data and format it for training 
path = 'processed_batches_boxes' #cifar10_dataset/cifar-10-batches-py'
processed_images, processed_labels = load_processed_batches(path,test=0)
processed_images = processed_images.transpose(0,3,1,2) #Get into the appropriate shape for training

trainData, valData, trainLabel, valLabel = train_test_split(processed_images, processed_labels, test_size=0.2, random_state=42)
train_set = torch.utils.data.TensorDataset(normalize(torch.tensor(trainData)),torch.tensor(trainLabel).type(torch.LongTensor))
val_set = torch.utils.data.TensorDataset(normalize(torch.tensor(valData)),torch.tensor(valLabel).type(torch.LongTensor))

#Get data and format it for testing 
processed_images, processed_labels = load_processed_batches(path,test=1)
processed_images = processed_images.transpose(0,3,1,2) #Get into the appropriate shape for training
test_set = torch.utils.data.TensorDataset(normalize(torch.tensor(processed_images)),torch.tensor(processed_labels).type(torch.LongTensor))


Loading data...


100%|██████████| 5/5 [00:00<00:00, 11.60it/s]


Loading data...


100%|██████████| 5/5 [00:00<00:00, 31.89it/s]


In [21]:
# Parameters
batch_size = 64 
num_classes = 10 
momentum = 0.9 
learning_rate = 0.005 
weight_decay = 0.005 
num_epochs = 20
num_workers = 2

# Starting parameters from: https://blog.paperspace.com/alexnet-pytorch/ 

#Set-up dataloaders
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers = num_workers)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers = num_workers)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers = num_workers)

#Get model
mod = models.vgg11(weights=None)
mod.classifier[6].out_features = num_classes #Adjust final layer to have the right number of classes 

#Get device 
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
mod.to(device)

if torch.cuda.is_available(): 
    print(torch.cuda.get_device_name(0))

NVIDIA GeForce RTX 4090


In [22]:
#Training 
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(mod.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum)

total_step = len(train_loader)

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  
        # Move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = mod(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print ('Epoch [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, loss.item()))
            
    # Validation
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = mod(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            del images, labels, outputs
    
        print('Accuracy of the network on validation images: {} %'.format(100 * correct / total)) 

Epoch [1/20], Loss: 5625.0000
Accuracy of the network on validation images: 68.0911111111111 %
Epoch [2/20], Loss: 5625.0000
Accuracy of the network on validation images: 77.04666666666667 %
Epoch [3/20], Loss: 5625.0000
Accuracy of the network on validation images: 81.78222222222222 %
Epoch [4/20], Loss: 5625.0000
Accuracy of the network on validation images: 79.35111111111111 %
Epoch [5/20], Loss: 5625.0000
Accuracy of the network on validation images: 84.82222222222222 %
Epoch [6/20], Loss: 5625.0000
Accuracy of the network on validation images: 85.41444444444444 %
Epoch [7/20], Loss: 5625.0000
Accuracy of the network on validation images: 84.30555555555556 %
Epoch [8/20], Loss: 5625.0000
Accuracy of the network on validation images: 84.26555555555555 %
Epoch [9/20], Loss: 5625.0000
Accuracy of the network on validation images: 84.00222222222222 %
Epoch [10/20], Loss: 5625.0000
Accuracy of the network on validation images: 86.87777777777778 %
Epoch [11/20], Loss: 5625.0000
Accuracy 

In [23]:
# Testing
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = mod(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        del images, labels, outputs

    print('Accuracy of the network on test images: {} %'.format(100 * correct / total)) 

Accuracy of the network on test images: 75.76666666666667 %
