In [3]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator

def ConvBNReLU(filters, kernel_size=3, strides=1, padding='same', use_bias=False):
    """Creates a Sequential model with Conv2D, BatchNormalization, and ReLU layers."""
    return models.Sequential([
        layers.Conv2D(filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias),
        layers.BatchNormalization(),
        layers.ReLU()
    ])

class ResidualBlock(models.Model):
    def __init__(self, in_channels, out_channels, strides=1):
        super().__init__()
        self.conv_res1 = ConvBNReLU(out_channels, strides=strides)
        self.conv_res2 = ConvBNReLU(out_channels)
        self.downsample = models.Sequential([
            layers.Conv2D(out_channels, 1, strides=strides, use_bias=False),
            layers.BatchNormalization()
        ]) if strides != 1 or in_channels != out_channels else lambda x: x

    def call(self, inputs):
        residual = self.downsample(inputs)
        out = self.conv_res1(inputs)
        out = self.conv_res2(out)
        out += residual
        return layers.ReLU()(out)

class Net(models.Model):
    def __init__(self):
        super().__init__()
        self.conv1 = ConvBNReLU(64)
        self.conv2 = ConvBNReLU(128)
        self.pool = layers.MaxPooling2D(pool_size=2, strides=2)
        self.res_block1 = ResidualBlock(128, 192)
        self.conv3 = ConvBNReLU(256)
        self.conv4 = ConvBNReLU(384)
        self.res_block2 = ResidualBlock(384, 384)
        self.gap = layers.GlobalAveragePooling2D()
        self.fc = layers.Dense(10, activation='softmax')

    def call(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.res_block1(x)
        x = self.conv3(x)
        x = self.pool(x)
        x = self.conv4(x)
        x = self.pool(x)
        x = self.res_block2(x)
        x = self.gap(x)
        return self.fc(x)

# Create and build the model
model = Net()
model.build(input_shape=(None, 32, 32, 3))

# Print the model summary
model.summary()

Model: "net_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 sequential_9 (Sequential)   (None, 32, 32, 64)        1984      
                                                                 
 sequential_10 (Sequential)  (None, 32, 32, 128)       74240     
                                                                 
 max_pooling2d_1 (MaxPooling  multiple                 0         
 2D)                                                             
                                                                 
 residual_block_2 (ResidualB  multiple                 579840    
 lock)                                                           
                                                                 
 sequential_14 (Sequential)  (None, 16, 16, 256)       443392    
                                                                 
 sequential_15 (Sequential)  (None, 8, 8, 384)         886272

In [4]:
import numpy as np

# Load the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Normalize the images
x_train, x_test = x_train / 255.0, x_test / 255.0

# One-hot encode the labels
y_train, y_test = to_categorical(y_train, 10), to_categorical(y_test, 10)

# Set a random seed for reproducibility
np.random.seed(42)

# Create an array of shuffled indices
shuffled_indices = np.arange(x_train.shape[0])
np.random.shuffle(shuffled_indices)

# Shuffle x_train and y_train using the shuffled indices
x_train = x_train[shuffled_indices]
y_train = y_train[shuffled_indices]

In [5]:
from PIL import Image, ImageEnhance, ImageOps
import random

class ShearX(object):
    def __init__(self, fillcolor=(128, 128, 128)):
        self.fillcolor = fillcolor

    def __call__(self, x, magnitude):
        return x.transform(
            x.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
            Image.BICUBIC, fillcolor=self.fillcolor)


class ShearY(object):
    def __init__(self, fillcolor=(128, 128, 128)):
        self.fillcolor = fillcolor

    def __call__(self, x, magnitude):
        return x.transform(
            x.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
            Image.BICUBIC, fillcolor=self.fillcolor)


class TranslateX(object):
    def __init__(self, fillcolor=(128, 128, 128)):
        self.fillcolor = fillcolor

    def __call__(self, x, magnitude):
        return x.transform(
            x.size, Image.AFFINE, (1, 0, magnitude * x.size[0] * random.choice([-1, 1]), 0, 1, 0),
            fillcolor=self.fillcolor)


class TranslateY(object):
    def __init__(self, fillcolor=(128, 128, 128)):
        self.fillcolor = fillcolor

    def __call__(self, x, magnitude):
        return x.transform(
            x.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * x.size[1] * random.choice([-1, 1])),
            fillcolor=self.fillcolor)


class Rotate(object):
    # from https://stackoverflow.com/questions/
    # 5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
    def __call__(self, x, magnitude):
        rot = x.convert("RGBA").rotate(magnitude * random.choice([-1, 1]))
        return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(x.mode)


class Color(object):
    def __call__(self, x, magnitude):
        return ImageEnhance.Color(x).enhance(1 + magnitude * random.choice([-1, 1]))


class Posterize(object):
    def __call__(self, x, magnitude):
        return ImageOps.posterize(x, magnitude)


class Solarize(object):
    def __call__(self, x, magnitude):
        return ImageOps.solarize(x, magnitude)


class Contrast(object):
    def __call__(self, x, magnitude):
        return ImageEnhance.Contrast(x).enhance(1 + magnitude * random.choice([-1, 1]))


class Sharpness(object):
    def __call__(self, x, magnitude):
        return ImageEnhance.Sharpness(x).enhance(1 + magnitude * random.choice([-1, 1]))


class Brightness(object):
    def __call__(self, x, magnitude):
        return ImageEnhance.Brightness(x).enhance(1 + magnitude * random.choice([-1, 1]))


class AutoContrast(object):
    def __call__(self, x, magnitude):
        return ImageOps.autocontrast(x)


class Equalize(object):
    def __call__(self, x, magnitude):
        return ImageOps.equalize(x)


class Invert(object):
    def __call__(self, x, magnitude):
        return ImageOps.invert(x)

In [6]:
class CIFAR10Policy(object):
    """ Randomly choose one of the best 25 Sub-policies on CIFAR10.

        Example:
        >>> policy = CIFAR10Policy()
        >>> transformed = policy(image)

        Example as a PyTorch Transform:
        >>> transform=transforms.Compose([
        >>>     transforms.Resize(256),
        >>>     CIFAR10Policy(),
        >>>     transforms.ToTensor()])
    """
    def __init__(self, fillcolor=(128, 128, 128)):
        self.policies = [
            SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
            SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
            SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
            SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
            SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),

            SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
            SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
            SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
            SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
            SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),

            SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
            SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
            SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
            SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
            SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),

            SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
            SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
            SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
            SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
            SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),

            SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
            SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
            SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
            SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
            SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
        ]

    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

    def __repr__(self):
        return "AutoAugment CIFAR10 Policy"

class SubPolicy(object):
    def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
        ranges = {
            "shearX": np.linspace(0, 0.3, 10),
            "shearY": np.linspace(0, 0.3, 10),
            "translateX": np.linspace(0, 150 / 331, 10),
            "translateY": np.linspace(0, 150 / 331, 10),
            "rotate": np.linspace(0, 30, 10),
            "color": np.linspace(0.0, 0.9, 10),
            "posterize": np.round(np.linspace(8, 4, 10), 0).astype(int),
            "solarize": np.linspace(256, 0, 10),
            "contrast": np.linspace(0.0, 0.9, 10),
            "sharpness": np.linspace(0.0, 0.9, 10),
            "brightness": np.linspace(0.0, 0.9, 10),
            "autocontrast": [0] * 10,
            "equalize": [0] * 10,
            "invert": [0] * 10
        }

        func = {
            "shearX": ShearX(fillcolor=fillcolor),
            "shearY": ShearY(fillcolor=fillcolor),
            "translateX": TranslateX(fillcolor=fillcolor),
            "translateY": TranslateY(fillcolor=fillcolor),
            "rotate": Rotate(),
            "color": Color(),
            "posterize": Posterize(),
            "solarize": Solarize(),
            "contrast": Contrast(),
            "sharpness": Sharpness(),
            "brightness": Brightness(),
            "autocontrast": AutoContrast(),
            "equalize": Equalize(),
            "invert": Invert()
        }

        self.p1 = p1
        self.operation1 = func[operation1]
        self.magnitude1 = ranges[operation1][magnitude_idx1]
        self.p2 = p2
        self.operation2 = func[operation2]
        self.magnitude2 = ranges[operation2][magnitude_idx2]

    def __call__(self, img):
        if random.random() < self.p1:
            img = self.operation1(img, self.magnitude1)
        if random.random() < self.p2:
            img = self.operation2(img, self.magnitude2)
        return img

In [7]:
def apply_cutout(image, num_holes=1, max_h_size=8, max_w_size=8):
    """Applies Cutout augmentation to a single image."""
    # Convert PIL Image to numpy array
    image_np = np.array(image)
    
    h, w = image_np.shape[:2]
    mask = np.ones((h, w), np.float32)

    for _ in range(num_holes):
        y = np.random.randint(h)
        x = np.random.randint(w)
        
        y1 = np.clip(y - max_h_size // 2, 0, h)
        y2 = np.clip(y + max_h_size // 2, 0, h)
        x1 = np.clip(x - max_w_size // 2, 0, w)
        x2 = np.clip(x + max_w_size // 2, 0, w)

        mask[y1: y2, x1: x2] = 0.

    # Apply mask
    image_np = image_np * mask[:, :, np.newaxis]

    # Convert back to PIL Image
    return Image.fromarray(image_np.astype('uint8'))

In [8]:
def pad_image(image, pad_size=4, fill=0, padding_mode='reflect'):
    """Pad the given PIL Image on all sides with the given pad_size."""
    return ImageOps.expand(image, border=pad_size, fill=fill)

def random_crop(image, crop_size=(32, 32)):
    """Crop a random part of the image to the given size."""
    width, height = image.size
    new_width, new_height = crop_size

    left = np.random.randint(0, width - new_width + 1)
    top = np.random.randint(0, height - new_height + 1)

    image = image.crop((left, top, left + new_width, top + new_height))
    return image

In [9]:
def random_horizontal_flip(image, p=0.5):
    """Randomly flip the image horizontally with a probability of p."""
    if random.random() < p:
        return image.transpose(Image.FLIP_LEFT_RIGHT)
    return image

def random_rotation(image, max_angle=0):
    """Randomly rotate the image within a given angle range."""
    angle = random.uniform(-max_angle, max_angle)
    return image.rotate(angle)

In [10]:
from tensorflow.keras.utils import Sequence
import numpy as np
import random

class CustomImageDataGenerator(Sequence):
    def __init__(self, x_set, y_set, batch_size=64, augmentations=None):
        self.x_set = x_set
        self.y_set = y_set
        self.batch_size = batch_size
        self.augmentations = augmentations if augmentations else []

    def __len__(self):
        return np.ceil(len(self.x_set) / self.batch_size).astype(int)

    def __getitem__(self, idx):
        batch_x = self.x_set[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y_set[idx * self.batch_size:(idx + 1) * self.batch_size]
        
        # Convert numpy arrays to PIL Images, apply augmentations, and convert back to numpy arrays
        x_batch_aug = np.array([self.apply_augmentations(Image.fromarray((image * 255).astype('uint8'))) for image in batch_x])
        
        # Convert PIL Images back to numpy arrays and normalize to [0, 1]
        x_batch_aug = np.array([np.array(image) for image in x_batch_aug]).astype('float32') / 255.0
        
        return x_batch_aug, batch_y

    def apply_augmentations(self, image):
        augmented_image = image
        for augmentation in self.augmentations:
            augmented_image = augmentation(augmented_image)
        return augmented_image

# Assuming CIFAR10Policy and apply_cutout are defined elsewhere
custom_augmentations = [pad_image, random_crop, random_horizontal_flip, random_rotation, CIFAR10Policy(), apply_cutout]

In [11]:
lr = 0.05
momentum = 0.9
weight_decay = 0.0005

optimizer = tf.keras.optimizers.SGD(
    learning_rate=lr, 
    momentum=momentum, 
    nesterov=True,
    decay=weight_decay
)

# Compile the model with the updated learning_rate parameter
model.compile(optimizer=optimizer,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [12]:
from tensorflow.keras.callbacks import ModelCheckpoint

# Define the ModelCheckpoint callback to save the model using the 'SavedModel' format
checkpoint = ModelCheckpoint('best_model-r9', monitor='val_accuracy', verbose=1, save_best_only=True, mode='max', save_format='tf')

# Initialize the generator with the custom augmentations
custom_data_generator = CustomImageDataGenerator(x_train, y_train, batch_size=128, augmentations=custom_augmentations)

# Train the model using the custom data generator
history = model.fit(custom_data_generator,
                    steps_per_epoch=len(x_train) // 128,
                    epochs=250, 
                    validation_data=(x_test, y_test),
                    callbacks=[checkpoint])  # Include the checkpoint callback here

Epoch 1/250
Epoch 1: val_accuracy improved from -inf to 0.22360, saving model to best_model-r9




INFO:tensorflow:Assets written to: best_model-r9\assets


INFO:tensorflow:Assets written to: best_model-r9\assets


Epoch 2/250
Epoch 2: val_accuracy improved from 0.22360 to 0.29070, saving model to best_model-r9




INFO:tensorflow:Assets written to: best_model-r9\assets


INFO:tensorflow:Assets written to: best_model-r9\assets


Epoch 3/250
Epoch 3: val_accuracy improved from 0.29070 to 0.58080, saving model to best_model-r9




INFO:tensorflow:Assets written to: best_model-r9\assets


INFO:tensorflow:Assets written to: best_model-r9\assets


Epoch 4/250
Epoch 4: val_accuracy did not improve from 0.58080
Epoch 5/250
Epoch 5: val_accuracy improved from 0.58080 to 0.68010, saving model to best_model-r9




INFO:tensorflow:Assets written to: best_model-r9\assets


INFO:tensorflow:Assets written to: best_model-r9\assets


Epoch 6/250
Epoch 6: val_accuracy improved from 0.68010 to 0.72990, saving model to best_model-r9




INFO:tensorflow:Assets written to: best_model-r9\assets


INFO:tensorflow:Assets written to: best_model-r9\assets


Epoch 7/250
Epoch 7: val_accuracy did not improve from 0.72990
Epoch 8/250
Epoch 8: val_accuracy improved from 0.72990 to 0.74860, saving model to best_model-r9




INFO:tensorflow:Assets written to: best_model-r9\assets


INFO:tensorflow:Assets written to: best_model-r9\assets


Epoch 9/250
Epoch 9: val_accuracy improved from 0.74860 to 0.79650, saving model to best_model-r9




INFO:tensorflow:Assets written to: best_model-r9\assets


INFO:tensorflow:Assets written to: best_model-r9\assets


Epoch 10/250
Epoch 10: val_accuracy did not improve from 0.79650
Epoch 11/250
Epoch 11: val_accuracy improved from 0.79650 to 0.81050, saving model to best_model-r9




INFO:tensorflow:Assets written to: best_model-r9\assets


INFO:tensorflow:Assets written to: best_model-r9\assets


Epoch 12/250
Epoch 12: val_accuracy did not improve from 0.81050
Epoch 13/250
Epoch 13: val_accuracy did not improve from 0.81050
Epoch 14/250
Epoch 14: val_accuracy improved from 0.81050 to 0.83200, saving model to best_model-r9




INFO:tensorflow:Assets written to: best_model-r9\assets


INFO:tensorflow:Assets written to: best_model-r9\assets


Epoch 15/250
Epoch 15: val_accuracy improved from 0.83200 to 0.83950, saving model to best_model-r9




INFO:tensorflow:Assets written to: best_model-r9\assets


INFO:tensorflow:Assets written to: best_model-r9\assets


Epoch 16/250
Epoch 16: val_accuracy did not improve from 0.83950
Epoch 17/250
Epoch 17: val_accuracy did not improve from 0.83950
Epoch 18/250
Epoch 18: val_accuracy improved from 0.83950 to 0.86500, saving model to best_model-r9




INFO:tensorflow:Assets written to: best_model-r9\assets


INFO:tensorflow:Assets written to: best_model-r9\assets


Epoch 19/250
Epoch 19: val_accuracy did not improve from 0.86500
Epoch 20/250
Epoch 20: val_accuracy improved from 0.86500 to 0.86930, saving model to best_model-r9




INFO:tensorflow:Assets written to: best_model-r9\assets


INFO:tensorflow:Assets written to: best_model-r9\assets


Epoch 21/250
Epoch 21: val_accuracy did not improve from 0.86930
Epoch 22/250
Epoch 22: val_accuracy did not improve from 0.86930
Epoch 23/250
Epoch 23: val_accuracy did not improve from 0.86930
Epoch 24/250
Epoch 24: val_accuracy improved from 0.86930 to 0.86950, saving model to best_model-r9




INFO:tensorflow:Assets written to: best_model-r9\assets


INFO:tensorflow:Assets written to: best_model-r9\assets


Epoch 25/250
Epoch 25: val_accuracy improved from 0.86950 to 0.87510, saving model to best_model-r9




INFO:tensorflow:Assets written to: best_model-r9\assets


INFO:tensorflow:Assets written to: best_model-r9\assets


Epoch 26/250
Epoch 26: val_accuracy did not improve from 0.87510
Epoch 27/250
Epoch 27: val_accuracy did not improve from 0.87510
Epoch 28/250
Epoch 28: val_accuracy did not improve from 0.87510
Epoch 29/250
Epoch 29: val_accuracy did not improve from 0.87510
Epoch 30/250
Epoch 30: val_accuracy improved from 0.87510 to 0.88390, saving model to best_model-r9




INFO:tensorflow:Assets written to: best_model-r9\assets


INFO:tensorflow:Assets written to: best_model-r9\assets


Epoch 31/250
Epoch 31: val_accuracy did not improve from 0.88390
Epoch 32/250
 17/390 [>.............................] - ETA: 24s - loss: 0.3809 - accuracy: 0.8727

KeyboardInterrupt: 

In [None]:
model.save('model-r9')