In [None]:
import datetime
import glob
import os
import pickle
import random

from pathlib import Path

Code in this notebook is based on <https://github.com/MBMS80/Writing-Cifar10-dataset-to-image-files-as-.tif-or-.jpg->  
( original author:
[Mehdi Maboudi](https://www.tu-braunschweig.de/igp/mitarbeiter/maboudi/), September 2019,
 Technical University of Braunschweig )
 
Implements:
- Writing Cifar10 dataset to image files as '.tif' or '.jpg'  
- Reading image files into numpy arrays compatible with the standard Cifar10 dataset

In [None]:
!pip install numpy
!pip install matplotlib
!pip install tqdm
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
!pip install imageio
import imageio #Python library for reading and writing image data

In [None]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='iso-8859-1')
    return dict

In [None]:
!pip install torch
!pip install torchvision
import torchvision

cifar10 = torchvision.datasets.CIFAR10('./data', train=True, download=True)
cifar10_val = torchvision.datasets.CIFAR10('./data', train=False, download=True)

In [None]:
cifar10

The [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) consists of 60000 32x32 colour images in 10 classes, with 6000 images per class.  
There are 50000 training images and 10000 test images.  
***
__png__ images of CIFAR-10 will be saved in 10 subdirectories of each label under the __test__ and __train__ directories as below.


In [None]:
cifar10_val

In [None]:
# Load cifar10 from local files
DATA_DIR   = './data/'
CIFAR10_DIR   = DATA_DIR + 'cifar-10-batches-py/'

In [None]:
training_batch_pickle_files = sorted(glob.glob(CIFAR10_DIR + 'data_batch_*'))
test_batch_pickle_file = CIFAR10_DIR + 'test_batch'
meta_data_pickle_file  = CIFAR10_DIR + 'batches.meta'

print(training_batch_pickle_files)

In [None]:
meta_data = unpickle(meta_data_pickle_file)
# print(meta_data)
# {'num_cases_per_batch': 10000,
# 'label_names': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
# 'num_vis': 3072}

In [None]:
# class_names = ['airplane',
# 'automobile',
# 'bird',
# 'cat',
# 'deer',
# 'dog',
# 'frog',
# 'horse',
# 'ship',
# 'truck']
# nb_classes = len(class_names)

In [None]:
class_names = meta_data['label_names']
nb_classes = len(class_names)
print(nb_classes, class_names)

In [None]:
cifar_data = []
for pickle_file in training_batch_pickle_files:
    data_dict = unpickle(pickle_file)
    print(f"{data_dict['batch_label']}: len={len(data_dict['labels'])} {data_dict.keys()}")
    assert(len(data_dict['labels']) == len(data_dict['data']))
    assert(len(data_dict['labels']) == len(data_dict['filenames']))
    cifar_data.append(data_dict)
    
print()
test_batch = unpickle(test_batch_pickle_file)
print(f"{test_batch['batch_label']}: len={len(test_batch['labels'])} {test_batch.keys()}")


In [None]:
print(data_dict['batch_label'], len(data_dict['filenames']), data_dict['filenames'][0:5])
cifar10_batch5 = cifar_data[4]['data'] 
print("dtype:", cifar10_batch5.dtype, "shape:", cifar10_batch5.shape)
print(cifar10_batch5)

In [None]:
def reshape_cifar_img_data(imgdata_batch):
    reshaped = np.reshape(imgdata_batch, (imgdata_batch.shape[0], 3, 32, 32))
    # print(reshaped.shape, end=' ')
    reshaped = np.transpose(reshaped, axes=(0,2,3,1))
    # print("->", reshaped.shape)
    return reshaped

def reshape_cifar_img(imgdata):
    reshaped = np.reshape(imgdata, (3, 32, 32))
    # print(reshaped.shape, end=' ')
    reshaped = np.transpose(reshaped, axes=(1,2,0))
    # print("->", reshaped.shape)
    return reshaped

# def reshape_imgdata(imgdata, flatten=0):
#     assert len(imgdata.shape) == 3, imgdata.shape # (Width, Height, Channels=3)
#     reshaped = np.transpose(imgdata,axes=(2,0,1))
#     if flatten == 1:
#         outshape = (imgdata.shape[2]*imgdata.shape[0]*imgdata.shape[1],) # completely flatten to 1D
#     elif flatten == 2:
#         outshape = (imgdata.shape[2],imgdata.shape[0]*imgdata.shape[1]) # C=3 channels, W x H values 
#     elif flatten:
#         assert False, f"Unsupported arg: flatten={flatten}"
#         flatten = 0
#     if flatten:
#         reshaped = np.reshape(reshaped,outshape)
#     return reshaped

In [None]:
# VISUALIZE IMAGES

def plotImages_categories( images, labels, n_rows=5, n_cols=4, figsize=(10, 10)):
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    axes = axes.flatten()
    for i in range(len(axes)):
        axes[i].imshow(images[i])        
        axes[i].set_xticks(())
        axes[i].set_yticks(())
        
        class_index = labels[i]
        title = class_names[class_index]
        axes[i].set_title(title, fontdict={'family':'monospace'}, loc='left')
    plt.tight_layout()
    plt.show()

reshaped_batch5 = reshape_cifar_img_data(cifar_data[4]['data'])
# print(reshaped_batch5[0])
plotImages_categories(images=reshaped_batch5, labels=cifar_data[4]['labels'], n_rows=5, n_cols=5, figsize=(7, 7))

### Helper function to make directories and write images

In [None]:
IMAGES_DIR = DATA_DIR +'Cifar10_images'
im_format = '.png'  # '.tif'

CHECKPOINT_DIR = DATA_DIR +'checkpoints'
if not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)


In [None]:
def make_directories(images_dir,split):
    for dir_name in class_names:
        dir_full_path = os.path.join(images_dir+'/'+split,dir_name)
        if not os.path.exists(dir_full_path):
            os.makedirs(dir_full_path)

def write_images_to_split_directory(batch,split,path,write_data=False,class_counter=None):
    if class_counter is None:
        class_counter=np.zeros(nb_classes,dtype=int)

    if write_data:
        make_directories(path,split=split)
    images,labels = reshape_cifar_img_data(batch['data']), batch['labels']
    filenames = batch['filenames']
    n_images = images.shape[0]       # number of images: one data tuple per image
    assert n_images==len(labels)     # each image has a class label
    assert n_images==len(filenames)  # each image has a filename
    for i in range(n_images):
        class_index = labels[i]
        class_  = class_names[class_index]
        destination_dir = os.path.join(path,split,class_)
        filename = Path(filenames[i]).stem
        outputname = f"{class_counter[class_index]:04d}_{filename}{im_format}"
        # if i % 200 == 0:
        #     print(destination_dir, outputname)
        if write_data:
            imageio.imwrite(destination_dir+'/'+outputname, images[i], im_format)
        class_counter[class_index] +=1 
        
    print('classnames',class_names)
    print('images/class = ',class_counter)
    return class_counter


In [None]:
class_counter=np.zeros(nb_classes,dtype=int)
WRITE_DATA = True

def export_cifar10_to_image_dirs(train_data, test_data, data_dir=IMAGES_DIR,
                                 class_counter=None, write_data=WRITE_DATA):
    split = 'train'
    for batch in train_data:
        # train_labels = batch['labels'] #np.ravel(y_train)

        print(f"\nWriting {batch['batch_label']}: {len(batch['labels'])} images to {data_dir}/{split}")

        class_counter = write_images_to_split_directory(batch, split=split, path=data_dir,
                                        write_data=write_data, class_counter=class_counter)        

    split = 'test'
    print(f"\nWriting {test_data['batch_label']}: {len(test_data['labels'])} images to '{split}'")
    test_counter = write_images_to_split_directory(test_data, split=split, path=IMAGES_DIR,
                                        write_data=write_data, class_counter=None)        
    print()

In [None]:
export_cifar10_to_image_dirs(cifar_data, test_batch, data_dir=IMAGES_DIR)

In [None]:
# Check results visually
sample_idx = 7 # 0, 1, 3, 4, 6, 7  (assumption: first image in a given class) 
sample_class = cifar_data[0]['labels'][sample_idx]
sample_origdata = cifar_data[0]['data'][sample_idx]
print(class_names[sample_class])
filepattern = f"{IMAGES_DIR}/train/{class_names[sample_class]}/0000_*{im_format}"
print(filepattern)
sample_file = glob.glob(filepattern)[0]
print(sample_file)


#sample  = imageio.imread('data/Cifar10_images/train/frog/0000'+im_format)
sample  = imageio.v2.imread(sample_file)
print(sample.shape)
plt.subplot(1,2,1)
plt.imshow(sample)
# reshaped_sample = reshape_imgdata(sample)
# print(reshaped_sample.shape)

plt.subplot(1,2,2)
print(sample_origdata.shape)
sample_img = reshape_cifar_img(sample_origdata)
plt.imshow(sample_img)
plt.show()

# Check results numerically
print('sample min & max:',sample.min(),sample.max())
print('source min & max:',sample_origdata.min(),sample_origdata.max())
print('\nsample patch:\n',sample[0:5,0:5,1])
print('source patch:\n',sample_img[0:5,0:5,1])

In [None]:
def read_image_data_(data_path, split, class_labels=None):
    class_counter = None
    class_names = []
    # for dir_name in class_names:
    dir_full_path = os.path.join(data_path+'/'+split)
    if not os.path.exists(dir_full_path):
        assert False, "Path does not exist: "+dir_full_path
    img_dirs = sorted([f for f in Path(dir_full_path).iterdir() if f.is_dir()])
    class_labels = [f.name for f in img_dirs]
    print("class_labels (from directory names):", class_labels)
    nb_classes = len(class_labels)
    if nb_classes > 0:
        class_counter=np.zeros(nb_classes,dtype=int)

    img_tensors = []
    for iclass, imgdir in enumerate(img_dirs):
        img_files = sorted(imgdir.glob('*'+im_format)) # list all images 
        img_file_names = list(map(lambda f: f.stem, img_files)) # list all images 
        print(img_file_names[0:5])
        # imgrecs = [reshape_imgdata(imageio.v2.imread(f), flatten=0) for f in img_files]
        imgrecs = [imageio.v2.imread(f) for f in img_files]
        npdata = np.stack(imgrecs)  # one row per image
        assert npdata.shape[0] == len(imgrecs) # by def of np.stack()
        print(npdata.shape)
        class_counter[iclass] = len(imgrecs)
        img_tensors.append(npdata)  # python list of 2D nparray: num_recs x flattened_img_data
        # (one list entry per class label)
        
    print()
    print('classnames',class_labels)
    print('images/class = ',class_counter)
    return img_tensors, class_labels, class_counter

In [None]:
IMAGES_DIR   # (just checking)

In [None]:
training_data, _labels, class_counts = read_image_data_(IMAGES_DIR, 'train', class_labels=None)

In [None]:
validation_data, _labels, class_counts = read_image_data_(IMAGES_DIR, 'test', class_labels=None)


In [None]:
import torch
from torchvision.transforms import v2

transform_to_normalized_tensors=v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=(0.4915, 0.4823, 0.4468),
                         std=(0.2470, 0.2435, 0.2616))
])

In [None]:
# cifar10x = datasets.CIFAR10(
#     data_path, train=True, download=False, transform=transform_to_normalized_tensors
# )

# cifar10x_val = datasets.CIFAR10(
#     data_path, train=False, download=False,
#     transform=transform_to_normalized_tensors
# )

In [None]:
def filter_cifar_data(cifar_data, label_map=None, class_names=class_names):
    if label_map is None:
        label_map = {0:0, 1:1}
    _cifar_data_ = []
    counts = {}
    for i, cifar10_class_idx in label_map.items():
        class_name = class_names[cifar10_class_idx]
        counts[class_name] = 0
        # print(class_name)
        for img in cifar_data[cifar10_class_idx]:
            _cifar_data_.append((transform_to_normalized_tensors(img),i))
            counts[class_name] += 1

    # print(len(_cifar_data_), _cifar_data_[0][0].shape, _cifar_data_[0][0].dtype)
    print(counts)
    return _cifar_data_

def load_val_data(label_map=None, class_names=class_names):
    return filter_cifar_data(validation_data, label_map=label_map, class_names=class_names)

cifar2 = filter_cifar_data(training_data, label_map={0: 0, 1: 2})  # airplane, bird
cifar2_val = filter_cifar_data(validation_data, label_map={0: 0, 1: 2})  # airplane, bird

print(len(cifar2), cifar2[0][0].shape, cifar2[0][0].dtype)
print(len(cifar2_val), cifar2_val[0][0].shape, cifar2_val[0][0].dtype)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(32*32*3, 32)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(32, 2)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, x):
        out = x.view(-1,32*32*3)
        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        out = self.softmax(out)
        return out
    
model = Net()

In [None]:
class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.fc0 = nn.Linear(32*32*3, 512)
        self.fc1 = nn.Linear(512, 32)
        self.fc2 = nn.Linear(32, 2)
        
    def forward(self, x):
        out = x.view(-1,32*32*3)
        out = F.gelu(self.fc0(out))
        out = F.gelu(self.fc1(out))
        out = F.log_softmax(self.fc2(out), dim=1)
        return out
    
model2 = Net2()

In [None]:
class Net3(nn.Module):
    def __init__(self):
        super(Net3, self).__init__()
        # self.conv1 = nn.Conv2d(C, C2, kernel_size=3, padding=1)  # IN_CHANNELS(C=3),W=32,H=32 -> OUT(=C2),W,H
        # self.conv2 = nn.Conv2d(C2, C3, kernel_size=3, padding=1) # C2,W2,H2 -> C3,W2,H2
        # self.fc1 = nn.Linear(W3xH3xC3, 32)
        self.fc1 = nn.Linear(32*32*3, 32)
        self.fc2 = nn.Linear(32, 2)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, x):
        # out = F.max_pool2d(F.gelu(self.conv1(x)), 2)     # C=3,W=32,H=32 -> C2,W2,H2
        # out = F.max_pool2d(F.gelu(self.conv2(out)), 2)   # C2,W2,H2 -> C3,W3,H3
        # out = out.view(-1, W3*H3*C3)
        out = x.view(-1,3*32*32)
        out = self.fc1(out)
        out = self.fc2(F.gelu(out))
        out = self.softmax(out)  # equivalent: F.log_softmax(out, dim=1)
        return out

model3 = Net3()

### DataLoaders for training and validation datasets
Read about DataLoaders here:
https://pytorch.org/tutorials/beginner/basics/data_tutorial.html


In [None]:
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)

val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)


In [None]:
def eval_accuracy(model, val_loader, which="Validation"):
    correct = 0
    total = 0
    
    with torch.no_grad():
        for imgs, labels in val_loader:
            outputs = model(imgs)
            _, predicted = torch.max(outputs, dim=1)
            total += labels.shape[0]
            correct += int((predicted == labels).sum())
            
    print(f"{which} Accuracy (correct: {correct} / total: {total}) = {correct/total}" )
    accuracy = correct/total
    return accuracy, correct, total

 ### Negative Log-Likelihood, CrossEntropy (loss functions for classification)
Can read more here:
https://towardsdatascience.com/cross-entropy-negative-log-likelihood-and-all-that-jazz-47a95bd2e81 

In [None]:
# loss_fn = nn.CrossEntropyLoss()  # Works with unnormalized logits
loss_fn = nn.NLLLoss()  # Negative Log Likelihood, assumes model outputs log probabilities (e.g. from LogSoftmax)


In [None]:
LEARNING_RATE = 1e-2

NUM_EPOCHS = 100

WEIGHT_DECAY = 0.002  # (L2 regularization )

In [None]:
def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")


In [None]:
def save_checkpoint(model, checkpoint_dir, chkpt_filename):
    torch.save(model.state_dict(), Path(checkpoint_dir) / chkpt_filename)

def load_from_checkpoint(model, checkpoint_dir, chkpt_filename):
    model.load_state_dict(torch.load(Path(checkpoint_dir) / chkpt_filename))
    return model
    
def train_model(model, data_loader,
                n_epochs=NUM_EPOCHS,
                learning_rate=LEARNING_RATE,
                loss_fn=loss_fn,
                val_loader=None,
                model_name=None
               ):
    if model_name is None:
        model_name = type(model).__name__
    n_batches = len(data_loader)
    best_validation_accuracy = 0.0
    print(f"----- Training model {model_name} {model} #params={sum(p.numel() for p in model.parameters())} with batches/epoch={n_batches}")
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=WEIGHT_DECAY)

    # for epoch in range(1, n_epochs+1):   # starting from 1 to NUM_EPOCHS , inclusive
    for epoch in range(n_epochs+1):   # starting from 0 to NUM_EPOCHS , inclusive
        loss_train = 0.0
        for imgs, labels in data_loader:
            outputs = model(imgs)
            loss = loss_fn(outputs, labels)
                    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_train += loss.item()

        # print("Epoch: %d, Loss: %f" % (epoch, float(loss)))
        if epoch == 1 or epoch % 5 == 0:
            print(f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} Epoch: {epoch}, Training loss {loss_train/n_batches :0.6f}')
        if val_loader is not None and epoch % 5 == 0:
            train_accuracy, _, _ = eval_accuracy(model, train_loader, which="Training")
            accuracy, _, _ = eval_accuracy(model, val_loader, which="Validation")
            # print(f"Train accuracy: {train_accuracy}, Validation accuracy: {accuracy}")
            if accuracy > best_validation_accuracy:
                print(f"+++++ New best validation accuracy {accuracy:0.3f} at epoch {epoch}")
                best_validation_accuracy = accuracy
                save_checkpoint(model, CHECKPOINT_DIR, f"{model_name}_best.pt")

    # save the model after training
    save_checkpoint(model, CHECKPOINT_DIR, f"{model_name}_checkpoint_{epoch}.pt")
          
    

In [None]:
set_seed(42)   # specify explicit seed for random number generators, for reproducible results
model = Net()
# rebuild training dataloader (for reproducibility using our explicit random seed)
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)
train_model(model, train_loader, n_epochs=NUM_EPOCHS, learning_rate=1e-2, val_loader=val_loader)

In [None]:
#val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)
#train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)

eval_accuracy(model, train_loader, "Training")
eval_accuracy(model, val_loader, "Validation")


In [None]:
model1_best = Net()
load_from_checkpoint(model1_best, CHECKPOINT_DIR, "Net_best.pt")

eval_accuracy(model1_best, train_loader, "TRAINING")
eval_accuracy(model1_best, val_loader, "VALIDATION")


In [None]:
set_seed(42)   # specify explicit seed for random number generators, for reproducible results
model2 = Net2()
# rebuild training dataloader (for reproducibility using our explicit random seed)
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)
train_model(model2, train_loader, n_epochs=NUM_EPOCHS, val_loader=val_loader)

In [None]:
model2_best = Net2()
load_from_checkpoint(model2_best, CHECKPOINT_DIR, "Net2_best.pt")

eval_accuracy(model2_best, train_loader, "TRAINING")
eval_accuracy(model2_best, val_loader, "VALIDATION")


In [None]:
val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)
eval_accuracy(model2, train_loader, "TRAINING")
eval_accuracy(model2, val_loader, "VALIDATION")

In [None]:
label_map = {0: 2, 1: 0}
cifar2_valSWAPPED = load_val_data(label_map=label_map)
eval_accuracy(model2, torch.utils.data.DataLoader(cifar2_valSWAPPED, batch_size=64, shuffle=False))

In [None]:
torch.cuda.is_available()

In [None]:
set_seed(0)   # specify an explicit seed for random number generators, for reproducible results
model3 = Net3()
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)
train_model(model3, train_loader, n_epochs=NUM_EPOCHS, val_loader=val_loader)
print("-----TRAINING FINISHED-----")
eval_accuracy(model3, train_loader, "TRAINING")
eval_accuracy(model3, val_loader, "VALIDATION")

In [None]:
model3_best = Net3()
load_from_checkpoint(model3_best, CHECKPOINT_DIR, "Net3_best.pt")

eval_accuracy(model3_best, train_loader, "TRAINING")
eval_accuracy(model3_best, val_loader, "VALIDATION")


In [None]:
label_map = {1: 2}
cifar2_valBird = load_val_data(label_map=label_map)
val_loader_alt = torch.utils.data.DataLoader(cifar2_valBird, batch_size=64, shuffle=False)
eval_accuracy(model3_best, val_loader_alt)


In [None]:
label_map = {0: 0}
cifar2_valAirplane = load_val_data(label_map=label_map)
val_loader_alt = torch.utils.data.DataLoader(cifar2_valAirplane, batch_size=64, shuffle=False)
eval_accuracy(model3_best, val_loader_alt)


In [None]:
label_map = {0: 0, 1: 5}
cifar2_valAirplaneDog = load_val_data(label_map=label_map)
val_loader_alt = torch.utils.data.DataLoader(cifar2_valAirplaneDog, batch_size=64, shuffle=False)
eval_accuracy(model3_best, val_loader_alt)


In [None]:
label_map = {0: 5, 1: 2}
cifar2_valDogBird = load_val_data(label_map=label_map)
val_loader_alt = torch.utils.data.DataLoader(cifar2_valDogBird, batch_size=64, shuffle=False)
eval_accuracy(model3_best, val_loader_alt)


In [None]:
label_map = {0: 5} 
cifar2_valDog0 = load_val_data(label_map=label_map)
val_loader_alt = torch.utils.data.DataLoader(cifar2_valDog0, batch_size=64, shuffle=False)
eval_accuracy(model3_best, val_loader_alt)


In [None]:
label_map = {1: 5} 
cifar2_valDog1 = load_val_data(label_map=label_map)
val_loader_alt = torch.utils.data.DataLoader(cifar2_valDog1, batch_size=64, shuffle=False)
eval_accuracy(model3_best, val_loader_alt)
