In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


#Preprocessing and splitting the different datasets#


In [None]:
"""
Splitting the data
"""
from keras.datasets import cifar10
from keras.datasets import cifar100
from keras.datasets import mnist
import numpy as np
from sklearn.model_selection import train_test_split


# iDataset = {0, 1, ... , 19} , 20 possible sub datasets
def splitSubSetData(iDataset):

  if iDataset < 2:
    dataset_name = "cifar10"
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    numDiv = 2

  elif iDataset < 4:
    dataset_name = "mnist"
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    iDataset = iDataset - 2
    numDiv = 2
 
  else:
    dataset_name = "cifar100"
    (x_train, y_train), (x_test, y_test) = cifar100.load_data()
    numDiv = 16
    iDataset = iDataset - 4

  numClasses = int(max(y_train) + 1)
  sub_y_range = list(range(int(numClasses / numDiv * iDataset), int(numClasses / numDiv * (iDataset + 1))))

  X = np.concatenate((x_train,x_test), axis=0)
  y = np.concatenate((y_train,y_test), axis=0)
  
  indices = []
  for i in range(len(y)):
    if y[i] <= sub_y_range[-1] and y[i] >= sub_y_range[0]:
      indices.append(i)

  y_out = y[indices]
  x_out = X[indices]  

  # x_train, x_test,y_train,y_test=train_test_split(x_out,y_out,train_size=0.8)
  # return (x_train, y_train), (x_test, y_test)

  return x_out, y_out


#Stage 1 - FixMach Algorithm implementation#

Utility functions for the algorithm

In [None]:
# libml\utils.py
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities."""

import os
import re

import numpy as np
import tensorflow as tf
from absl import flags, logging
from tensorflow.python.client import device_lib

_GPUS = None
FLAGS = flags.FLAGS
# flags.DEFINE_bool('log_device_placement', False, 'For debugging purpose.')


class EasyDict(dict):
    def __init__(self, *args, **kwargs):
        super(EasyDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


def get_config():
    config = tf.ConfigProto()
    if len(get_available_gpus()) > 1:
        config.allow_soft_placement = True
    if FLAGS.log_device_placement:
        config.log_device_placement = True
    config.gpu_options.allow_growth = True
    return config


def setup_main():
    pass


def setup_tf():
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    # logging.set_verbosity(logging.ERROR)


def smart_shape(x):
    s = x.shape
    st = tf.shape(x)
    return [s[i] if s[i].value is not None else st[i] for i in range(4)]


def ilog2(x):
    """Integer log2."""
    return int(np.ceil(np.log2(x)))


def find_latest_checkpoint(dir, glob_term='model.ckpt-*.meta'):
    """Replacement for tf.train.latest_checkpoint.

    It does not rely on the "checkpoint" file which sometimes contains
    absolute path and is generally hard to work with when sharing files
    between users / computers.
    """
    r_step = re.compile('.*model\.ckpt-(?P<step>\d+)\.meta')
    matches = tf.gfile.Glob(os.path.join(dir, glob_term))
    matches = [(int(r_step.match(x).group('step')), x) for x in matches]
    ckpt_file = max(matches)[1][:-5]
    return ckpt_file


def get_latest_global_step(dir):
    """Loads the global step from the latest checkpoint in directory.

    Args:
      dir: string, path to the checkpoint directory.

    Returns:
      int, the global step of the latest checkpoint or 0 if none was found.
    """
    try:
        checkpoint_reader = tf.train.NewCheckpointReader(find_latest_checkpoint(dir))
        return checkpoint_reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
    except:  # pylint: disable=bare-except
        return 0


def get_latest_global_step_in_subdir(dir):
    """Loads the global step from the latest checkpoint in sub-directories.

    Args:
      dir: string, parent of the checkpoint directories.

    Returns:
      int, the global step of the latest checkpoint or 0 if none was found.
    """
    sub_dirs = (x for x in tf.gfile.Glob(os.path.join(dir, '*')) if os.path.isdir(x))
    step = 0
    for x in sub_dirs:
        step = max(step, get_latest_global_step(x))
    return step


def getter_ema(ema, getter, name, *args, **kwargs):
    """Exponential moving average getter for variable scopes.

    Args:
        ema: ExponentialMovingAverage object, where to get variable moving averages.
        getter: default variable scope getter.
        name: variable name.
        *args: extra args passed to default getter.
        **kwargs: extra args passed to default getter.

    Returns:
        If found the moving average variable, otherwise the default variable.
    """
    var = getter(name, *args, **kwargs)
    ema_var = ema.average(var)
    return ema_var if ema_var else var


def model_vars(scope=None):
    return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)


def gpu(x):
    return '/gpu:%d' % (x % max(1, len(get_available_gpus())))


def get_available_gpus():
    global _GPUS
    if _GPUS is None:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        local_device_protos = device_lib.list_local_devices(session_config=config)
        _GPUS = tuple([x.name for x in local_device_protos if x.device_type == 'GPU'])
    return _GPUS


def get_gpu():
    gpus = get_available_gpus()
    pos = 0
    while 1:
        yield gpus[pos]
        pos = (pos + 1) % len(gpus)


def average_gradients(tower_grads):
    """
    Calculate the average gradient for each shared variable across all towers.
    Note that this function provides a synchronization point across all towers.
    Args:
      tower_grads: List of lists of (gradient, variable) tuples. For each tower, a list of its gradients.
    Returns:
       List of pairs of (gradient, variable) where the gradient has been averaged
       across all towers.
    """
    if len(tower_grads) <= 1:
        return tower_grads[0]

    average_grads = []
    for grads_and_vars in zip(*tower_grads):
        grad = tf.reduce_mean([gv[0] for gv in grads_and_vars], 0)
        average_grads.append((grad, grads_and_vars[0][1]))
    return average_grads


def para_list(fn, *args):
    """Run on multiple GPUs in parallel and return list of results."""
    gpus = len(get_available_gpus())
    if gpus <= 1:
        return zip(*[fn(*args)])
    splitted = [tf.split(x, gpus) for x in args]
    outputs = []
    for gpu, x in enumerate(zip(*splitted)):
        with tf.name_scope('tower%d' % gpu):
            with tf.device(tf.train.replica_device_setter(
                    worker_device='/gpu:%d' % gpu, ps_device='/cpu:0', ps_tasks=1)):
                outputs.append(fn(*x))
    return zip(*outputs)


def para_mean(fn, *args):
    """Run on multiple GPUs in parallel and return means."""
    gpus = len(get_available_gpus())
    if gpus <= 1:
        return fn(*args)
    splitted = [tf.split(x, gpus) for x in args]
    outputs = []
    for gpu, x in enumerate(zip(*splitted)):
        with tf.name_scope('tower%d' % gpu):
            with tf.device(tf.train.replica_device_setter(
                    worker_device='/gpu:%d' % gpu, ps_device='/cpu:0', ps_tasks=1)):
                outputs.append(fn(*x))
    if isinstance(outputs[0], (tuple, list)):
        return [tf.reduce_mean(x, 0) for x in zip(*outputs)]
    return tf.reduce_mean(outputs, 0)


def para_cat(fn, *args):
    """Run on multiple GPUs in parallel and return concatenated outputs."""
    gpus = len(get_available_gpus())
    if gpus <= 1:
        return fn(*args)
    splitted = [tf.split(x, gpus) for x in args]
    outputs = []
    for gpu, x in enumerate(zip(*splitted)):
        with tf.name_scope('tower%d' % gpu):
            with tf.device(tf.train.replica_device_setter(
                    worker_device='/gpu:%d' % gpu, ps_device='/cpu:0', ps_tasks=1)):
                outputs.append(fn(*x))
    if isinstance(outputs[0], (tuple, list)):
        return [tf.concat(x, axis=0) for x in zip(*outputs)]
    return tf.concat(outputs, axis=0)


def interleave(x, batch):
    s = x.get_shape().as_list()
    return tf.reshape(tf.transpose(tf.reshape(x, [-1, batch] + s[1:]), [1, 0] + list(range(2, 1+len(s)))), [-1] + s[1:])


def de_interleave(x, batch):
    s = x.get_shape().as_list()
    return tf.reshape(tf.transpose(tf.reshape(x, [batch, -1] + s[1:]), [1, 0] + list(range(2, 1+len(s)))), [-1] +s[1:])


def combine_dicts(*args):
    # Python 2 compatible way to combine several dictionaries
    # We need it because currently TPU code does not work with python 3
    result = {}
    for d in args:
        result.update(d)
    return result




The original method to load the data from the GitHub project
---
-we used another method, 
described above using an index for sub dataset, 
iDataset = {0 : 19}

In [None]:
# create_datasets.py

#!/usr/bin/env python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Script to download all datasets and create .tfrecord files.
"""

import collections
import gzip
import os
import tarfile
import tempfile
from urllib import request

import numpy as np
import scipy.io
import tensorflow as tf
from absl import app
from tqdm import trange

# from libml import data as libml_data
# from libml.utils import EasyDict

URLS = {
    'svhn': 'http://ufldl.stanford.edu/housenumbers/{}_32x32.mat',
    'cifar10': 'https://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz',
    'cifar100': 'https://www.cs.toronto.edu/~kriz/cifar-100-matlab.tar.gz',
    'stl10': 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz',
}


def _encode_png(images):
    raw = []
    with tf.Session() as sess, tf.device('cpu:0'):
        image_x = tf.placeholder(tf.uint8, [None, None, None], 'image_x')
        to_png = tf.image.encode_png(image_x)
        for x in trange(images.shape[0], desc='PNG Encoding', leave=False):
            raw.append(sess.run(to_png, feed_dict={image_x: images[x]}))
    return raw


def _load_svhn():
    splits = collections.OrderedDict()
    for split in ['train', 'test', 'extra']:
        with tempfile.NamedTemporaryFile() as f:
            request.urlretrieve(URLS['svhn'].format(split), f.name)
            data_dict = scipy.io.loadmat(f.name)
        dataset = {}
        dataset['images'] = np.transpose(data_dict['X'], [3, 0, 1, 2])
        dataset['images'] = _encode_png(dataset['images'])
        dataset['labels'] = data_dict['y'].reshape((-1))
        # SVHN raw data uses labels from 1 to 10; use 0 to 9 instead.
        dataset['labels'] -= 1
        splits[split] = dataset
    return splits


def _load_stl10():
    def unflatten(images):
        return np.transpose(images.reshape((-1, 3, 96, 96)),
                            [0, 3, 2, 1])

    with tempfile.NamedTemporaryFile() as f:
        if tf.gfile.Exists('stl10/stl10_binary.tar.gz'):
            f = tf.gfile.Open('stl10/stl10_binary.tar.gz', 'rb')
        else:
            request.urlretrieve(URLS['stl10'], f.name)
        tar = tarfile.open(fileobj=f)
        train_X = tar.extractfile('stl10_binary/train_X.bin')
        train_y = tar.extractfile('stl10_binary/train_y.bin')

        test_X = tar.extractfile('stl10_binary/test_X.bin')
        test_y = tar.extractfile('stl10_binary/test_y.bin')

        unlabeled_X = tar.extractfile('stl10_binary/unlabeled_X.bin')

        train_set = {'images': np.frombuffer(train_X.read(), dtype=np.uint8),
                     'labels': np.frombuffer(train_y.read(), dtype=np.uint8) - 1}

        test_set = {'images': np.frombuffer(test_X.read(), dtype=np.uint8),
                    'labels': np.frombuffer(test_y.read(), dtype=np.uint8) - 1}

        _imgs = np.frombuffer(unlabeled_X.read(), dtype=np.uint8)
        unlabeled_set = {'images': _imgs,
                         'labels': np.zeros(100000, dtype=np.uint8)}

        fold_indices = tar.extractfile('stl10_binary/fold_indices.txt').read()

    train_set['images'] = _encode_png(unflatten(train_set['images']))
    test_set['images'] = _encode_png(unflatten(test_set['images']))
    unlabeled_set['images'] = _encode_png(unflatten(unlabeled_set['images']))
    return dict(train=train_set, test=test_set, unlabeled=unlabeled_set,
                files=[EasyDict(filename="stl10_fold_indices.txt", data=fold_indices)])


def _load_cifar10():
    def unflatten(images):
        return np.transpose(images.reshape((images.shape[0], 3, 32, 32)),
                            [0, 2, 3, 1])

    with tempfile.NamedTemporaryFile() as f:
        request.urlretrieve(URLS['cifar10'], f.name)
        tar = tarfile.open(fileobj=f)
        train_data_batches, train_data_labels = [], []
        for batch in range(1, 6):
            data_dict = scipy.io.loadmat(tar.extractfile(
                'cifar-10-batches-mat/data_batch_{}.mat'.format(batch)))
            train_data_batches.append(data_dict['data'])
            train_data_labels.append(data_dict['labels'].flatten())
        train_set = {'images': np.concatenate(train_data_batches, axis=0),
                     'labels': np.concatenate(train_data_labels, axis=0)}
        data_dict = scipy.io.loadmat(tar.extractfile(
            'cifar-10-batches-mat/test_batch.mat'))
        test_set = {'images': data_dict['data'],
                    'labels': data_dict['labels'].flatten()}
    train_set['images'] = _encode_png(unflatten(train_set['images']))
    test_set['images'] = _encode_png(unflatten(test_set['images']))
    return dict(train=train_set, test=test_set)
    # the output is a dictionary{train, test} of dictionaries{images,labels} of ndarray's 


def _load_cifar100():
    def unflatten(images):
        return np.transpose(images.reshape((images.shape[0], 3, 32, 32)),
                            [0, 2, 3, 1])

    with tempfile.NamedTemporaryFile() as f:
        request.urlretrieve(URLS['cifar100'], f.name)
        tar = tarfile.open(fileobj=f)
        data_dict = scipy.io.loadmat(tar.extractfile('cifar-100-matlab/train.mat'))
        train_set = {'images': data_dict['data'],
                     'labels': data_dict['fine_labels'].flatten()}
        data_dict = scipy.io.loadmat(tar.extractfile('cifar-100-matlab/test.mat'))
        test_set = {'images': data_dict['data'],
                    'labels': data_dict['fine_labels'].flatten()}
    train_set['images'] = _encode_png(unflatten(train_set['images']))
    test_set['images'] = _encode_png(unflatten(test_set['images']))
    return dict(train=train_set, test=test_set)


def _load_fashionmnist():
    def _read32(data):
        dt = np.dtype(np.uint32).newbyteorder('>')
        return np.frombuffer(data.read(4), dtype=dt)[0]

    image_filename = '{}-images-idx3-ubyte'
    label_filename = '{}-labels-idx1-ubyte'
    split_files = [('train', 'train'), ('test', 't10k')]
    splits = {}
    for split, split_file in split_files:
        with tempfile.NamedTemporaryFile() as f:
            request.urlretrieve(URLS['fashion_mnist'].format(image_filename.format(split_file)), f.name)
            with gzip.GzipFile(fileobj=f, mode='r') as data:
                assert _read32(data) == 2051
                n_images = _read32(data)
                row = _read32(data)
                col = _read32(data)
                images = np.frombuffer(data.read(n_images * row * col), dtype=np.uint8)
                images = images.reshape((n_images, row, col, 1))
        with tempfile.NamedTemporaryFile() as f:
            request.urlretrieve(URLS['fashion_mnist'].format(label_filename.format(split_file)), f.name)
            with gzip.GzipFile(fileobj=f, mode='r') as data:
                assert _read32(data) == 2049
                n_labels = _read32(data)
                labels = np.frombuffer(data.read(n_labels), dtype=np.uint8)
        splits[split] = {'images': _encode_png(images), 'labels': labels}
    return splits


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _save_as_tfrecord(data, filename):
    assert len(data['images']) == len(data['labels'])
    filename = os.path.join(libml_data.DATA_DIR, filename + '.tfrecord')
    print('Saving dataset:', filename)
    with tf.python_io.TFRecordWriter(filename) as writer:
        for x in trange(len(data['images']), desc='Building records'):
            feat = dict(image=_bytes_feature(data['images'][x]),
                        label=_int64_feature(data['labels'][x]))
            record = tf.train.Example(features=tf.train.Features(feature=feat))
            writer.write(record.SerializeToString())
    print('Saved:', filename)


def _is_installed(name, checksums):
    for subset, checksum in checksums.items():
        filename = os.path.join(libml_data.DATA_DIR, '%s-%s.tfrecord' % (name, subset))
        if not tf.gfile.Exists(filename):
            return False
    return True


def _save_files(files, *args, **kwargs):
    del args, kwargs
    for folder in frozenset(os.path.dirname(x) for x in files):
        tf.gfile.MakeDirs(os.path.join(libml_data.DATA_DIR, folder))
    for filename, contents in files.items():
        with tf.gfile.Open(os.path.join(libml_data.DATA_DIR, filename), 'w') as f:
            f.write(contents)


def _is_installed_folder(name, folder):
    return tf.gfile.Exists(os.path.join(libml_data.DATA_DIR, name, folder))


CONFIGS = dict(
    cifar10=dict(loader=_load_cifar10, checksums=dict(train=None, test=None)),
    cifar100=dict(loader=_load_cifar100, checksums=dict(train=None, test=None)),
    svhn=dict(loader=_load_svhn, checksums=dict(train=None, test=None, extra=None)),
    stl10=dict(loader=_load_stl10, checksums=dict(train=None, test=None)),
)


def main(argv):
    if len(argv[1:]):
        subset = set(argv[1:])
    else:
        subset = set(CONFIGS.keys())
    tf.gfile.MakeDirs(DATA_DIR)
    for name, config in CONFIGS.items():
        if name not in subset:
            continue
        if 'is_installed' in config:
            if config['is_installed']():
                print('Skipping already installed:', name)
                continue
        elif _is_installed(name, config['checksums']):
            print('Skipping already installed:', name)
            continue
        print('Preparing', name)
        datas = config['loader']()
        saver = config.get('saver', _save_as_tfrecord)
        for sub_name, data in datas.items():
            if sub_name == 'readme':
                filename = os.path.join(libml_data.DATA_DIR, '%s-%s.txt' % (name, sub_name))
                with tf.gfile.Open(filename, 'w') as f:
                    f.write(data)
            elif sub_name == 'files':
                for file_and_data in data:
                    path = os.path.join(libml_data.DATA_DIR, file_and_data.filename)
                    with tf.gfile.Open(path, "wb") as f:
                        f.write(file_and_data.data)
            else:
                saver(data, '%s-%s' % (name, sub_name))


# if __name__ == '__main__':
    # app.run(main)


Stage 1 & Stage 2 - Adding augmentation methods as hyperparameter
---
Augmentation methods, used for the improvement of the model.

In [None]:
# libml\ctaugment.py
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Control Theory based self-augmentation."""
import random
from collections import namedtuple

import numpy as np
from PIL import Image, ImageOps, ImageEnhance, ImageFilter

OPS = {}
OP = namedtuple('OP', ('f', 'bins'))
Sample = namedtuple('Sample', ('train', 'probe'))


def register(*bins):
    def wrap(f):
        OPS[f.__name__] = OP(f, bins)
        return f

    return wrap


def apply(x, ops):
    if ops is None:
        return x
    y = Image.fromarray(np.round(127.5 * (1 + x)).clip(0, 255).astype('uint8'))
    for op, args in ops:
        y = OPS[op].f(y, *args)
    return np.asarray(y).astype('f') / 127.5 - 1


class CTAugment:
    def __init__(self, depth=2, th=0.85, decay=0.99):
        self.decay = decay
        self.depth = depth
        self.th = th
        self.rates = {}
        for k, op in OPS.items():
            self.rates[k] = tuple([np.ones(x, 'f') for x in op.bins])

    def rate_to_p(self, rate):
        p = rate + (1 - self.decay)  # Avoid to have all zero.
        p = p / p.max()
        p[p < self.th] = 0
        return p

    def policy(self, probe):
        kl = list(OPS.keys())
        v = []
        if probe:
            for _ in range(self.depth):
                k = random.choice(kl)
                bins = self.rates[k]
                rnd = np.random.uniform(0, 1, len(bins))
                v.append(OP(k, rnd.tolist()))
            return v
        for _ in range(self.depth):
            vt = []
            k = random.choice(kl)
            bins = self.rates[k]
            rnd = np.random.uniform(0, 1, len(bins))
            for r, bin in zip(rnd, bins):
                p = self.rate_to_p(bin)
                value = np.random.choice(p.shape[0], p=p / p.sum())
                vt.append((value + r) / p.shape[0])
            v.append(OP(k, vt))
        return v

    def update_rates(self, policy, proximity):
        for k, bins in policy:
            for p, rate in zip(bins, self.rates[k]):
                p = int(p * len(rate) * 0.999)
                rate[p] = rate[p] * self.decay + proximity * (1 - self.decay)

    def stats(self):
        return '\n'.join('%-16s    %s' % (k, ' / '.join(' '.join('%.2f' % x for x in self.rate_to_p(rate))
                                                        for rate in self.rates[k]))
                         for k in sorted(OPS.keys()))


def _enhance(x, op, level):
    return op(x).enhance(0.1 + 1.9 * level)


def _imageop(x, op, level):
    return Image.blend(x, op(x), level)


def _filter(x, op, level):
    return Image.blend(x, x.filter(op), level)


@register(17)
def autocontrast(x, level):
    return _imageop(x, ImageOps.autocontrast, level)


@register(17)
def blur(x, level):
    return _filter(x, ImageFilter.BLUR, level)


@register(17)
def brightness(x, brightness):
    return _enhance(x, ImageEnhance.Brightness, brightness)


@register(17)
def color(x, color):
    return _enhance(x, ImageEnhance.Color, color)


@register(17)
def contrast(x, contrast):
    return _enhance(x, ImageEnhance.Contrast, contrast)


@register(17)
def cutout(x, level):
    """Apply cutout to pil_img at the specified level."""
    size = 1 + int(level * min(x.size) * 0.499)
    img_height, img_width = x.size
    height_loc = np.random.randint(low=0, high=img_height)
    width_loc = np.random.randint(low=0, high=img_width)
    upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
    lower_coord = (min(img_height, height_loc + size // 2), min(img_width, width_loc + size // 2))
    pixels = x.load()  # create the pixel map
    for i in range(upper_coord[0], lower_coord[0]):  # for every col:
        for j in range(upper_coord[1], lower_coord[1]):  # For every row
            pixels[i, j] = (127, 127, 127)  # set the color accordingly
    return x


@register(17)
def equalize(x, level):
    return _imageop(x, ImageOps.equalize, level)


@register(17)
def invert(x, level):
    return _imageop(x, ImageOps.invert, level)


@register()
def identity(x):
    return x


@register(8)
def posterize(x, level):
    level = 1 + int(level * 7.999)
    return ImageOps.posterize(x, level)


@register(17, 6)
def rescale(x, scale, method):
    s = x.size
    scale *= 0.25
    crop = (scale * s[0], scale * s[1], s[0] * (1 - scale), s[1] * (1 - scale))
    methods = (Image.ANTIALIAS, Image.BICUBIC, Image.BILINEAR, Image.BOX, Image.HAMMING, Image.NEAREST)
    method = methods[int(method * 5.99)]
    return x.crop(crop).resize(x.size, method)


@register(17)
def rotate(x, angle):
    angle = int(np.round((2 * angle - 1) * 45))
    return x.rotate(angle)


@register(17)
def sharpness(x, sharpness):
    return _enhance(x, ImageEnhance.Sharpness, sharpness)


@register(17)
def shear_x(x, shear):
    shear = (2 * shear - 1) * 0.3
    return x.transform(x.size, Image.AFFINE, (1, shear, 0, 0, 1, 0))


@register(17)
def shear_y(x, shear):
    shear = (2 * shear - 1) * 0.3
    return x.transform(x.size, Image.AFFINE, (1, 0, 0, shear, 1, 0))


@register(17)
def smooth(x, level):
    return _filter(x, ImageFilter.SMOOTH, level)


@register(17)
def solarize(x, th):
    th = int(th * 255.999)
    return ImageOps.solarize(x, th)


@register(17)
def translate_x(x, delta):
    delta = (2 * delta - 1) * 0.3
    return x.transform(x.size, Image.AFFINE, (1, 0, delta, 0, 1, 0))


@register(17)
def translate_y(x, delta):
    delta = (2 * delta - 1) * 0.3
    return x.transform(x.size, Image.AFFINE, (1, 0, 0, 0, 1, delta))


Third party augmentation methods

In [None]:
# third_party\auto_augment\augmentations.py
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Transforms used in the Augmentation Policies."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random

import numpy as np
# pylint:disable=g-multiple-import
from PIL import ImageOps, ImageEnhance, ImageFilter, Image

# pylint:enable=g-multiple-import


# IMAGE_SIZE = 32
# What is the dataset mean and std of the images on the training set
# MEANS = [0.49139968, 0.48215841, 0.44653091]
# STDS = [0.24703223, 0.24348513, 0.26158784]
PARAMETER_MAX = 10  # What is the max 'level' a transform could be predicted


def random_flip(x):
    """Flip the input x horizontally with 50% probability."""
    if np.random.rand(1)[0] > 0.5:
        return np.fliplr(x)
    return x


def zero_pad_and_crop(img, amount=4):
    """Zero pad by `amount` zero pixels on each side then take a random crop.

    Args:
      img: numpy image that will be zero padded and cropped.
      amount: amount of zeros to pad `img` with horizontally and verically.

    Returns:
      The cropped zero padded img. The returned numpy array will be of the same
      shape as `img`.
    """
    padded_img = np.zeros((img.shape[0] + amount * 2, img.shape[1] + amount * 2,
                           img.shape[2]))
    padded_img[amount:img.shape[0] + amount, amount:
                                             img.shape[1] + amount, :] = img
    top = np.random.randint(low=0, high=2 * amount)
    left = np.random.randint(low=0, high=2 * amount)
    new_img = padded_img[top:top + img.shape[0], left:left + img.shape[1], :]
    return new_img


def create_cutout_mask(img_height, img_width, num_channels, size):
    """Creates a zero mask used for cutout of shape `img_height` x `img_width`.

    Args:
      img_height: Height of image cutout mask will be applied to.
      img_width: Width of image cutout mask will be applied to.
      num_channels: Number of channels in the image.
      size: Size of the zeros mask.

    Returns:
      A mask of shape `img_height` x `img_width` with all ones except for a
      square of zeros of shape `size` x `size`. This mask is meant to be
      elementwise multiplied with the original image. Additionally returns
      the `upper_coord` and `lower_coord` which specify where the cutout mask
      will be applied.
    """
    assert img_height == img_width

    # Sample center where cutout mask will be applied
    height_loc = np.random.randint(low=0, high=img_height)
    width_loc = np.random.randint(low=0, high=img_width)

    # Determine upper right and lower left corners of patch
    upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
    lower_coord = (min(img_height, height_loc + size // 2),
                   min(img_width, width_loc + size // 2))
    mask_height = lower_coord[0] - upper_coord[0]
    mask_width = lower_coord[1] - upper_coord[1]
    assert mask_height > 0
    assert mask_width > 0

    mask = np.ones((img_height, img_width, num_channels))
    mask[upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1], :] = 0
    return mask, upper_coord, lower_coord


def cutout_numpy(img, size=16):
    """Apply cutout with mask of shape `size` x `size` to `img`.

    The cutout operation is from the paper https://arxiv.org/abs/1708.04552.
    This operation applies a `size`x`size` mask of zeros to a random location
    within `img`.

    Args:
      img: Numpy image that cutout will be applied to.
      size: Height/width of the cutout mask that will be

    Returns:
      A numpy tensor that is the result of applying the cutout mask to `img`.
    """
    if size <= 0:
        return img
    assert len(img.shape) == 3
    img_height, img_width, num_channels = img.shape
    mask = create_cutout_mask(img_height, img_width, num_channels, size)[0]
    return img * mask


def float_parameter(level, maxval):
    """Helper function to scale `val` between 0 and maxval .

    Args:
      level: Level of the operation that will be between [0, `PARAMETER_MAX`].
      maxval: Maximum value that the operation can have. This will be scaled
        to level/PARAMETER_MAX.

    Returns:
      A float that results from scaling `maxval` according to `level`.
    """
    return float(level) * maxval / PARAMETER_MAX


def int_parameter(level, maxval):
    """Helper function to scale `val` between 0 and maxval .

    Args:
      level: Level of the operation that will be between [0, `PARAMETER_MAX`].
      maxval: Maximum value that the operation can have. This will be scaled
        to level/PARAMETER_MAX.

    Returns:
      An int that results from scaling `maxval` according to `level`.
    """
    return int(level * maxval / PARAMETER_MAX)


def pil_wrap(img, mean=0.5, std=0.5):
    """Convert the `img` numpy tensor to a PIL Image."""
    return Image.fromarray(
        np.uint8((img * std + mean) * 255.0)).convert('RGBA')


def pil_unwrap(pil_img, mean=0.5, std=0.5):
    """Converts the PIL img to a numpy array."""
    s = pil_img.size
    pic_array = (np.array(pil_img.getdata()).reshape((s[0], s[1], 4)) / 255.0)
    i1, i2 = np.where(pic_array[:, :, 3] == 0)
    pic_array = (pic_array[:, :, :3] - mean) / std
    pic_array[i1, i2] = [0, 0, 0]
    return pic_array


def apply_policy(policy, img):
    """Apply the `policy` to the numpy `img`.

    Args:
      policy: A list of tuples with the form (name, probability, level) where
        `name` is the name of the augmentation operation to apply, `probability`
        is the probability of applying the operation and `level` is what strength
        the operation to apply.
      img: Numpy image that will have `policy` applied to it.

    Returns:
      The result of applying `policy` to `img`.
    """
    pil_img = pil_wrap(img)
    for xform in policy:
        assert len(xform) == 3
        name, probability, level = xform
        xform_fn = NAME_TO_TRANSFORM[name].pil_transformer(probability, level)
        pil_img = xform_fn(pil_img)
    return pil_unwrap(pil_img)


class TransformFunction(object):
    """Wraps the Transform function for pretty printing options."""

    def __init__(self, func, name):
        self.f = func
        self.name = name

    def __repr__(self):
        return '<' + self.name + '>'

    def __call__(self, pil_img):
        return self.f(pil_img)


class TransformT(object):
    """Each instance of this class represents a specific transform."""

    def __init__(self, name, xform_fn):
        self.name = name
        self.xform = xform_fn

    def pil_transformer(self, probability, level):
        def return_function(im):
            if random.random() < probability:
                im = self.xform(im, level)
            return im

        name = self.name + '({:.1f},{})'.format(probability, level)
        return TransformFunction(return_function, name)

    def do_transform(self, image, level):
        f = self.pil_transformer(PARAMETER_MAX, level)
        return pil_unwrap(f(pil_wrap(image)))


################## Transform Functions ##################
identity = TransformT('Identity', lambda pil_img, level: pil_img)
flip_lr = TransformT(
    'FlipLR',
    lambda pil_img, level: pil_img.transpose(Image.FLIP_LEFT_RIGHT))
flip_ud = TransformT(
    'FlipUD',
    lambda pil_img, level: pil_img.transpose(Image.FLIP_TOP_BOTTOM))
# pylint:disable=g-long-lambda
auto_contrast = TransformT(
    'AutoContrast',
    lambda pil_img, level: ImageOps.autocontrast(
        pil_img.convert('RGB')).convert('RGBA'))
equalize = TransformT(
    'Equalize',
    lambda pil_img, level: ImageOps.equalize(
        pil_img.convert('RGB')).convert('RGBA'))
invert = TransformT(
    'Invert',
    lambda pil_img, level: ImageOps.invert(
        pil_img.convert('RGB')).convert('RGBA'))
# pylint:enable=g-long-lambda
blur = TransformT(
    'Blur', lambda pil_img, level: pil_img.filter(ImageFilter.BLUR))
smooth = TransformT(
    'Smooth',
    lambda pil_img, level: pil_img.filter(ImageFilter.SMOOTH))


def _rotate_impl(pil_img, level):
    """Rotates `pil_img` from -30 to 30 degrees depending on `level`."""
    degrees = int_parameter(level, 30)
    if random.random() > 0.5:
        degrees = -degrees
    return pil_img.rotate(degrees)


rotate = TransformT('Rotate', _rotate_impl)


def _posterize_impl(pil_img, level):
    """Applies PIL Posterize to `pil_img`."""
    level = int_parameter(level, 4)
    return ImageOps.posterize(pil_img.convert('RGB'), 4 - level).convert('RGBA')


posterize = TransformT('Posterize', _posterize_impl)


def _shear_x_impl(pil_img, level):
    """Applies PIL ShearX to `pil_img`.

    The ShearX operation shears the image along the horizontal axis with `level`
    magnitude.

    Args:
      pil_img: Image in PIL object.
      level: Strength of the operation specified as an Integer from
        [0, `PARAMETER_MAX`].

    Returns:
      A PIL Image that has had ShearX applied to it.
    """
    level = float_parameter(level, 0.3)
    if random.random() > 0.5:
        level = -level
    return pil_img.transform(pil_img.size, Image.AFFINE, (1, level, 0, 0, 1, 0))


shear_x = TransformT('ShearX', _shear_x_impl)


def _shear_y_impl(pil_img, level):
    """Applies PIL ShearY to `pil_img`.

    The ShearY operation shears the image along the vertical axis with `level`
    magnitude.

    Args:
      pil_img: Image in PIL object.
      level: Strength of the operation specified as an Integer from
        [0, `PARAMETER_MAX`].

    Returns:
      A PIL Image that has had ShearX applied to it.
    """
    level = float_parameter(level, 0.3)
    if random.random() > 0.5:
        level = -level
    return pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, 0, level, 1, 0))


shear_y = TransformT('ShearY', _shear_y_impl)


def _translate_x_impl(pil_img, level):
    """Applies PIL TranslateX to `pil_img`.

    Translate the image in the horizontal direction by `level`
    number of pixels.

    Args:
      pil_img: Image in PIL object.
      level: Strength of the operation specified as an Integer from
        [0, `PARAMETER_MAX`].

    Returns:
      A PIL Image that has had TranslateX applied to it.
    """
    level = int_parameter(level, 10)
    if random.random() > 0.5:
        level = -level
    return pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, level, 0, 1, 0))


translate_x = TransformT('TranslateX', _translate_x_impl)


def _translate_y_impl(pil_img, level):
    """Applies PIL TranslateY to `pil_img`.

    Translate the image in the vertical direction by `level`
    number of pixels.

    Args:
      pil_img: Image in PIL object.
      level: Strength of the operation specified as an Integer from
        [0, `PARAMETER_MAX`].

    Returns:
      A PIL Image that has had TranslateY applied to it.
    """
    level = int_parameter(level, 10)
    if random.random() > 0.5:
        level = -level
    return pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, 0, 0, 1, level))


translate_y = TransformT('TranslateY', _translate_y_impl)


def _crop_impl(pil_img, level, interpolation=Image.BILINEAR):
    """Applies a crop to `pil_img` with the size depending on the `level`."""
    cropped = pil_img.crop((level, level, pil_img.size[0] - level, pil_img.size[1] - level))
    resized = cropped.resize(pil_img.size, interpolation)
    return resized


crop_bilinear = TransformT('CropBilinear', _crop_impl)


def _solarize_impl(pil_img, level):
    """Applies PIL Solarize to `pil_img`.

    Translate the image in the vertical direction by `level`
    number of pixels.

    Args:
      pil_img: Image in PIL object.
      level: Strength of the operation specified as an Integer from
        [0, `PARAMETER_MAX`].

    Returns:
      A PIL Image that has had Solarize applied to it.
    """
    level = int_parameter(level, 256)
    return ImageOps.solarize(pil_img.convert('RGB'), 256 - level).convert('RGBA')


solarize = TransformT('Solarize', _solarize_impl)


def _cutout_pil_impl(pil_img, level):
    """Apply cutout to pil_img at the specified level."""
    size = int_parameter(level, 20)
    if size <= 0:
        return pil_img
    img_width, img_height = pil_img.size
    num_channels = 3
    _, upper_coord, lower_coord = (
        create_cutout_mask(img_height, img_width, num_channels, size))
    pixels = pil_img.load()  # create the pixel map
    for i in range(upper_coord[0], lower_coord[0]):  # for every col:
        for j in range(upper_coord[1], lower_coord[1]):  # For every row
            pixels[i, j] = (127, 127, 127, 0)  # set the colour accordingly
    return pil_img


cutout = TransformT('Cutout', _cutout_pil_impl)


def _enhancer_impl(enhancer):
    """Sets level to be between 0.1 and 1.8 for ImageEnhance transforms of PIL."""

    def impl(pil_img, level):
        v = float_parameter(level, 1.8) + .1  # going to 0 just destroys it
        return enhancer(pil_img).enhance(v)

    return impl


color = TransformT('Color', _enhancer_impl(ImageEnhance.Color))
contrast = TransformT('Contrast', _enhancer_impl(ImageEnhance.Contrast))
brightness = TransformT('Brightness', _enhancer_impl(
    ImageEnhance.Brightness))
sharpness = TransformT('Sharpness', _enhancer_impl(ImageEnhance.Sharpness))

ALL_TRANSFORMS = [
    identity,
    flip_lr,
    flip_ud,
    auto_contrast,
    equalize,
    invert,
    rotate,
    posterize,
    crop_bilinear,
    solarize,
    color,
    contrast,
    brightness,
    sharpness,
    shear_x,
    shear_y,
    translate_x,
    translate_y,
    cutout,
    blur,
    smooth
]

NAME_TO_TRANSFORM = {t.name: t for t in ALL_TRANSFORMS}
TRANSFORM_NAMES = NAME_TO_TRANSFORM.keys()


In [None]:
# third_party\auto_augment\policies.py
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


def cifar10_policies():
    """AutoAugment policies found on CIFAR-10."""
    exp0_0 = [[('Invert', 0.1, 7), ('Contrast', 0.2, 6)],
              [('Rotate', 0.7, 2), ('TranslateX', 0.3, 9)],
              [('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)],
              [('ShearY', 0.5, 8), ('TranslateY', 0.7, 9)],
              [('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)]]
    exp0_1 = [[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)],
              [('TranslateY', 0.9, 9), ('TranslateY', 0.7, 9)],
              [('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)],
              [('Equalize', 0.8, 8), ('Invert', 0.1, 3)],
              [('TranslateY', 0.7, 9), ('AutoContrast', 0.9, 1)]]
    exp0_2 = [[('Solarize', 0.4, 5), ('AutoContrast', 0.0, 2)],
              [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)],
              [('AutoContrast', 0.9, 0), ('Solarize', 0.4, 3)],
              [('Equalize', 0.7, 5), ('Invert', 0.1, 3)],
              [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)]]
    exp0_3 = [[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 1)],
              [('TranslateY', 0.8, 9), ('TranslateY', 0.9, 9)],
              [('AutoContrast', 0.8, 0), ('TranslateY', 0.7, 9)],
              [('TranslateY', 0.2, 7), ('Color', 0.9, 6)],
              [('Equalize', 0.7, 6), ('Color', 0.4, 9)]]
    exp1_0 = [[('ShearY', 0.2, 7), ('Posterize', 0.3, 7)],
              [('Color', 0.4, 3), ('Brightness', 0.6, 7)],
              [('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)],
              [('Equalize', 0.6, 5), ('Equalize', 0.5, 1)],
              [('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)]]
    exp1_1 = [[('Brightness', 0.3, 7), ('AutoContrast', 0.5, 8)],
              [('AutoContrast', 0.9, 4), ('AutoContrast', 0.5, 6)],
              [('Solarize', 0.3, 5), ('Equalize', 0.6, 5)],
              [('TranslateY', 0.2, 4), ('Sharpness', 0.3, 3)],
              [('Brightness', 0.0, 8), ('Color', 0.8, 8)]]
    exp1_2 = [[('Solarize', 0.2, 6), ('Color', 0.8, 6)],
              [('Solarize', 0.2, 6), ('AutoContrast', 0.8, 1)],
              [('Solarize', 0.4, 1), ('Equalize', 0.6, 5)],
              [('Brightness', 0.0, 0), ('Solarize', 0.5, 2)],
              [('AutoContrast', 0.9, 5), ('Brightness', 0.5, 3)]]
    exp1_3 = [[('Contrast', 0.7, 5), ('Brightness', 0.0, 2)],
              [('Solarize', 0.2, 8), ('Solarize', 0.1, 5)],
              [('Contrast', 0.5, 1), ('TranslateY', 0.2, 9)],
              [('AutoContrast', 0.6, 5), ('TranslateY', 0.0, 9)],
              [('AutoContrast', 0.9, 4), ('Equalize', 0.8, 4)]]
    exp1_4 = [[('Brightness', 0.0, 7), ('Equalize', 0.4, 7)],
              [('Solarize', 0.2, 5), ('Equalize', 0.7, 5)],
              [('Equalize', 0.6, 8), ('Color', 0.6, 2)],
              [('Color', 0.3, 7), ('Color', 0.2, 4)],
              [('AutoContrast', 0.5, 2), ('Solarize', 0.7, 2)]]
    exp1_5 = [[('AutoContrast', 0.2, 0), ('Equalize', 0.1, 0)],
              [('ShearY', 0.6, 5), ('Equalize', 0.6, 5)],
              [('Brightness', 0.9, 3), ('AutoContrast', 0.4, 1)],
              [('Equalize', 0.8, 8), ('Equalize', 0.7, 7)],
              [('Equalize', 0.7, 7), ('Solarize', 0.5, 0)]]
    exp1_6 = [[('Equalize', 0.8, 4), ('TranslateY', 0.8, 9)],
              [('TranslateY', 0.8, 9), ('TranslateY', 0.6, 9)],
              [('TranslateY', 0.9, 0), ('TranslateY', 0.5, 9)],
              [('AutoContrast', 0.5, 3), ('Solarize', 0.3, 4)],
              [('Solarize', 0.5, 3), ('Equalize', 0.4, 4)]]
    exp2_0 = [[('Color', 0.7, 7), ('TranslateX', 0.5, 8)],
              [('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)],
              [('TranslateY', 0.4, 3), ('Sharpness', 0.2, 6)],
              [('Brightness', 0.9, 6), ('Color', 0.2, 8)],
              [('Solarize', 0.5, 2), ('Invert', 0.0, 3)]]
    exp2_1 = [[('AutoContrast', 0.1, 5), ('Brightness', 0.0, 0)],
              [('Cutout', 0.2, 4), ('Equalize', 0.1, 1)],
              [('Equalize', 0.7, 7), ('AutoContrast', 0.6, 4)],
              [('Color', 0.1, 8), ('ShearY', 0.2, 3)],
              [('ShearY', 0.4, 2), ('Rotate', 0.7, 0)]]
    exp2_2 = [[('ShearY', 0.1, 3), ('AutoContrast', 0.9, 5)],
              [('TranslateY', 0.3, 6), ('Cutout', 0.3, 3)],
              [('Equalize', 0.5, 0), ('Solarize', 0.6, 6)],
              [('AutoContrast', 0.3, 5), ('Rotate', 0.2, 7)],
              [('Equalize', 0.8, 2), ('Invert', 0.4, 0)]]
    exp2_3 = [[('Equalize', 0.9, 5), ('Color', 0.7, 0)],
              [('Equalize', 0.1, 1), ('ShearY', 0.1, 3)],
              [('AutoContrast', 0.7, 3), ('Equalize', 0.7, 0)],
              [('Brightness', 0.5, 1), ('Contrast', 0.1, 7)],
              [('Contrast', 0.1, 4), ('Solarize', 0.6, 5)]]
    exp2_4 = [[('Solarize', 0.2, 3), ('ShearX', 0.0, 0)],
              [('TranslateX', 0.3, 0), ('TranslateX', 0.6, 0)],
              [('Equalize', 0.5, 9), ('TranslateY', 0.6, 7)],
              [('ShearX', 0.1, 0), ('Sharpness', 0.5, 1)],
              [('Equalize', 0.8, 6), ('Invert', 0.3, 6)]]
    exp2_5 = [[('AutoContrast', 0.3, 9), ('Cutout', 0.5, 3)],
              [('ShearX', 0.4, 4), ('AutoContrast', 0.9, 2)],
              [('ShearX', 0.0, 3), ('Posterize', 0.0, 3)],
              [('Solarize', 0.4, 3), ('Color', 0.2, 4)],
              [('Equalize', 0.1, 4), ('Equalize', 0.7, 6)]]
    exp2_6 = [[('Equalize', 0.3, 8), ('AutoContrast', 0.4, 3)],
              [('Solarize', 0.6, 4), ('AutoContrast', 0.7, 6)],
              [('AutoContrast', 0.2, 9), ('Brightness', 0.4, 8)],
              [('Equalize', 0.1, 0), ('Equalize', 0.0, 6)],
              [('Equalize', 0.8, 4), ('Equalize', 0.0, 4)]]
    exp2_7 = [[('Equalize', 0.5, 5), ('AutoContrast', 0.1, 2)],
              [('Solarize', 0.5, 5), ('AutoContrast', 0.9, 5)],
              [('AutoContrast', 0.6, 1), ('AutoContrast', 0.7, 8)],
              [('Equalize', 0.2, 0), ('AutoContrast', 0.1, 2)],
              [('Equalize', 0.6, 9), ('Equalize', 0.4, 4)]]
    exp0s = exp0_0 + exp0_1 + exp0_2 + exp0_3
    exp1s = exp1_0 + exp1_1 + exp1_2 + exp1_3 + exp1_4 + exp1_5 + exp1_6
    exp2s = exp2_0 + exp2_1 + exp2_2 + exp2_3 + exp2_4 + exp2_5 + exp2_6 + exp2_7
    return exp0s + exp1s + exp2s


def svhn_policies():
    """AutoAugment policies found on SVHN."""
    policies = [
        [('ShearX', 0.9, 4), ('Invert', 0.2, 3)],
        [('ShearY', 0.9, 8), ('Invert', 0.7, 5)],
        [('Equalize', 0.6, 5), ('Solarize', 0.6, 6)],
        [('Invert', 0.9, 3), ('Equalize', 0.6, 3)],
        [('Equalize', 0.6, 1), ('Rotate', 0.9, 3)],
        [('ShearX', 0.9, 4), ('AutoContrast', 0.8, 3)],
        [('ShearY', 0.9, 8), ('Invert', 0.4, 5)],
        [('ShearY', 0.9, 5), ('Solarize', 0.2, 6)],
        [('Invert', 0.9, 6), ('AutoContrast', 0.8, 1)],
        [('Equalize', 0.6, 3), ('Rotate', 0.9, 3)],
        [('ShearX', 0.9, 4), ('Solarize', 0.3, 3)],
        [('ShearY', 0.8, 8), ('Invert', 0.7, 4)],
        [('Equalize', 0.9, 5), ('TranslateY', 0.6, 6)],
        [('Invert', 0.9, 4), ('Equalize', 0.6, 7)],
        [('Contrast', 0.3, 3), ('Rotate', 0.8, 4)],
        [('ShearX', 0.9, 3), ('Invert', 0.5, 3)],
        [('ShearY', 0.9, 8), ('Invert', 0.4, 5)],
        [('Equalize', 0.6, 3), ('Solarize', 0.2, 3)],
        [('Invert', 0.9, 4), ('Equalize', 0.5, 6)],
        [('Equalize', 0.6, 1), ('Rotate', 0.9, 3)],
        [('Invert', 0.8, 5), ('TranslateY', 0.0, 2)],
        [('ShearY', 0.7, 6), ('Solarize', 0.4, 8)],
        [('Invert', 0.6, 4), ('Rotate', 0.8, 4)],
        [('ShearY', 0.3, 7), ('TranslateX', 0.9, 3)],
        [('ShearX', 0.1, 6), ('Invert', 0.6, 5)],
        [('Solarize', 0.7, 2), ('TranslateY', 0.6, 7)],
        [('ShearY', 0.8, 4), ('Invert', 0.8, 8)],
        [('ShearX', 0.7, 9), ('TranslateY', 0.8, 3)],
        [('ShearY', 0.8, 5), ('AutoContrast', 0.7, 3)],
        [('ShearX', 0.7, 2), ('Invert', 0.1, 5)],
        [('ShearY', 0.8, 9), ('ShearX', 0.7, 7)],
        [('ShearY', 0.7, 4), ('Solarize', 0.9, 7)],
        [('ShearY', 0.9, 5), ('Invert', 0.0, 4)],
        [('TranslateX', 0.8, 3), ('ShearY', 0.7, 7)],
        [('Invert', 0.1, 7), ('Solarize', 0.3, 9)],
        [('Invert', 0.6, 2), ('Invert', 0.9, 4)],
        [('Equalize', 0.5, 2), ('Solarize', 0.9, 7)],
        [('ShearY', 0.6, 7), ('Solarize', 0.8, 3)],
        [('ShearY', 0.6, 3), ('Invert', 0.6, 1)],
        [('ShearX', 0.4, 2), ('Rotate', 0.7, 5)]]
    return policies


def imagenet_policies():
    """AutoAugment policies found on ImageNet.
    This policy also transfers to five FGVC datasets with image size similar to
    ImageNet including Oxford 102 Flowers, Caltech-101, Oxford-IIIT Pets,
    FGVC Aircraft and Stanford Cars.
    """
    policies = [
        [('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
        [('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
        [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
        [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
        [('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
        [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
        [('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
        [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
        [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
        [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
        [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
        [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
        [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
        [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
        [('Color', 0.4, 0), ('Equalize', 0.6, 3)]
    ]
    return policies


In [None]:
# libml\augment.py
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Augmentations for images.
"""
import collections
import functools
import itertools
import multiprocessing
import random

import numpy as np
import tensorflow as tf
from absl import flags

# from libml import utils, ctaugment
# from libml.utils import EasyDict
# from third_party.auto_augment import augmentations, policies

FLAGS = flags.FLAGS
POOL = None
POLICIES = EasyDict(cifar10=cifar10_policies(),
                    cifar100=cifar10_policies(),
                    svhn=svhn_policies(),
                    svhn_noextra=svhn_policies())

RANDOM_POLICY_OPS = (
    'Identity', 'AutoContrast', 'Equalize', 'Rotate',
    'Solarize', 'Color', 'Contrast', 'Brightness',
    'Sharpness', 'ShearX', 'TranslateX', 'TranslateY',
    'Posterize', 'ShearY'
)
AUGMENT_ENUM = 'd x m aa aac ra rac'.split() + ['r%d_%d_%d' % (nops, mag, cutout) for nops, mag, cutout in
                                                itertools.product(range(1, 5), range(1, 16), range(0, 100, 25))] + [
                   'rac%d' % (mag) for mag in range(1, 10)]

for name in list(flags.FLAGS):
    if name == 'K' or name == 'augment':
      delattr(flags.FLAGS,name) 

flags.DEFINE_integer('K', 1, 'Number of strong augmentation for unlabeled data.')
flags.DEFINE_enum('augment', 'd.d',
                  [x + '.' + y for x, y in itertools.product(AUGMENT_ENUM, AUGMENT_ENUM)] +
                  [x + '.' + y + '.' + z for x, y, z in itertools.product(AUGMENT_ENUM, AUGMENT_ENUM, AUGMENT_ENUM)] + [
                      'd.d.d.d', 'd.aac.d.aac', 'd.rac.d.rac'],
                  'Dataset augmentation method (x=identity, m=mirror, d=default, aa=auto-augment, aac=auto-augment+cutout, '
                  'ra=rand-augment, rac=rand-augment+cutout; for rand-augment, magnitude is also randomized'
                  'rxyy=random augment with x ops and magnitude yy),'
                  'first is for labeled data, others are for unlabeled.')


def init_pool():
    global POOL
    if POOL is None:
        para = max(1, len(get_available_gpus())) * FLAGS.para_augment
        POOL = multiprocessing.Pool(para)


def augment_mirror(x):
    return tf.image.random_flip_left_right(x)


def augment_shift(x, w):
    y = tf.pad(x, [[w] * 2, [w] * 2, [0] * 2], mode='REFLECT')
    return tf.random_crop(y, tf.shape(x))


def augment_noise(x, std):
    return x + std * tf.random_normal(tf.shape(x), dtype=x.dtype)


def numpy_apply_policy(x, policy):
    return augmentations.apply_policy(policy, x).astype('f')


def stack_augment(augment: list):
    def func(x):
        xl = [augment[i](x) if augment[i] is not None else x for i in range(len(augment))]
        return {k: tf.stack([x[k] for x in xl]) for k in xl[0].keys()}

    return func


class Primitives:
    @staticmethod
    def m():
        return lambda x: augment_mirror(x['image'])

    @staticmethod
    def ms(shift):
        return lambda x: augment_shift(augment_mirror(x['image']), shift)

    @staticmethod
    def s(shift):
        return lambda x: augment_shift(x['image'], shift)


AugmentPair = collections.namedtuple('AugmentPair', 'tf numpy')
PoolEntry = collections.namedtuple('PoolEntry', 'payload batch')


class AugmentPool:
    def __init__(self, get_samples):
        self.get_samples = get_samples

    def __call__(self, *args, **kwargs):
        return self.get_samples()


NOAUGMENT = AugmentPair(tf=lambda x: dict(image=x['image'], label=x['label'], index=x.get('index', -1)),
                        numpy=AugmentPool)


class AugmentPoolAA(AugmentPool):

    def __init__(self, get_samples, policy_group):
        init_pool()
        self.get_samples = get_samples
        self.policy_group = policy_group
        self.queue = []
        self.fill_queue()

    @staticmethod
    def numpy_apply_policies(arglist):
        x, policies = arglist
        return np.stack([augmentations.apply_policy(policy, y) for y, policy in zip(x, policies)]).astype('f')

    def queue_images(self, batch):
        args = []
        image = batch['image']
        if image.ndim == 4:
            for x in range(image.shape[0]):
                args.append((image[x:x + 1], [random.choice(POLICIES[self.policy_group])]))
        else:
            for x in image[:, 1:]:
                args.append((x, [random.choice(POLICIES[self.policy_group]) for _ in range(x.shape[0])]))
        self.queue.append(PoolEntry(payload=POOL.imap(self.numpy_apply_policies, args), batch=batch))

    def fill_queue(self):
        for _ in range(4):
            self.queue_images(self.get_samples())

    def __call__(self, *args, **kwargs):
        del args, kwargs
        batch = self.get_samples()
        entry = self.queue.pop(0)
        samples = np.stack(list(entry.payload))
        if entry.batch['image'].ndim == 4:
            samples = samples.reshape(entry.batch['image'].shape)
            entry.batch['image'] = samples
        else:
            samples = samples.reshape(entry.batch['image'][:, 1:].shape)
            entry.batch['image'][:, 1:] = samples
        self.queue_images(batch)
        return entry.batch


class AugmentPoolAAC(AugmentPoolAA):

    def __init__(self, get_samples, policy_group):
        init_pool()
        self.get_samples = get_samples
        self.policy_group = policy_group
        self.queue = []
        self.fill_queue()

    @staticmethod
    def numpy_apply_policies(arglist):
        x, policies = arglist
        return np.stack([augmentations.cutout_numpy(augmentations.apply_policy(policy, y)) for y, policy in
                         zip(x, policies)]).astype('f')


class AugmentPoolRAM(AugmentPoolAA):
    # Randomized magnitude
    def __init__(self, get_samples, nops=2, magnitude=10):
        init_pool()
        self.get_samples = get_samples
        self.nops = nops
        self.magnitude = magnitude
        self.queue = []
        self.fill_queue()

    @staticmethod
    def numpy_apply_policies(arglist):
        x, policies = arglist
        return np.stack([augmentations.apply_policy(policy, y) for y, policy in zip(x, policies)]).astype('f')

    def queue_images(self, batch):
        args = []
        image = batch['image']
        policy = lambda: [(op, 0.5, np.random.randint(1, self.magnitude))
                          for op in np.random.choice(RANDOM_POLICY_OPS, self.nops)]
        if image.ndim == 4:
            for x in range(image.shape[0]):
                args.append((image[x:x + 1], [policy()]))
        else:
            for x in image[:, 1:]:
                args.append((x, [policy() for _ in range(x.shape[0])]))
        self.queue.append(PoolEntry(payload=POOL.imap(self.numpy_apply_policies, args), batch=batch))


class AugmentPoolRAMC(AugmentPoolRAM):
    # Randomized magnitude (inherited from queue images)
    def __init__(self, get_samples, nops=2, magnitude=10):
        init_pool()
        self.get_samples = get_samples
        self.nops = nops
        self.magnitude = magnitude
        self.queue = []
        self.fill_queue()

    @staticmethod
    def numpy_apply_policies(arglist):
        x, policies = arglist
        return np.stack([augmentations.cutout_numpy(augmentations.apply_policy(policy, y)) for y, policy in
                         zip(x, policies)]).astype('f')


class AugmentPoolRAMC2(AugmentPoolRAM):
    # Randomized magnitude (inherited from queue images)
    def __init__(self, get_samples, nops=2, magnitude=10):
        init_pool()
        self.get_samples = get_samples
        self.nops = nops
        self.magnitude = magnitude
        self.queue = []
        self.fill_queue()

    @staticmethod
    def numpy_apply_policies(arglist):
        x, policies = arglist
        return np.stack([augmentations.cutout_numpy(augmentations.apply_policy(policy, y)) for y, policy in
                         zip(x, policies)]).astype('f')

    def queue_images(self, batch):
        args = []
        image = batch['image']
        policy = lambda: [(op, 0.5, np.random.randint(1, self.magnitude))
                          for op in np.random.choice(RANDOM_POLICY_OPS, self.nops)]
        if image.ndim == 4:
            for x in range(image.shape[0]):
                args.append((image[x:x + 1], [policy()]))
        else:
            for x in image[:, :]:
                args.append((x, [policy() for _ in range(x.shape[0])]))
        self.queue.append(PoolEntry(payload=POOL.imap(self.numpy_apply_policies, args), batch=batch))

    def __call__(self, *args, **kwargs):
        del args, kwargs
        batch = self.get_samples()
        entry = self.queue.pop(0)
        samples = np.stack(list(entry.payload))
        if entry.batch['image'].ndim == 4:
            samples = samples.reshape(entry.batch['image'].shape)
            entry.batch['image'] = samples
        else:
            samples = samples.reshape(entry.batch['image'][:, :].shape)
            entry.batch['image'][:, :] = samples
        self.queue_images(batch)
        return entry.batch


class AugmentPoolRA(AugmentPoolAA):
    def __init__(self, get_samples, nops, magnitude, cutout):
        init_pool()
        self.get_samples = get_samples
        self.nops = nops
        self.magnitude = magnitude
        self.size = cutout
        self.queue = []
        self.fill_queue()

    @staticmethod
    def numpy_apply_policies(arglist):
        x, policies, cutout = arglist
        return np.stack([augmentations.cutout_numpy(augmentations.apply_policy(policy, y),
                                                    size=int(0.01 * cutout * min(y.shape[:2])))
                         for y, policy in zip(x, policies)]).astype('f')

    def queue_images(self, batch):
        args = []
        image = batch['image']
        # Fixed magnitude
        policy = lambda: [(op, 1.0, self.magnitude) for op in np.random.choice(RANDOM_POLICY_OPS, self.nops)]
        if image.ndim == 4:
            for x in range(image.shape[0]):
                args.append((image[x:x + 1], [policy()], self.size))
        else:
            for x in image[:, 1:]:
                args.append((x, [policy() for _ in range(x.shape[0])], self.size))
        self.queue.append(PoolEntry(payload=POOL.imap(self.numpy_apply_policies, args), batch=batch))


class AugmentPoolCTA(AugmentPool):

    def __init__(self, get_samples):
        init_pool()
        self.get_samples = get_samples
        self.queue = []
        self.fill_queue()

    @staticmethod
    def numpy_apply_policies(arglist):
        x, cta, probe = arglist
        if x.ndim == 3:
            assert probe
            policy = cta.policy(probe=True)
            return dict(policy=policy,
                        probe=ctaugment.apply(x, policy),
                        image=ctaugment.apply(x, cta.policy(probe=False)))
        assert not probe
        return dict(image=np.stack([x[0]] + [ctaugment.apply(y, cta.policy(probe=False)) for y in x[1:]]).astype('f'))

    def queue_images(self):
        batch = self.get_samples()
        args = [(x, batch['cta'], batch['probe']) for x in batch['image']]
        self.queue.append(PoolEntry(payload=POOL.imap(self.numpy_apply_policies, args), batch=batch))

    def fill_queue(self):
        for _ in range(4):
            self.queue_images()

    def __call__(self, *args, **kwargs):
        del args, kwargs
        entry = self.queue.pop(0)
        samples = list(entry.payload)
        entry.batch['image'] = np.stack(x['image'] for x in samples)
        if 'probe' in samples[0]:
            entry.batch['probe'] = np.stack(x['probe'] for x in samples)
            entry.batch['policy'] = [x['policy'] for x in samples]
        self.queue_images()
        return entry.batch


DEFAULT_AUGMENT = EasyDict(
    cifar10=AugmentPair(tf=lambda x: dict(image=Primitives.ms(4)(x), label=x['label'], index=x.get('index', -1)),
                        numpy=AugmentPool),
    cifar100=AugmentPair(tf=lambda x: dict(image=Primitives.ms(4)(x), label=x['label'], index=x.get('index', -1)),
                         numpy=AugmentPool),
    fashion_mnist=AugmentPair(tf=lambda x: dict(image=Primitives.ms(4)(x), label=x['label'], index=x.get('index', -1)),
                              numpy=AugmentPool),
    stl10=AugmentPair(tf=lambda x: dict(image=Primitives.ms(12)(x), label=x['label'], index=x.get('index', -1)),
                      numpy=AugmentPool),
    svhn=AugmentPair(tf=lambda x: dict(image=Primitives.s(4)(x), label=x['label'], index=x.get('index', -1)),
                     numpy=AugmentPool),
    svhn_noextra=AugmentPair(tf=lambda x: dict(image=Primitives.s(4)(x), label=x['label'], index=x.get('index', -1)),
                             numpy=AugmentPool),
)
AUTO_AUGMENT = EasyDict({
    k: AugmentPair(tf=v.tf, numpy=functools.partial(AugmentPoolAA, policy_group=k))
    for k, v in DEFAULT_AUGMENT.items()
})
AUTO_AUGMENT_CUTOUT = EasyDict({
    k: AugmentPair(tf=v.tf, numpy=functools.partial(AugmentPoolAAC, policy_group=k))
    for k, v in DEFAULT_AUGMENT.items()
})
RAND_AUGMENT = EasyDict({
    k: AugmentPair(tf=v.tf, numpy=functools.partial(AugmentPoolRAM, nops=2, magnitude=10))
    for k, v in DEFAULT_AUGMENT.items()
})
RAND_AUGMENT_CUTOUT = EasyDict({
    k: AugmentPair(tf=v.tf, numpy=functools.partial(AugmentPoolRAMC, nops=2, magnitude=10))
    for k, v in DEFAULT_AUGMENT.items()
})


def get_augmentation(dataset: str, augmentation: str):
    if augmentation == 'x':
        return NOAUGMENT
    elif augmentation == 'm':
        return AugmentPair(tf=lambda x: dict(image=Primitives.m()(x), label=x['label'], index=x.get('index', -1)),
                           numpy=AugmentPool)
    elif augmentation == 'd':
        return DEFAULT_AUGMENT[dataset]
    elif augmentation == 'aa':
        return AUTO_AUGMENT[dataset]
    elif augmentation == 'aac':
        return AUTO_AUGMENT_CUTOUT[dataset]
    elif augmentation == 'ra':
        return RAND_AUGMENT[dataset]
    elif augmentation.startswith('rac'):
        mag = 10 if augmentation == 'rac' else int(augmentation[-1])
        return AugmentPair(tf=DEFAULT_AUGMENT[dataset].tf,
                           numpy=functools.partial(AugmentPoolRAMC, nops=2, magnitude=mag))
    elif augmentation[0] == 'r':
        nops, mag, cutout = (int(x) for x in augmentation[1:].split('_'))
        return AugmentPair(tf=DEFAULT_AUGMENT[dataset].tf,
                           numpy=functools.partial(AugmentPoolRA, nops=nops, magnitude=mag, cutout=cutout))
    else:
        raise NotImplementedError(augmentation)


def augment_function(dataset: str):
    augmentations = FLAGS.augment.split('.')
    assert len(augmentations) == 2
    return [get_augmentation(dataset, x) for x in augmentations]


def pair_augment_function(dataset: str):
    augmentations = FLAGS.augment.split('.')
    assert len(augmentations) == 3
    unlabeled = [get_augmentation(dataset, x) for x in augmentations[1:]]
    return [get_augmentation(dataset, augmentations[0]),
            AugmentPair(tf=stack_augment([x.tf for x in unlabeled]), numpy=unlabeled[-1].numpy)]


def quad_augment_function(dataset: str):
    augmentations = FLAGS.augment.split('.')
    assert len(augmentations) == 4
    labeled = [get_augmentation(dataset, x) for x in augmentations[:2]]
    unlabeled = [get_augmentation(dataset, x) for x in augmentations[2:]]
    return [AugmentPair(tf=stack_augment([x.tf for x in labeled]), numpy=labeled[-1].numpy),
            AugmentPair(tf=stack_augment([x.tf for x in unlabeled]), numpy=unlabeled[-1].numpy)]


def many_augment_function(dataset: str):
    augmentations = FLAGS.augment.split('.')
    assert len(augmentations) == 3
    unlabeled = [get_augmentation(dataset, x) for x in (augmentations[1:2] + augmentations[2:] * FLAGS.K)]
    return [get_augmentation(dataset, augmentations[0]),
            AugmentPair(tf=stack_augment([x.tf for x in unlabeled]), numpy=unlabeled[-1].numpy)]


-we used another method, described above using an index for sub dataset,
iDataset = {0 : 19}

In [None]:
# libml\data.py
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input data for image models.
"""

import functools
import itertools
import os

import numpy as np
import tensorflow as tf
# NOAM
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

from absl import app
from absl import flags
from tqdm import tqdm

# from libml import augment as augment_module
# from libml import utils
# from libml.augment import AugmentPair, NOAUGMENT

# Data directory. Value is initialized in _data_setup
#
# Note that if you need to use DATA_DIR outside of this module then
# you should do following:
#     from libml import data as libml_data
#     ...
#     dir = libml_data.DATA_DIR
#
# If you directly import DATA_DIR:
#   from libml.data import DATA_DIR
# then None will be imported.
DATA_DIR = None

_DATA_CACHE = None
SAMPLES_PER_CLASS = [1, 2, 3, 4, 5, 10, 25, 100, 400]

for name in list(flags.FLAGS):
    if name == 'dataset' or name == 'para_parse' or name == 'para_augment' or name == 'shuffle' or name == 'p_unlabeled' or name == 'whiten' or name == 'data_dir':
      delattr(flags.FLAGS,name) 

flags.DEFINE_string('dataset', 'cifar10.1@4000-5000', 'Data to train on.')
flags.DEFINE_integer('para_parse', 1, 'Parallel parsing.')
flags.DEFINE_integer('para_augment', 5, 'Parallel augmentation.')
flags.DEFINE_integer('shuffle', 8192, 'Size of dataset shuffling.')
flags.DEFINE_string('p_unlabeled', '', 'Probability distribution of unlabeled.')
flags.DEFINE_bool('whiten', False, 'Whether to normalize images.')
flags.DEFINE_string('data_dir', None,
                    'Data directory. '
                    'If None then environment variable ML_DATA '
                    'will be used as a data directory.')

FLAGS = flags.FLAGS


def _data_setup():
    # set up data directory
    global DATA_DIR
    # DATA_DIR = FLAGS.data_dir or os.environ['ML_DATA']    #NOAM
    DATA_DIR = "/content/drive/MyDrive/MachineLearning/Final_Project_Fix_Match/SAVED_DATA/"


app.call_after_init(_data_setup)


def record_parse_mnist(serialized_example, image_shape=None):
    features = tf.parse_single_example(
        serialized_example,
        features={'image': tf.FixedLenFeature([], tf.string),
                  'label': tf.FixedLenFeature([], tf.int64)})
    image = tf.image.decode_image(features['image'])
    if image_shape:
        image.set_shape(image_shape)
    image = tf.pad(image, [[2] * 2, [2] * 2, [0] * 2])
    image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0
    return dict(image=image, label=features['label'])


def record_parse(serialized_example, image_shape=None):
    features = tf.parse_single_example(
        serialized_example,
        features={'image': tf.FixedLenFeature([], tf.string),
                  'label': tf.FixedLenFeature([], tf.int64)})
    image = tf.image.decode_image(features['image'])
    if image_shape:
        image.set_shape(image_shape)
    image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0
    return dict(image=image, label=features['label'])


def compute_mean_std(data: tf.data.Dataset):
    data = data.map(lambda x: x['image']).batch(1024).prefetch(1)
    data = data.make_one_shot_iterator().get_next()
    count = 0
    stats = []
    with tf.Session(config=get_config()) as sess:
        def iterator():
            while True:
                try:
                    yield sess.run(data)
                except tf.errors.OutOfRangeError:
                    break

        for batch in tqdm(iterator(), unit='kimg', desc='Computing dataset mean and std'):
            ratio = batch.shape[0] / 1024.
            count += ratio
            stats.append((batch.mean((0, 1, 2)) * ratio, (batch ** 2).mean((0, 1, 2)) * ratio))
    mean = sum(x[0] for x in stats) / count
    sigma = sum(x[1] for x in stats) / count - mean ** 2
    std = np.sqrt(sigma)
    print('Mean %s  Std: %s' % (mean, std))
    return mean, std


class DataSet:
    """Wrapper for tf.data.Dataset to permit extensions."""

    def __init__(self, data: tf.data.Dataset, augment_fn: AugmentPair, parse_fn=record_parse, image_shape=None):
        self.data = data
        self.parse_fn = parse_fn
        self.augment_fn = augment_fn
        self.image_shape = image_shape

    @classmethod
    def from_files(cls, filenames: list, augment_fn: AugmentPair, parse_fn=record_parse, image_shape=None):
        filenames_in = filenames
        filenames = sorted(sum([tf.io.gfile.Glob(x) for x in filenames], []))
        if not filenames:
            raise ValueError('Empty dataset, did you mount gcsfuse bucket?', filenames_in)
        if len(filenames) > 4:
            def fetch_dataset(filename):
                buffer_size = 8 * 1024 * 1024  # 8 MiB per file
                dataset = tf.data.TFRecordDataset(filename, buffer_size=buffer_size)
                return dataset

            # Read the data from disk in parallel
            dataset = tf.data.Dataset.from_tensor_slices(filenames)
            dataset = dataset.apply(
                tf.data.experimental.parallel_interleave(
                    fetch_dataset,
                    cycle_length=min(16, len(filenames)),
                    sloppy=True))
        else:
            dataset = tf.data.TFRecordDataset(filenames)
        return cls(dataset,
                   augment_fn=augment_fn,
                   parse_fn=parse_fn,
                   image_shape=image_shape)

    @classmethod
    def empty_data(cls, image_shape, augment_fn: AugmentPair = None):
        def _get_null_input(_):
            return dict(image=tf.zeros(image_shape, tf.float32),
                        label=tf.constant(0, tf.int64))

        return cls(tf.data.Dataset.range(FLAGS.batch).map(_get_null_input),
                   parse_fn=None,
                   augment_fn=augment_fn,
                   image_shape=image_shape)

    def __getattr__(self, item):
        if item in self.__dict__:
            return self.__dict__[item]

        def call_and_update(*args, **kwargs):
            v = getattr(self.__dict__['data'], item)(*args, **kwargs)
            if isinstance(v, tf.data.Dataset):
                return self.__class__(v,
                                      parse_fn=self.parse_fn,
                                      augment_fn=self.augment_fn,
                                      image_shape=self.image_shape)
            return v

        return call_and_update

    def parse(self):
        if self.parse_fn:
            para = 4 * max(1, len(get_available_gpus())) * FLAGS.para_parse
            if self.image_shape:
                return self.map(lambda x: self.parse_fn(x, self.image_shape), para)
            else:
                return self.map(self.parse_fn, para)
        return self

    def numpy_augment(self, *args, **kwargs):
        return self.augment_fn.numpy(*args, **kwargs)

    def augment(self):
        if self.augment_fn:
            para = max(1, len(get_available_gpus())) * FLAGS.para_augment
            return self.map(self.augment_fn.tf, para)
        return self

    def memoize(self):
        """Call before parsing, since it calls for parse inside."""
        data = []
        with tf.Session(config=get_config()) as session:
            it = self.parse().prefetch(16).make_one_shot_iterator().get_next()
            try:
                while 1:
                    data.append(session.run(it))
            except tf.errors.OutOfRangeError:
                pass
        images = np.stack([x['image'] for x in data])
        labels = np.stack([x['label'] for x in data])

        def tf_get(index, image_shape):
            def get(index):
                return images[index], labels[index]

            image, label = tf.py_func(get, [index], [tf.float32, tf.int64])
            return dict(image=tf.reshape(image, image_shape), label=label, index=index)

        return self.__class__(tf.data.Dataset.range(len(data)),
                              parse_fn=tf_get,
                              augment_fn=self.augment_fn,
                              image_shape=self.image_shape)


class DataSets:
    def __init__(self, name, train_labeled: DataSet, train_unlabeled: DataSet, test: DataSet, valid: DataSet,
                 height=32, width=32, colors=3, nclass=10, mean=0, std=1, p_labeled=None, p_unlabeled=None):
        self.name = name
        self.train_labeled = train_labeled
        self.train_unlabeled = train_unlabeled
        self.test = test
        self.valid = valid
        self.height = height
        self.width = width
        self.colors = colors
        self.nclass = nclass
        self.mean = mean
        self.std = std
        self.p_labeled = p_labeled
        self.p_unlabeled = p_unlabeled

    @classmethod
    def creator(cls, name, seed, label, valid, augment, parse_fn=record_parse, do_memoize=False,
                nclass=10, colors=3, height=32, width=32):
        if not isinstance(augment, list):
            augment = augment(name)
        fullname = '.%d@%d' % (seed, label)
        root = os.path.join(DATA_DIR, 'SSL2', name)

        def create():
            p_labeled = p_unlabeled = None

            if FLAGS.p_unlabeled:
                sequence = FLAGS.p_unlabeled.split(',')
                p_unlabeled = np.array(list(map(float, sequence)), dtype=np.float32)
                p_unlabeled /= np.max(p_unlabeled)

            image_shape = [height, width, colors]
            train_labeled = DataSet.from_files(
                [root + fullname + '-label.tfrecord'], augment[0], parse_fn, image_shape)
            train_unlabeled = DataSet.from_files(
                [root + '-unlabel.tfrecord'], augment[1], parse_fn, image_shape)
            if do_memoize:
                train_labeled = train_labeled.memoize()
                train_unlabeled = train_unlabeled.memoize()

            if FLAGS.whiten:
                mean, std = compute_mean_std(train_labeled.concatenate(train_unlabeled))
            else:
                mean, std = 0, 1

            test_data = DataSet.from_files(
                [os.path.join(DATA_DIR, '%s-test.tfrecord' % name)], NOAUGMENT, parse_fn, image_shape=image_shape)

            return cls(name + '.' + FLAGS.augment + fullname + '-' + str(valid)
                       + ('/' + FLAGS.p_unlabeled if FLAGS.p_unlabeled else ''),
                       train_labeled=train_labeled,
                       train_unlabeled=train_unlabeled.skip(valid),
                       valid=train_unlabeled.take(valid),
                       test=test_data,
                       nclass=nclass, p_labeled=p_labeled, p_unlabeled=p_unlabeled,
                       height=height, width=width, colors=colors, mean=mean, std=std)

        return name + fullname + '-' + str(valid), create


def create_datasets(augment_fn):
    d = {}
    d.update([DataSets.creator('cifar10', seed, label, valid, augment_fn)
              for seed, label, valid in itertools.product(range(6), [10 * x for x in SAMPLES_PER_CLASS], [1, 5000])])
    d.update([DataSets.creator('cifar100', seed, label, valid, augment_fn, nclass=100)
              for seed, label, valid in itertools.product(range(6), [400, 1000, 2500, 10000], [1, 5000])])

    # NOAM
    
    # d.update([DataSets.creator('fashion_mnist', seed, label, valid, augment_fn, height=32, width=32, colors=1,
    #                            parse_fn=record_parse_mnist)
    #           for seed, label, valid in itertools.product(range(6), [10 * x for x in SAMPLES_PER_CLASS], [1, 5000])])
    # d.update([DataSets.creator('stl10', seed, label, valid, augment_fn, height=96, width=96)
    #           for seed, label, valid in itertools.product(range(6), [1000, 5000], [1, 500])])
    # d.update([DataSets.creator('svhn', seed, label, valid, augment_fn)
    #           for seed, label, valid in itertools.product(range(6), [10 * x for x in SAMPLES_PER_CLASS], [1, 5000])])
    # d.update([DataSets.creator('svhn_noextra', seed, label, valid, augment_fn)
    #           for seed, label, valid in itertools.product(range(6), [10 * x for x in SAMPLES_PER_CLASS], [1, 5000])])
    return d


DATASETS = functools.partial(create_datasets, augment_function)
PAIR_DATASETS = functools.partial(create_datasets, pair_augment_function)
MANY_DATASETS = functools.partial(create_datasets, many_augment_function)
QUAD_DATASETS = functools.partial(create_datasets, quad_augment_function)


In [None]:
#libml\train.py
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training loop, checkpoint saving and loading, evaluation code."""
import functools
import json
import os.path
import shutil

import numpy as np
import tensorflow as tf
from absl import flags
from tqdm import trange, tqdm

# from libml import data, utils
# from libml.utils import EasyDict


for name in list(flags.FLAGS):
    if name == 'train_dir' or name == 'lr' or name == 'batch' or name == 'train_kimg' or name == 'report_kimg' or name == 'save_kimg' or name == 'keep_ckpt' or name == 'eval_ckpt' or name == 'rerun':
      delattr(flags.FLAGS,name) 


FLAGS = flags.FLAGS
flags.DEFINE_string('train_dir', './experiments',
                    'Folder where to save training data.')
flags.DEFINE_float('lr', 0.0001, 'Learning rate.')
flags.DEFINE_integer('batch', 64, 'Batch size.')
flags.DEFINE_integer('train_kimg', 1 << 14, 'Training duration in kibi-samples.')
flags.DEFINE_integer('report_kimg', 64, 'Report summary period in kibi-samples.')
flags.DEFINE_integer('save_kimg', 64, 'Save checkpoint period in kibi-samples.')
flags.DEFINE_integer('keep_ckpt', 50, 'Number of checkpoints to keep.')
flags.DEFINE_string('eval_ckpt', '', 'Checkpoint to evaluate. If provided, do not do training, just do eval.')
flags.DEFINE_string('rerun', '', 'A string to identify a run if running multiple ones with same parameters.')


class Model:
    def __init__(self, train_dir: str, dataset: DataSets, **kwargs):
        self.train_dir = os.path.join(train_dir, FLAGS.rerun, self.experiment_name(**kwargs))
        self.params = EasyDict(kwargs)
        self.dataset = dataset
        self.session = None
        self.tmp = EasyDict(print_queue=[], cache=EasyDict())
        self.step = tf.train.get_or_create_global_step()
        self.ops = self.model(**kwargs)
        self.ops.update_step = tf.assign_add(self.step, FLAGS.batch)
        self.add_summaries(**kwargs)

        print(' Config '.center(80, '-'))
        print('train_dir', self.train_dir)
        print('%-32s %s' % ('Model', self.__class__.__name__))
        print('%-32s %s' % ('Dataset', dataset.name))
        for k, v in sorted(kwargs.items()):
            print('%-32s %s' % (k, v))
        print(' Model '.center(80, '-'))
        to_print = [tuple(['%s' % x for x in (v.name, np.prod(v.shape), v.shape)]) for v in model_vars(None)]
        to_print.append(('Total', str(sum(int(x[1]) for x in to_print)), ''))
        sizes = [max([len(x[i]) for x in to_print]) for i in range(3)]
        fmt = '%%-%ds  %%%ds  %%%ds' % tuple(sizes)
        for x in to_print[:-1]:
            print(fmt % x)
        print()
        print(fmt % to_print[-1])
        print('-' * 80)
        self._create_initial_files()

    @property
    def arg_dir(self):
        return os.path.join(self.train_dir, 'args')

    @property
    def checkpoint_dir(self):
        return os.path.join(self.train_dir, 'tf')

    def train_print(self, text):
        self.tmp.print_queue.append(text)

    def _create_initial_files(self):
        for dir in (self.checkpoint_dir, self.arg_dir):
            tf.gfile.MakeDirs(dir)
        self.save_args()

    def _reset_files(self):
        shutil.rmtree(self.train_dir)
        self._create_initial_files()

    def save_args(self, **extra_params):
        with tf.gfile.Open(os.path.join(self.arg_dir, 'args.json'), 'w') as f:
            json.dump({**self.params, **extra_params}, f, sort_keys=True, indent=4)

    @classmethod
    def load(cls, train_dir):
        with tf.gfile.Open(os.path.join(train_dir, 'args/args.json'), 'r') as f:
            params = json.load(f)
        instance = cls(train_dir=train_dir, **params)
        instance.train_dir = train_dir
        return instance

    def experiment_name(self, **kwargs):
        args = [x + str(y) for x, y in sorted(kwargs.items())]
        return '_'.join([self.__class__.__name__] + args)

    def eval_mode(self, ckpt=None):
        self.session = tf.Session(config=get_config())
        saver = tf.train.Saver()
        if ckpt is None:
            ckpt = find_latest_checkpoint(self.checkpoint_dir)
        else:
            ckpt = os.path.abspath(ckpt)
        saver.restore(self.session, ckpt)
        self.tmp.step = self.session.run(self.step)
        print('Eval model %s at global_step %d' % (self.__class__.__name__, self.tmp.step))
        return self

    def model(self, **kwargs):
        raise NotImplementedError()

    def add_summaries(self, **kwargs):
        raise NotImplementedError()


class ClassifySemi(Model):
    """Semi-supervised classification."""

    def __init__(self, train_dir: str, dataset: DataSets, nclass: int, **kwargs):
        self.nclass = nclass
        Model.__init__(self, train_dir, dataset, nclass=nclass, **kwargs)

    def train_step(self, train_session, gen_labeled, gen_unlabeled):
        x, y = gen_labeled(), gen_unlabeled()
        self.tmp.step = train_session.run([self.ops.train_op, self.ops.update_step],
                                          feed_dict={self.ops.y: y['image'],
                                                     self.ops.xt: x['image'],
                                                     self.ops.label: x['label']})[1]

    def gen_labeled_fn(self, data_iterator):
        return self.dataset.train_labeled.numpy_augment(lambda: self.session.run(data_iterator))

    def gen_unlabeled_fn(self, data_iterator):
        return self.dataset.train_unlabeled.numpy_augment(lambda: self.session.run(data_iterator))

    def train(self, train_nimg, report_nimg):
        if FLAGS.eval_ckpt:
            self.eval_checkpoint(FLAGS.eval_ckpt)
            return
        batch = FLAGS.batch
        train_labeled = self.dataset.train_labeled.repeat().shuffle(FLAGS.shuffle).parse().augment()
        train_labeled = train_labeled.batch(batch).prefetch(16).make_one_shot_iterator().get_next()
        train_unlabeled = self.dataset.train_unlabeled.repeat().shuffle(FLAGS.shuffle).parse().augment()
        train_unlabeled = train_unlabeled.batch(batch).prefetch(16).make_one_shot_iterator().get_next()
        scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=FLAGS.keep_ckpt, pad_step_number=10))

        with tf.Session(config=get_config()) as sess:
            self.session = sess
            self.cache_eval()

        with tf.train.MonitoredTrainingSession(
                scaffold=scaffold,
                checkpoint_dir=self.checkpoint_dir,
                config=get_config(),
                save_checkpoint_steps=FLAGS.save_kimg << 10,
                save_summaries_steps=report_nimg - batch) as train_session:
            self.session = train_session._tf_sess()
            gen_labeled = self.gen_labeled_fn(train_labeled)
            gen_unlabeled = self.gen_unlabeled_fn(train_unlabeled)
            self.tmp.step = self.session.run(self.step)
            while self.tmp.step < train_nimg:
                loop = trange(self.tmp.step % report_nimg, report_nimg, batch,
                              leave=False, unit='img', unit_scale=batch,
                              desc='Epoch %d/%d' % (1 + (self.tmp.step // report_nimg), train_nimg // report_nimg))
                for _ in loop:
                    self.train_step(train_session, gen_labeled, gen_unlabeled)
                    while self.tmp.print_queue:
                        loop.write(self.tmp.print_queue.pop(0))
            while self.tmp.print_queue:
                print(self.tmp.print_queue.pop(0))

    def eval_checkpoint(self, ckpt=None):
        self.eval_mode(ckpt)
        self.cache_eval()
        raw = self.eval_stats(classify_op=self.ops.classify_raw)
        ema = self.eval_stats(classify_op=self.ops.classify_op)
        print('%16s %8s %8s %8s' % ('', 'labeled', 'valid', 'test'))
        print('%16s %8s %8s %8s' % (('raw',) + tuple('%.2f' % x for x in raw)))
        print('%16s %8s %8s %8s' % (('ema',) + tuple('%.2f' % x for x in ema)))

    def cache_eval(self):
        """Cache datasets for computing eval stats."""

        def collect_samples(dataset, name):
            """Return numpy arrays of all the samples from a dataset."""
            pbar = tqdm(desc='Caching %s examples' % name)
            it = dataset.batch(1).prefetch(16).make_one_shot_iterator().get_next()
            images, labels = [], []
            while 1:
                try:
                    v = self.session.run(it)
                except tf.errors.OutOfRangeError:
                    break
                images.append(v['image'])
                labels.append(v['label'])
                pbar.update()

            images = np.concatenate(images, axis=0)
            labels = np.concatenate(labels, axis=0)
            pbar.close()
            return images, labels

        if 'test' not in self.tmp.cache:
            self.tmp.cache.test = collect_samples(self.dataset.test.parse(), name='test')
            self.tmp.cache.valid = collect_samples(self.dataset.valid.parse(), name='valid')
            self.tmp.cache.train_labeled = collect_samples(self.dataset.train_labeled.take(10000).parse(),
                                                           name='train_labeled')

    def eval_stats(self, batch=None, feed_extra=None, classify_op=None, verbose=True):
        """Evaluate model on train, valid and test."""
        batch = batch or FLAGS.batch
        classify_op = self.ops.classify_op if classify_op is None else classify_op
        accuracies = []
        for subset in ('train_labeled', 'valid', 'test'):
            images, labels = self.tmp.cache[subset]
            predicted = []

            for x in range(0, images.shape[0], batch):
                p = self.session.run(
                    classify_op,
                    feed_dict={
                        self.ops.x: images[x:x + batch],
                        **(feed_extra or {})
                    })
                predicted.append(p)
            predicted = np.concatenate(predicted, axis=0)
            accuracies.append((predicted.argmax(1) == labels).mean() * 100)
        if verbose:
            self.train_print('kimg %-5d  accuracy train/valid/test  %.2f  %.2f  %.2f' %
                             tuple([self.tmp.step >> 10] + accuracies))
        return np.array(accuracies, 'f')

    def add_summaries(self, feed_extra=None, **kwargs):
        del kwargs

        def gen_stats(classify_op=None, verbose=True):
            return self.eval_stats(feed_extra=feed_extra, classify_op=classify_op, verbose=verbose)

        accuracies = tf.py_func(functools.partial(gen_stats), [], tf.float32)
        tf.summary.scalar('accuracy/train_labeled', accuracies[0])
        tf.summary.scalar('accuracy/valid', accuracies[1])
        tf.summary.scalar('accuracy', accuracies[2])
        if 'classify_raw' in self.ops:
            accuracies = tf.py_func(functools.partial(gen_stats,
                                                      classify_op=self.ops.classify_raw,
                                                      verbose=False), [], tf.float32)
            tf.summary.scalar('accuracy/raw/train_labeled', accuracies[0])
            tf.summary.scalar('accuracy/raw/valid', accuracies[1])
            tf.summary.scalar('accuracy/raw', accuracies[2])


In [None]:
# fully_supervised\lib\train.py
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from absl import flags
from tqdm import trange

# from libml import utils
# from libml.train import ClassifySemi

FLAGS = flags.FLAGS


class ClassifyFullySupervised(ClassifySemi):
    """Fully supervised classification.
    """

    def train_step(self, train_session, gen_labeled):
        x = gen_labeled()
        self.tmp.step = train_session.run([self.ops.train_op, self.ops.update_step],
                                          feed_dict={self.ops.xt: x['image'],
                                                     self.ops.label: x['label']})[1]

    def train(self, train_nimg, report_nimg):
        if FLAGS.eval_ckpt:
            self.eval_checkpoint(FLAGS.eval_ckpt)
            return
        batch = FLAGS.batch
        train_labeled = self.dataset.train_labeled.repeat().shuffle(FLAGS.shuffle).parse().augment()
        train_labeled = train_labeled.batch(batch).prefetch(16).make_one_shot_iterator().get_next()
        scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=FLAGS.keep_ckpt, pad_step_number=10))

        with tf.Session(config=get_config()) as sess:
            self.session = sess
            self.cache_eval()

        with tf.train.MonitoredTrainingSession(
                scaffold=scaffold,
                checkpoint_dir=self.checkpoint_dir,
                config=get_config(),
                save_checkpoint_steps=FLAGS.save_kimg << 10,
                save_summaries_steps=report_nimg - batch) as train_session:
            self.session = train_session._tf_sess()
            gen_labeled = self.gen_labeled_fn(train_labeled)
            self.tmp.step = self.session.run(self.step)
            while self.tmp.step < train_nimg:
                loop = trange(self.tmp.step % report_nimg, report_nimg, batch,
                              leave=False, unit='img', unit_scale=batch,
                              desc='Epoch %d/%d' % (1 + (self.tmp.step // report_nimg), train_nimg // report_nimg))
                for _ in loop:
                    self.train_step(train_session, gen_labeled)
                    while self.tmp.print_queue:
                        loop.write(self.tmp.print_queue.pop(0))
            while self.tmp.print_queue:
                print(self.tmp.print_queue.pop(0))


In [None]:
# cta\lib\train.py
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from absl import flags

# from fully_supervised.lib.train import ClassifyFullySupervised
# from libml import data
# from libml.augment import AugmentPoolCTA
# from libml.ctaugment import CTAugment
# from libml.train import ClassifySemi

FLAGS = flags.FLAGS

for name in list(flags.FLAGS):
    if name == 'adepth' or name == 'adecay' or name == 'ath':
      delattr(flags.FLAGS,name) 


flags.DEFINE_integer('adepth', 2, 'Augmentation depth.')
flags.DEFINE_float('adecay', 0.99, 'Augmentation decay.')
flags.DEFINE_float('ath', 0.80, 'Augmentation threshold.')


class CTAClassifySemi(ClassifySemi):
    """Semi-supervised classification."""
    AUGMENTER_CLASS = CTAugment
    AUGMENT_POOL_CLASS = AugmentPoolCTA

    @classmethod
    def cta_name(cls):
        return '%s_depth%d_th%.2f_decay%.3f' % (cls.AUGMENTER_CLASS.__name__,
                                                FLAGS.adepth, FLAGS.ath, FLAGS.adecay)

    def __init__(self, train_dir: str, dataset: DataSets, nclass: int, **kwargs):
        ClassifySemi.__init__(self, train_dir, dataset, nclass, **kwargs)
        self.augmenter = self.AUGMENTER_CLASS(FLAGS.adepth, FLAGS.ath, FLAGS.adecay)

    def gen_labeled_fn(self, data_iterator):
        def wrap():
            batch = self.session.run(data_iterator)
            batch['cta'] = self.augmenter
            batch['probe'] = True
            return batch

        return self.AUGMENT_POOL_CLASS(wrap)

    def gen_unlabeled_fn(self, data_iterator):
        def wrap():
            batch = self.session.run(data_iterator)
            batch['cta'] = self.augmenter
            batch['probe'] = False
            return batch

        return self.AUGMENT_POOL_CLASS(wrap)

    def train_step(self, train_session, gen_labeled, gen_unlabeled):
        x, y = gen_labeled(), gen_unlabeled()
        v = train_session.run([self.ops.classify_op, self.ops.train_op, self.ops.update_step],
                              feed_dict={self.ops.y: y['image'],
                                         self.ops.x: x['probe'],
                                         self.ops.xt: x['image'],
                                         self.ops.label: x['label']})
        self.tmp.step = v[-1]
        lx = v[0]
        for p in range(lx.shape[0]):
            error = lx[p]
            error[x['label'][p]] -= 1
            error = np.abs(error).sum()
            self.augmenter.update_rates(x['policy'][p], 1 - 0.5 * error)

    def eval_stats(self, batch=None, feed_extra=None, classify_op=None, verbose=True):
        """Evaluate model on train, valid and test."""
        batch = batch or FLAGS.batch
        classify_op = self.ops.classify_op if classify_op is None else classify_op
        accuracies = []
        for subset in ('train_labeled', 'valid', 'test'):
            images, labels = self.tmp.cache[subset]
            predicted = []

            for x in range(0, images.shape[0], batch):
                p = self.session.run(
                    classify_op,
                    feed_dict={
                        self.ops.x: images[x:x + batch],
                        **(feed_extra or {})
                    })
                predicted.append(p)
            predicted = np.concatenate(predicted, axis=0)
            accuracies.append((predicted.argmax(1) == labels).mean() * 100)
        if verbose:
            self.train_print('kimg %-5d  accuracy train/valid/test  %.2f  %.2f  %.2f' %
                             tuple([self.tmp.step >> 10] + accuracies))
        self.train_print(self.augmenter.stats())
        return np.array(accuracies, 'f')


class CTAClassifyFullySupervised(ClassifyFullySupervised, CTAClassifySemi):
    """Fully-supervised classification."""

    def train_step(self, train_session, gen_labeled):
        x = gen_labeled()
        v = train_session.run([self.ops.classify_op, self.ops.train_op, self.ops.update_step],
                              feed_dict={self.ops.x: x['probe'],
                                         self.ops.xt: x['image'],
                                         self.ops.label: x['label']})
        self.tmp.step = v[-1]
        lx = v[0]
        for p in range(lx.shape[0]):
            error = lx[p]
            error[x['label'][p]] -= 1
            error = np.abs(error).sum()
            self.augmenter.update_rates(x['policy'][p], 1 - 0.5 * error)


In [None]:
# libml\layers.py
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom neural network layers and primitives.
"""
import numbers

import numpy as np
import tensorflow as tf

# from libml.data import DataSets


def smart_shape(x):
    s, t = x.shape, tf.shape(x)
    return [t[i] if s[i].value is None else s[i] for i in range(len(s))]


def entropy_from_logits(logits):
    """Computes entropy from classifier logits.

    Args:
        logits: a tensor of shape (batch_size, class_count) representing the
        logits of a classifier.

    Returns:
        A tensor of shape (batch_size,) of floats giving the entropies
        batchwise.
    """
    distribution = tf.contrib.distributions.Categorical(logits=logits)
    return distribution.entropy()


def entropy_penalty(logits, entropy_penalty_multiplier, mask):
    """Computes an entropy penalty using the classifier logits.

    Args:
        logits: a tensor of shape (batch_size, class_count) representing the
            logits of a classifier.
        entropy_penalty_multiplier: A float by which the entropy is multiplied.
        mask: A tensor that optionally masks out some of the costs.

    Returns:
        The mean entropy penalty
    """
    entropy = entropy_from_logits(logits)
    losses = entropy * entropy_penalty_multiplier
    losses *= tf.cast(mask, tf.float32)
    return tf.reduce_mean(losses)


def kl_divergence_from_logits(logits_a, logits_b):
    """Gets KL divergence from logits parameterizing categorical distributions.

    Args:
        logits_a: A tensor of logits parameterizing the first distribution.
        logits_b: A tensor of logits parameterizing the second distribution.

    Returns:
        The (batch_size,) shaped tensor of KL divergences.
    """
    distribution1 = tf.contrib.distributions.Categorical(logits=logits_a)
    distribution2 = tf.contrib.distributions.Categorical(logits=logits_b)
    return tf.contrib.distributions.kl_divergence(distribution1, distribution2)


def mse_from_logits(output_logits, target_logits):
    """Computes MSE between predictions associated with logits.

    Args:
        output_logits: A tensor of logits from the primary model.
        target_logits: A tensor of logits from the secondary model.

    Returns:
        The mean MSE
    """
    diffs = tf.nn.softmax(output_logits) - tf.nn.softmax(target_logits)
    squared_diffs = tf.square(diffs)
    return tf.reduce_mean(squared_diffs, -1)


def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets


def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [tf.concat(v, axis=0) for v in xy]


def renorm(v):
    return v / tf.reduce_sum(v, axis=-1, keepdims=True)


def shakeshake(a, b, training):
    if not training:
        return 0.5 * (a + b)
    mu = tf.random_uniform([tf.shape(a)[0]] + [1] * (len(a.shape) - 1), 0, 1)
    mixf = a + mu * (b - a)
    mixb = a + mu[::1] * (b - a)
    return tf.stop_gradient(mixf - mixb) + mixb


class PMovingAverage:
    def __init__(self, name, nclass, buf_size):
        # MEAN aggregation is used by DistributionStrategy to aggregate
        # variable updates across shards
        self.ma = tf.Variable(tf.ones([buf_size, nclass]) / nclass,
                              trainable=False,
                              name=name,
                              aggregation=tf.VariableAggregation.MEAN)

    def __call__(self):
        v = tf.reduce_mean(self.ma, axis=0)
        return v / tf.reduce_sum(v)

    def update(self, entry):
        entry = tf.reduce_mean(entry, axis=0)
        return tf.assign(self.ma, tf.concat([self.ma[1:], [entry]], axis=0))


class PData:
    def __init__(self, dataset: DataSets):
        self.has_update = False
        if dataset.p_unlabeled is not None:
            self.p_data = tf.constant(dataset.p_unlabeled, name='p_data')
        elif dataset.p_labeled is not None:
            self.p_data = tf.constant(dataset.p_labeled, name='p_data')
        else:
            # MEAN aggregation is used by DistributionStrategy to aggregate
            # variable updates across shards
            self.p_data = tf.Variable(renorm(tf.ones([dataset.nclass])),
                                      trainable=False,
                                      name='p_data',
                                      aggregation=tf.VariableAggregation.MEAN)
            self.has_update = True

    def __call__(self):
        return self.p_data / tf.reduce_sum(self.p_data)

    def update(self, entry, decay=0.999):
        entry = tf.reduce_mean(entry, axis=0)
        return tf.assign(self.p_data, self.p_data * decay + entry * (1 - decay))


class MixMode:
    # A class for mixing data for various combination of labeled and unlabeled.
    # x = labeled example
    # y = unlabeled example
    # For example "xx.yxy" means: mix x with x, mix y with both x and y.
    MODES = 'xx.yy xxy.yxy xx.yxy xx.yx xx. .yy xxy. .yxy .'.split()

    def __init__(self, mode):
        assert mode in self.MODES
        self.mode = mode

    @staticmethod
    def augment_pair(x0, l0, x1, l1, beta, **kwargs):
        del kwargs
        if isinstance(beta, numbers.Integral) and beta <= 0:
            return x0, l0

        def np_beta(s, beta):  # TF implementation seems unreliable for beta below 0.2
            return np.random.beta(beta, beta, s).astype('f')

        with tf.device('/cpu'):
            mix = tf.py_func(np_beta, [tf.shape(x0)[0], beta], tf.float32)
            mix = tf.reshape(tf.maximum(mix, 1 - mix), [tf.shape(x0)[0], 1, 1, 1])
            index = tf.random_shuffle(tf.range(tf.shape(x0)[0]))
        xs = tf.gather(x1, index)
        ls = tf.gather(l1, index)
        xmix = x0 * mix + xs * (1 - mix)
        lmix = l0 * mix[:, :, 0, 0] + ls * (1 - mix[:, :, 0, 0])
        return xmix, lmix

    @staticmethod
    def augment(x, l, beta, **kwargs):
        return MixMode.augment_pair(x, l, x, l, beta, **kwargs)

    def __call__(self, xl: list, ll: list, betal: list):
        assert len(xl) == len(ll) >= 2
        assert len(betal) == 2
        if self.mode == '.':
            return xl, ll
        elif self.mode == 'xx.':
            mx0, ml0 = self.augment(xl[0], ll[0], betal[0])
            return [mx0] + xl[1:], [ml0] + ll[1:]
        elif self.mode == '.yy':
            mx1, ml1 = self.augment(
                tf.concat(xl[1:], 0), tf.concat(ll[1:], 0), betal[1])
            return (xl[:1] + tf.split(mx1, len(xl) - 1),
                    ll[:1] + tf.split(ml1, len(ll) - 1))
        elif self.mode == 'xx.yy':
            mx0, ml0 = self.augment(xl[0], ll[0], betal[0])
            mx1, ml1 = self.augment(
                tf.concat(xl[1:], 0), tf.concat(ll[1:], 0), betal[1])
            return ([mx0] + tf.split(mx1, len(xl) - 1),
                    [ml0] + tf.split(ml1, len(ll) - 1))
        elif self.mode == 'xxy.':
            mx, ml = self.augment(
                tf.concat(xl, 0), tf.concat(ll, 0),
                sum(betal) / len(betal))
            return (tf.split(mx, len(xl))[:1] + xl[1:],
                    tf.split(ml, len(ll))[:1] + ll[1:])
        elif self.mode == '.yxy':
            mx, ml = self.augment(
                tf.concat(xl, 0), tf.concat(ll, 0),
                sum(betal) / len(betal))
            return (xl[:1] + tf.split(mx, len(xl))[1:],
                    ll[:1] + tf.split(ml, len(ll))[1:])
        elif self.mode == 'xxy.yxy':
            mx, ml = self.augment(
                tf.concat(xl, 0), tf.concat(ll, 0),
                sum(betal) / len(betal))
            return tf.split(mx, len(xl)), tf.split(ml, len(ll))
        elif self.mode == 'xx.yxy':
            mx0, ml0 = self.augment(xl[0], ll[0], betal[0])
            mx1, ml1 = self.augment(tf.concat(xl, 0), tf.concat(ll, 0), betal[1])
            mx1, ml1 = [tf.split(m, len(xl))[1:] for m in (mx1, ml1)]
            return [mx0] + mx1, [ml0] + ml1
        elif self.mode == 'xx.yx':
            mx0, ml0 = self.augment(xl[0], ll[0], betal[0])
            mx1, ml1 = zip(*[
                self.augment_pair(xl[i], ll[i], xl[0], ll[0], betal[1])
                for i in range(1, len(xl))
            ])
            return [mx0] + list(mx1), [ml0] + list(ml1)
        raise NotImplementedError(self.mode)


In [None]:
# libml\models.py
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classifier architectures."""
import functools
import itertools

import tensorflow as tf
from absl import flags

# from libml import layers
# from libml.train import ClassifySemi
# from libml.utils import EasyDict


class CNN13(ClassifySemi):
    """Simplified reproduction of the Mean Teacher paper network. filters=128 in original implementation.
    Removed dropout, Gaussians, forked dense layers, basically all non-standard things."""

    def classifier(self, x, scales, filters, training, getter=None, **kwargs):
        del kwargs
        assert scales == 3  # Only specified for 32x32 inputs.
        conv_args = dict(kernel_size=3, activation=tf.nn.leaky_relu, padding='same')
        bn_args = dict(training=training, momentum=0.999)

        with tf.variable_scope('classify', reuse=tf.AUTO_REUSE, custom_getter=getter):
            y = tf.layers.conv2d((x - self.dataset.mean) / self.dataset.std, filters, **conv_args)
            y = tf.layers.batch_normalization(y, **bn_args)
            y = tf.layers.conv2d(y, filters, **conv_args)
            y = tf.layers.batch_normalization(y, **bn_args)
            y = tf.layers.conv2d(y, filters, **conv_args)
            y = tf.layers.batch_normalization(y, **bn_args)
            y = tf.layers.max_pooling2d(y, 2, 2)
            y = tf.layers.conv2d(y, 2 * filters, **conv_args)
            y = tf.layers.batch_normalization(y, **bn_args)
            y = tf.layers.conv2d(y, 2 * filters, **conv_args)
            y = tf.layers.batch_normalization(y, **bn_args)
            y = tf.layers.conv2d(y, 2 * filters, **conv_args)
            y = tf.layers.batch_normalization(y, **bn_args)
            y = tf.layers.max_pooling2d(y, 2, 2)
            y = tf.layers.conv2d(y, 4 * filters, kernel_size=3, activation=tf.nn.leaky_relu, padding='valid')
            y = tf.layers.batch_normalization(y, **bn_args)
            y = tf.layers.conv2d(y, 2 * filters, kernel_size=1, activation=tf.nn.leaky_relu, padding='same')
            y = tf.layers.batch_normalization(y, **bn_args)
            y = tf.layers.conv2d(y, 1 * filters, kernel_size=1, activation=tf.nn.leaky_relu, padding='same')
            y = tf.layers.batch_normalization(y, **bn_args)
            y = tf.reduce_mean(y, [1, 2])  # (b, 6, 6, 128) -> (b, 128)
            logits = tf.layers.dense(y, self.nclass)
        return EasyDict(logits=logits, embeds=y)


class ResNet(ClassifySemi):
    def classifier(self, x, scales, filters, repeat, training, getter=None, dropout=0, **kwargs):
        del kwargs
        leaky_relu = functools.partial(tf.nn.leaky_relu, alpha=0.1)
        bn_args = dict(training=training, momentum=0.999)

        def conv_args(k, f):
            return dict(padding='same',
                        kernel_initializer=tf.random_normal_initializer(stddev=tf.rsqrt(0.5 * k * k * f)))

        def residual(x0, filters, stride=1, activate_before_residual=False):
            x = leaky_relu(tf.layers.batch_normalization(x0, **bn_args))
            if activate_before_residual:
                x0 = x

            x = tf.layers.conv2d(x, filters, 3, strides=stride, **conv_args(3, filters))
            x = leaky_relu(tf.layers.batch_normalization(x, **bn_args))
            x = tf.layers.conv2d(x, filters, 3, **conv_args(3, filters))

            if x0.get_shape()[3] != filters:
                x0 = tf.layers.conv2d(x0, filters, 1, strides=stride, **conv_args(1, filters))

            return x0 + x

        with tf.variable_scope('classify', reuse=tf.AUTO_REUSE, custom_getter=getter):
            y = tf.layers.conv2d((x - self.dataset.mean) / self.dataset.std, 16, 3, **conv_args(3, 16))
            for scale in range(scales):
                y = residual(y, filters << scale, stride=2 if scale else 1, activate_before_residual=scale == 0)
                for i in range(repeat - 1):
                    y = residual(y, filters << scale)

            y = leaky_relu(tf.layers.batch_normalization(y, **bn_args))
            y = embeds = tf.reduce_mean(y, [1, 2])
            if dropout and training:
                y = tf.nn.dropout(y, 1 - dropout)
            logits = tf.layers.dense(y, self.nclass, kernel_initializer=tf.glorot_normal_initializer())
        return EasyDict(logits=logits, embeds=embeds)


class ShakeNet(ClassifySemi):
    def classifier(self, x, scales, filters, repeat, training, getter=None, dropout=0, **kwargs):
        del kwargs
        bn_args = dict(training=training, momentum=0.999)

        def conv_args(k, f):
            return dict(padding='same', use_bias=False,
                        kernel_initializer=tf.random_normal_initializer(stddev=tf.rsqrt(0.5 * k * k * f)))

        def residual(x0, filters, stride=1):
            def branch():
                x = tf.nn.relu(x0)
                x = tf.layers.conv2d(x, filters, 3, strides=stride, **conv_args(3, filters))
                x = tf.nn.relu(tf.layers.batch_normalization(x, **bn_args))
                x = tf.layers.conv2d(x, filters, 3, **conv_args(3, filters))
                x = tf.layers.batch_normalization(x, **bn_args)
                return x

            x = layers.shakeshake(branch(), branch(), training)

            if stride == 2:
                x1 = tf.layers.conv2d(tf.nn.relu(x0[:, ::2, ::2]), filters >> 1, 1, **conv_args(1, filters >> 1))
                x2 = tf.layers.conv2d(tf.nn.relu(x0[:, 1::2, 1::2]), filters >> 1, 1, **conv_args(1, filters >> 1))
                x0 = tf.concat([x1, x2], axis=3)
                x0 = tf.layers.batch_normalization(x0, **bn_args)
            elif x0.get_shape()[3] != filters:
                x0 = tf.layers.conv2d(x0, filters, 1, **conv_args(1, filters))
                x0 = tf.layers.batch_normalization(x0, **bn_args)

            return x0 + x

        with tf.variable_scope('classify', reuse=tf.AUTO_REUSE, custom_getter=getter):
            y = tf.layers.conv2d((x - self.dataset.mean) / self.dataset.std, 16, 3, **conv_args(3, 16))
            for scale, i in itertools.product(range(scales), range(repeat)):
                with tf.variable_scope('layer%d.%d' % (scale + 1, i)):
                    if i == 0:
                        y = residual(y, filters << scale, stride=2 if scale else 1)
                    else:
                        y = residual(y, filters << scale)

            y = embeds = tf.reduce_mean(y, [1, 2])
            if dropout and training:
                y = tf.nn.dropout(y, 1 - dropout)
            logits = tf.layers.dense(y, self.nclass, kernel_initializer=tf.glorot_normal_initializer())
        return EasyDict(logits=logits, embeds=embeds)


class MultiModel(CNN13, ResNet, ShakeNet):
    MODELS = ('cnn13', 'resnet', 'shake')
    MODEL_CNN13, MODEL_RESNET, MODEL_SHAKE = MODELS

    def augment(self, x, l, smoothing, **kwargs):
        del kwargs
        return x, l - smoothing * (l - 1. / self.nclass)

    def classifier(self, x, arch, **kwargs):
        if arch == self.MODEL_CNN13:
            return CNN13.classifier(self, x, **kwargs)
        elif arch == self.MODEL_RESNET:
            return ResNet.classifier(self, x, **kwargs)
        elif arch == self.MODEL_SHAKE:
            return ShakeNet.classifier(self, x, **kwargs)
        raise ValueError('Model %s does not exists, available ones are %s' % (arch, self.MODELS))


for name in list(flags.FLAGS):
    if name == 'arch':
      delattr(flags.FLAGS,name) 

flags.DEFINE_enum('arch', MultiModel.MODEL_RESNET, MultiModel.MODELS, 'Architecture.')


In [None]:
# cta/cta_remixmatch.py
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from absl import app
from absl import flags

from cta.lib.train import CTAClassifySemi
from libml import utils, data
from remixmatch_no_cta import ReMixMatch

FLAGS = flags.FLAGS


class CTAReMixMatch(ReMixMatch, CTAClassifySemi):
    pass


def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.MANY_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = CTAReMixMatch(
        os.path.join(FLAGS.train_dir, dataset.name, CTAReMixMatch.cta_name()),
        dataset,
        lr=FLAGS.lr,
        wd=FLAGS.wd,
        arch=FLAGS.arch,
        batch=FLAGS.batch,
        nclass=dataset.nclass,

        K=FLAGS.K,
        beta=FLAGS.beta,
        w_kl=FLAGS.w_kl,
        w_match=FLAGS.w_match,
        w_rot=FLAGS.w_rot,
        redux=FLAGS.redux,
        use_dm=FLAGS.use_dm,
        use_xe=FLAGS.use_xe,
        warmup_kimg=FLAGS.warmup_kimg,

        scales=FLAGS.scales or (log_width - 2),
        filters=FLAGS.filters,
        repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)


if __name__ == '__main__':
    utils.setup_tf()
    flags.DEFINE_float('wd', 0.02, 'Weight decay.')
    flags.DEFINE_float('beta', 0.75, 'Mixup beta distribution.')
    flags.DEFINE_float('w_kl', 0.5, 'Weight for KL loss.')
    flags.DEFINE_float('w_match', 1.5, 'Weight for distribution matching loss.')
    flags.DEFINE_float('w_rot', 0.5, 'Weight for rotation loss.')
    flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
    flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
    flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
    flags.DEFINE_integer('warmup_kimg', 1024, 'Unannealing duration for SSL loss.')
    flags.DEFINE_enum('redux', '1st', 'swap mean 1st'.split(), 'Logit selection.')
    flags.DEFINE_bool('use_dm', True, 'Whether to use distribution matching.')
    flags.DEFINE_bool('use_xe', True, 'Whether to use cross-entropy or Brier.')
    FLAGS.set_default('augment', 'd.d.d')
    FLAGS.set_default('dataset', 'cifar10.3@250-5000')
    FLAGS.set_default('batch', 64)
    FLAGS.set_default('lr', 0.002)
    FLAGS.set_default('train_kimg', 1 << 16)
    app.run(main)


In [None]:
# fixmatch.py
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os

import numpy as np
import tensorflow as tf
from absl import app
from absl import flags
from tqdm import trange

# from cta.cta_remixmatch import CTAReMixMatch
# from libml import data, utils, augment, ctaugment

FLAGS = flags.FLAGS


class AugmentPoolCTACutOut(augment.AugmentPoolCTA):
    @staticmethod
    def numpy_apply_policies(arglist):
        x, cta, probe = arglist
        if x.ndim == 3:
            assert probe
            policy = policy(probe=True)
            return dict(policy=policy,
                        probe=ctaugment.apply(x, policy),
                        image=x)
        assert not probe
        cutout_policy = lambda: cta.policy(probe=False) + [ctaugment.OP('cutout', (1,))]
        return dict(image=np.stack([x[0]] + [ctaugment.apply(y, cutout_policy()) for y in x[1:]]).astype('f'))


class FixMatch(CTAReMixMatch):
    if FLAGS.augment_improve:
      AUGMENT_POOL_CLASS = AugmentPoolCTACutOut
    else:
      AUGMENT_POOL_CLASS = None     # not to use augmentation improvement
    
    def train(self, train_nimg, report_nimg):
        if FLAGS.eval_ckpt:
            self.eval_checkpoint(FLAGS.eval_ckpt)
            return
        batch = FLAGS.batch
        train_labeled = self.dataset.train_labeled.repeat().shuffle(FLAGS.shuffle).parse().augment()
        train_labeled = train_labeled.batch(batch).prefetch(16).make_one_shot_iterator().get_next()
        train_unlabeled = self.dataset.train_unlabeled.repeat().shuffle(FLAGS.shuffle).parse().augment()
        train_unlabeled = train_unlabeled.batch(batch * self.params['uratio']).prefetch(16)
        train_unlabeled = train_unlabeled.make_one_shot_iterator().get_next()
        scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=FLAGS.keep_ckpt,
                                                          pad_step_number=10))

        with tf.Session(config=get_config()) as sess:
            self.session = sess
            self.cache_eval()

        with tf.train.MonitoredTrainingSession(
                scaffold=scaffold,
                checkpoint_dir=self.checkpoint_dir,
                config=get_config(),
                save_checkpoint_steps=FLAGS.save_kimg << 10,
                save_summaries_steps=report_nimg - batch) as train_session:
            self.session = train_session._tf_sess()
            gen_labeled = self.gen_labeled_fn(train_labeled)
            gen_unlabeled = self.gen_unlabeled_fn(train_unlabeled)
            self.tmp.step = self.session.run(self.step)
            while self.tmp.step < train_nimg:
                loop = trange(self.tmp.step % report_nimg, report_nimg, batch,
                              leave=False, unit='img', unit_scale=batch,
                              desc='Epoch %d/%d' % (1 + (self.tmp.step // report_nimg), train_nimg // report_nimg))
                for _ in loop:
                    self.train_step(train_session, gen_labeled, gen_unlabeled)
                    while self.tmp.print_queue:
                        loop.write(self.tmp.print_queue.pop(0))
            while self.tmp.print_queue:
                print(self.tmp.print_queue.pop(0))

    def model(self, batch, lr, wd, wu, confidence, uratio, ema=0.999, **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
        xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt')  # Training labeled
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')  # Eval images
        y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y')  # Training unlabeled (weak, strong)
        l_in = tf.placeholder(tf.int32, [batch], 'labels')  # Labels

        lrate = tf.clip_by_value(tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1)
        lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8))
        tf.summary.scalar('monitors/lr', lr)

        # Compute logits for xt_in and y_in
        classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        x = interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0), 2 * uratio + 1)
        logits = para_cat(lambda x: classifier(x, training=True), x)
        logits = de_interleave(logits, 2 * uratio+1)
        post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
        logits_x = logits[:batch]
        logits_weak, logits_strong = tf.split(logits[batch:], 2)
        del logits, skip_ops

        # Labeled cross-entropy
        loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=l_in, logits=logits_x)
        loss_xe = tf.reduce_mean(loss_xe)
        tf.summary.scalar('losses/xe', loss_xe)

        # Pseudo-label cross entropy for unlabeled data
        pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak))
        loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(pseudo_labels, axis=1),
                                                                  logits=logits_strong)
        pseudo_mask = tf.to_float(tf.reduce_max(pseudo_labels, axis=1) >= confidence)
        tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask))
        loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask)
        tf.summary.scalar('losses/xeu', loss_xeu)

        # L2 regularization
        loss_wd = sum(tf.nn.l2_loss(v) for v in model_vars('classify') if 'kernel' in v.name)
        tf.summary.scalar('losses/wd', loss_wd)

        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(model_vars())
        ema_getter = functools.partial(getter_ema, ema)
        post_ops.append(ema_op)

        train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize(
            loss_xe + wu * loss_xeu + wd * loss_wd, colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        return EasyDict(
            xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op,
            classify_raw=tf.nn.softmax(classifier(x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))


def main(argv):
    setup_main()
    del argv  # Unused.
    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = ilog2(dataset.width)
    model = FixMatch(
        os.path.join(FLAGS.train_dir, dataset.name, FixMatch.cta_name()),
        dataset,
        lr=FLAGS.lr,
        wd=FLAGS.wd,
        arch=FLAGS.arch,
        batch=FLAGS.batch,
        nclass=dataset.nclass,
        wu=FLAGS.wu,
        confidence=FLAGS.confidence,
        uratio=FLAGS.uratio,
        scales=FLAGS.scales or (log_width - 2),
        filters=FLAGS.filters,
        repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)

def getFixMatch_model(INIT_LEARNING_RATE = 0.03, batch_size = 256, flag_AugmentImprove = 0):
  setup_tf()
  flags.DEFINE_bool('augment_improve',bool(flag_AugmentImprove),'A flag indicates if to use augmentation improvement')
  flags.DEFINE_float('confidence', 0.95, 'Confidence threshold.')
  flags.DEFINE_float('wd', 0.0005, 'Weight decay.')
  flags.DEFINE_float('wu', 1, 'Pseudo label loss weight.')
  flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
  flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
  flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
  flags.DEFINE_integer('uratio', 7, 'Unlabeled batch size ratio.')
  FLAGS.set_default('augment', 'd.d.d')
  FLAGS.set_default('dataset', 'cifar10.3@250-1')
  FLAGS.set_default('batch', batch_size)
  FLAGS.set_default('lr', INIT_LEARNING_RATE)
  FLAGS.set_default('train_kimg', 1 << 16)
  
  model = FixMatch(
        os.path.join(FLAGS.train_dir, dataset.name, FixMatch.cta_name()),
        dataset,
        lr=FLAGS.lr,
        wd=FLAGS.wd,
        arch=FLAGS.arch,
        batch=FLAGS.batch,
        nclass=dataset.nclass,
        wu=FLAGS.wu,
        confidence=FLAGS.confidence,
        uratio=FLAGS.uratio,
        scales=FLAGS.scales or (log_width - 2),
        filters=FLAGS.filters,
        repeat=FLAGS.repeat)
    # model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)

    return model

# if __name__ == '__main__':
#   setup_tf()
#   flags.DEFINE_float('confidence', 0.95, 'Confidence threshold.')
#   flags.DEFINE_float('wd', 0.0005, 'Weight decay.')
#   flags.DEFINE_float('wu', 1, 'Pseudo label loss weight.')
#   flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
#   flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
#   flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
#   flags.DEFINE_integer('uratio', 7, 'Unlabeled batch size ratio.')
#   FLAGS.set_default('augment', 'd.d.d')
#   FLAGS.set_default('dataset', 'cifar10.3@250-1')
#   FLAGS.set_default('batch', 64)
#   FLAGS.set_default('lr', 0.03)
#   FLAGS.set_default('train_kimg', 1 << 16)
#   app.run(main)


#Stage 3 - CNN #

In [None]:
import keras
from keras.datasets import cifar10
from keras.datasets import cifar100
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, ZeroPadding2D
from keras.layers import Dense, Dropout, Activation, Flatten
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import to_categorical

# (x_train, y_train), (x_test, y_test) = mnist.load_data()
# (x_train, y_train), (x_test, y_test) = cifar10.load_data()
# (x_train, y_train), (x_test, y_test) = cifar100.load_data()

# NUM_CLASSES = 100 # NUM_CLASSES
# y_train = to_categorical(y_train)
# y_test = to_categorical(y_test)

# x_train = x_train.astype('float32')
# x_test = x_test.astype('float32')
# x_train /= 255.0
# x_test /= 255.0

def getCNN_model(x_train, INIT_DROPOUT_RATE = 0.5, MOMENTUM_RATE = 0.9, INIT_LEARNING_RATE = 0.01, L2_DECAY_RATE = 0.0005, batch_size = 256):

  model = Sequential()
  model.add(ZeroPadding2D(4, input_shape=x_train.shape[1:]))
  # Stack 1:
  model.add(Conv2D(384, (3, 3), padding='same', kernel_regularizer=l2(0.01)))
  model.add(Activation('elu'))
  model.add(MaxPooling2D(pool_size=(2, 2), padding='same'))
  model.add(Dropout(INIT_DROPOUT_RATE))
  # Stack 2:
  model.add(Conv2D(384, (1, 1), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Conv2D(384, (2, 2), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Conv2D(640, (2, 2), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Conv2D(640, (2, 2), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Activation('elu'))
  model.add(MaxPooling2D(pool_size=(2, 2), padding='same'))
  model.add(Dropout(INIT_DROPOUT_RATE))
  # Stack 3:
  model.add(Conv2D(640, (3, 3), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Conv2D(768, (2, 2), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Conv2D(768, (2, 2), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Conv2D(768, (2, 2), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Activation('elu'))
  model.add(MaxPooling2D(pool_size=(2, 2), padding='same'))
  model.add(Dropout(INIT_DROPOUT_RATE))
  # Stack 4:
  model.add(Conv2D(768, (1, 1), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Conv2D(896, (2, 2), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Conv2D(896, (2, 2), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Activation('elu'))
  model.add(MaxPooling2D(pool_size=(2, 2), padding='same'))
  model.add(Dropout(INIT_DROPOUT_RATE))
  # Stack 5:
  model.add(Conv2D(896, (3, 3), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Conv2D(1024, (2, 2), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Conv2D(1024, (2, 2), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Activation('elu'))
  model.add(MaxPooling2D(pool_size=(2, 2), padding='same'))
  model.add(Dropout(INIT_DROPOUT_RATE))
  # Stack 6:
  model.add(Conv2D(1024, (1, 1), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Conv2D(1152, (2, 2), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Activation('elu'))
  model.add(MaxPooling2D(pool_size=(2, 2), padding='same'))
  model.add(Dropout(INIT_DROPOUT_RATE))
  # Stack 7:
  model.add(Conv2D(1152, (1, 1), padding='same', kernel_regularizer=l2(L2_DECAY_RATE)))
  model.add(Activation('elu'))
  model.add(MaxPooling2D(pool_size=(2, 2), padding='same'))
  model.add(Dropout(INIT_DROPOUT_RATE))
  model.add(Flatten())
  model.add(Dense(NUM_CLASSES))
  model.add(Activation('softmax'))

  print(model.summary())

  # epochs = 100

  model.compile(optimizer='adagrad',loss='categorical_crossentropy', metrics=['accuracy'])

  # model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data= (x_test, y_test))

  return model


# Stage 4 - Evaluation
---
 Evaluating the Models accuracy using different metrics 


In [None]:
#Evaluation Protocol - External 10-fold cross validation + internal 3-fold cross validation 

from sklearn.model_selection import RandomizedSearchCV
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.model_selection import KFold

# The function receives the model as a parameter, 
# model is one out of 3 possible values: FixMatch, FixMatch+Improve, CNN
# function also receives the Training data (images) and the corresponding lables
def nestedCrossValidationAlgorithm(X, y, model):
    maxF1 = 0
    i = 1
    cv_outer = KFold(n_splits=10, shuffle=True, random_state=1)
    for train_ix, test_ix in cv_outer.split(X):  # outer loop for K = 10
        print('iteration ', i, 'begin')
        i += 1
        # split data
        X_train, X_test = X[train_ix, :], X[test_ix, :]
        y_train, y_test = y[train_ix], y[test_ix]

        # configure the cross-validation procedure
        cv_inner = KFold(n_splits=3, shuffle=True, random_state=1)
        space = dict()
        space['mu'] = [4, 5, 6, 7 ,8]
        space['lr']= [0.2, 0.002, 0.02, 0.03, 0.3, 0.003,  0.04, 0.4, 0.004] 
        space['batch_size'] = [64, 128, 256]

        # define search
        search = RandomizedSearchCV(model, space, n_iter = 50, scoring='f1_macro', cv=cv_inner, refit=True)

        # execute search
        result = search.fit(X_train, y_train)

        # get the best performing model fit on the whole training set
        best_model = result.best_estimator_

        # evaluate model on the hold out dataset
        yEstimation = best_model.predict(X_test)

        # evaluate the model
        acc = accuracy_score(y_test, yEstimation.round())
        f1score = f1_score(y_test, yEstimation,average='macro')

        if (maxF1 < f1score):
            maxF1 = f1score
            superBestModel = best_model

        # report progress
        # print('>acc=%.3f, est=%.3f, cfg=%s' % (acc, result.best_score_, result.best_params_))

    # summarize the estimated performance of the model
    return superBestModel

# train the best model and evaluate its results 
def trainTheModelAndEvalute(X, y, model):
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
  superBestModel = nestedCrossValidationAlgorithm(X, y, model)
  superBestModel.fit(X_train, y_train)
  y_bestEst = superBestModel.predict(X_test) #probabilities vector

  # get the real classification labels
  y_true = np.argmax(y_test) #y_test is a one hot vector
  y_pred = np.argmax(y_test)

  # Performance metrics for evaluation
  cm = confusion_matrix(y_true, y_pred)
  tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

  accuracy = accuracy_score(y_true, y_pred)
  TPR = (tp)/(tp + fn)
  FPR = (fp)/(fp + tn)
  Precision = precision_score(y_true, y_pred)
  AUC = roc_auc_score(y_true, y_pred)
  PR_curve = average_precision_score(y_true, y_pred)
  
  return accuracy, TPR, FPR, Precision, AUC, PR_curve


In [None]:
# currentModel = getFixMatch_model(INIT_LEARNING_RATE = 0.03, batch_size = 256)

# currentModel = getFixMatch_model()    # Augmentation feature

NUM_CLASSES = 5
numDatasets = 20

for iDataset in range(numDatasets):
  X, y = splitSubSetData(iDataset)
  currentModel = getCNN_model(X)
  superBestModel = nestedCrossValidationAlgorithm(X, y, currentModel)
  accuracy, TPR, FPR, Precision, AUC, PR_curve = trainTheModelAndEvalute(X, y, superBestModel)
  
  print("iDataset: ", iDataset)
  print("accuracy: " , accuracy)
  print("TPR: " , TPR)
  print("FPR: " , FPR)
  print("Precision: " , Precision)
  print("AUC: " , AUC)
  print("PR_curve: " , PR_curve)



## Significant test 

In [None]:
from scipy.stats import friedmanchisquare
from scipy.stats import f_oneway
from statsmodels.stats.multicomp import pairwise_tukeyhsd


# compare distributions of the 3 different models
def friedman_test(Alg1, Alg2, Alg3):
  # compare samples
  return friedmanchisquare(Alg1, Alg2, Alg3)


# calculates the differences between the distributions in terms of p-value and mean difference
def post_hoc_test(Alg1, Alg2, Alg3):
  scores = Alg1
  scores.extend(Alg2)
  scores.extend(Alg3)
  #create DataFrame to hold data
  df = pd.DataFrame({'score': scores,group': np.repeat(['Alg1', 'Alg2', 'Alg3'], repeats=10)}) 
  # perform Tukey's test
  tukey = pairwise_tukeyhsd(endog=df['score'], groups=df['group'], alpha=0.05)
  #display results
  print(tukey)

# we have saved the results of each algorithm manually in excel, and load it from drive in order to calculate the significant tests
df1 = pd.read_excel('/content/drive/MyDrive/MLAssignment/Alg1_results.xlsx', header=None).iloc[:,:-1]
df2 = pd.read_excel('/content/drive/MyDrive/MLAssignment/Alg2_results.xlsx', header=None).iloc[:,:-1]
df3 = pd.read_excel('/content/drive/MyDrive/MLAssignment/Alg3_results.xlsx', header=None).iloc[:,:-1]

df1.columns= ['Dataset Name', 'Algorithm Name', 'Cross Validation [1-10]', 'HyperParamaters Values', 'ACC', 'TPR', 'FPR', 'Precision','AUC', 'PR-CURVE', 'Training Time', 'Inference Time']
df2.columns= ['Dataset Name', 'Algorithm Name', 'Cross Validation [1-10]', 'HyperParamaters Values', 'ACC', 'TPR', 'FPR', 'Precision','AUC', 'PR-CURVE', 'Training Time', 'Inference Time']
df3.columns= ['Dataset Name', 'Algorithm Name', 'Cross Validation [1-10]', 'HyperParamaters Values', 'ACC', 'TPR', 'FPR', 'Precision','AUC', 'PR-CURVE', 'Training Time', 'Inference Time']

metric = 'ACC'
# metric = 'AUC'

# take the results of the metric in the dataset for each algorithm
Alg1 = df1[metric].values.tolist()
Alg2 = df2[metric].values.tolist()
Alg3 = df3[metric].values.tolist()

# apply the firdeman test
stat, p = friedman_test(Alg1, Alg2, Alg3)

# interpret - if the distributions are different then apply also post_hoc test to measure the difference
alpha = 0.05
if p > alpha:
  print('Same distributions (fail to reject H0)')
else:
  print('Different distributions (reject H0)')
  print(post_hoc_test(Alg1, Alg2, Alg3))