<a href="https://colab.research.google.com/github/mimilazarova/dd2412_project_fixmatch_and_beyond/blob/main/src/FixMatch_mimi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
pip install tensorflow-addons==0.11.1

Collecting tensorflow-addons==0.11.1
[?25l  Downloading https://files.pythonhosted.org/packages/29/51/8e5bb7649ac136292aefef6ea0172d10cc23a26dcda093c62637585bc05e/tensorflow_addons-0.11.1-cp36-cp36m-manylinux2010_x86_64.whl (1.1MB)
[K     |▎                               | 10kB 18.3MB/s eta 0:00:01[K     |▋                               | 20kB 20.2MB/s eta 0:00:01[K     |█                               | 30kB 13.0MB/s eta 0:00:01[K     |█▏                              | 40kB 9.0MB/s eta 0:00:01[K     |█▌                              | 51kB 8.1MB/s eta 0:00:01[K     |█▉                              | 61kB 8.2MB/s eta 0:00:01[K     |██                              | 71kB 8.5MB/s eta 0:00:01[K     |██▍                             | 81kB 8.3MB/s eta 0:00:01[K     |██▊                             | 92kB 7.9MB/s eta 0:00:01[K     |███                             | 102kB 8.6MB/s eta 0:00:01[K     |███▎                            | 112kB 8.6MB/s eta 0:00:01[K     |███▋

In [5]:
# All imports here
import tensorflow_probability as tfp
import numpy as np
from PIL import Image, ImageOps, ImageEnhance, ImageFilter
import os
import tensorflow as tf
import logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
logging.getLogger('tensorflow').disabled = True
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
from tqdm import tqdm, tqdm_notebook
from matplotlib import pyplot as plt
import json

In [12]:
# this cell is for the wide ResNet
def regularized_padded_conv(*args, **kwargs):
    return tf.keras.layers.Conv2D(*args, **kwargs, padding='same', kernel_regularizer=_regularizer,
                                  kernel_initializer='he_normal', use_bias=False)


def BN_ReLU(x):
    x = tf.keras.layers.BatchNormalization()(x)
    return tf.keras.layers.ReLU()(x)


def shortcut(x, filters, stride, mode):
    if x.shape[-1] == filters:
        return x
    elif mode == 'B':
        return regularized_padded_conv(filters, 1, strides=stride)(x)
    elif mode == 'B_original':
        x = regularized_padded_conv(filters, 1, strides=stride)(x)
        return tf.keras.layers.BatchNormalization()(x)
    elif mode == 'A':
        return tf.pad(tf.keras.layers.MaxPool2D(1, stride)(x) if stride>1 else x,
                      paddings=[(0, 0), (0, 0), (0, 0), (0, filters - x.shape[-1])])
    else:
        raise KeyError("Parameter shortcut_type not recognized!")
    

def original_block(x, filters, stride=1, **kwargs):
    c1 = regularized_padded_conv(filters, 3, strides=stride)(x)
    c2 = regularized_padded_conv(filters, 3)(BN_ReLU(c1))
    c2 = tf.keras.layers.BatchNormalization()(c2)
    
    mode = 'B_original' if _shortcut_type == 'B' else _shortcut_type
    x = shortcut(x, filters, stride, mode=mode)
    return tf.keras.layers.ReLU()(x + c2)
    
    
def preactivation_block(x, filters, stride=1, preact_block=False):
    flow = BN_ReLU(x)
    if preact_block:
        x = flow
        
    c1 = regularized_padded_conv(filters, 3, strides=stride)(flow)
    if _dropout:
        c1 = tf.keras.layers.Dropout(_dropout)(c1)
        
    c2 = regularized_padded_conv(filters, 3)(BN_ReLU(c1))
    x = shortcut(x, filters, stride, mode=_shortcut_type)
    return x + c2


def bootleneck_block(x, filters, stride=1, preact_block=False):
    flow = BN_ReLU(x)
    if preact_block:
        x = flow
         
    c1 = regularized_padded_conv(filters//_bootleneck_width, 1)(flow)
    c2 = regularized_padded_conv(filters//_bootleneck_width, 3, strides=stride)(BN_ReLU(c1))
    c3 = regularized_padded_conv(filters, 1)(BN_ReLU(c2))
    x = shortcut(x, filters, stride, mode=_shortcut_type)
    return x + c3


def group_of_blocks(x, block_type, num_blocks, filters, stride, block_idx=0):
    global _preact_shortcuts
    preact_block = True if _preact_shortcuts or block_idx == 0 else False
    
    x = block_type(x, filters, stride, preact_block=preact_block)
    for i in range(num_blocks-1):
        x = block_type(x, filters)
    return x


def Resnet(input_shape, n_classes, l2_reg=1e-4, group_sizes=(2, 2, 2), features=(16, 32, 64), strides=(1, 2, 2),
           shortcut_type='B', block_type='preactivated', first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
           dropout=0, cardinality=1, bootleneck_width=4, preact_shortcuts=True):
    
    global _regularizer, _shortcut_type, _preact_projection, _dropout, _cardinality, _bootleneck_width, _preact_shortcuts
    _bootleneck_width = bootleneck_width # used in ResNeXts and bootleneck blocks
    _regularizer = tf.keras.regularizers.l2(l2_reg)
    _shortcut_type = shortcut_type # used in blocks
    _cardinality = cardinality # used in ResNeXts
    _dropout = dropout # used in Wide ResNets
    _preact_shortcuts = preact_shortcuts
    
    block_types = {'preactivated': preactivation_block,
                   'bootleneck': bootleneck_block,
                   'original': original_block}
    
    selected_block = block_types[block_type]
    inputs = tf.keras.layers.Input(shape=input_shape)
    flow = regularized_padded_conv(**first_conv)(inputs)
    
    if block_type == 'original':
        flow = BN_ReLU(flow)
    
    for block_idx, (group_size, feature, stride) in enumerate(zip(group_sizes, features, strides)):
        flow = group_of_blocks(flow,
                               block_type=selected_block,
                               num_blocks=group_size,
                               block_idx=block_idx,
                               filters=feature,
                               stride=stride)
    
    if block_type != 'original':
        flow = BN_ReLU(flow)
    
    flow = tf.keras.layers.GlobalAveragePooling2D()(flow)
    outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=_regularizer)(flow)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model


def load_weights_func(model, model_name):
    try: model.load_weights(os.path.join('saved_models', model_name + '.tf'))
    except tf.errors.NotFoundError: print("No weights found for this model!")
    return model


def cifar_wide_resnet(N, K, block_type='preactivated', shortcut_type='B', dropout=0, l2_reg=2.5e-4):
    assert (N-4) % 6 == 0, "N-4 has to be divisible by 6"
    lpb = (N-4) // 6 # layers per block - since N is total number of convolutional layers in Wide ResNet
    model = Resnet(input_shape=(32, 32, 3), n_classes=10, l2_reg=l2_reg, group_sizes=(lpb, lpb, lpb), features=(16*K, 32*K, 64*K),
                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type,
                   block_type=block_type, dropout=dropout, preact_shortcuts=True)
    return model


def WRN_28_2(shortcut_type='B', load_weights=False, dropout=0, l2_reg=2.5e-4):
    model = cifar_wide_resnet(28, 2, 'preactivated', shortcut_type, dropout=dropout, l2_reg=l2_reg)
    if load_weights: model = load_weights_func(model, 'cifar_WRN_28_10')
    return model

In [None]:
# This cell is for CTAugment

class CTAugment:

  def __init__(self, n_classes, decay=0.99, threshold=0.80, depth=2, n_bins=17):
    self.decay = decay
    self.threshold = threshold
    self.depth = depth
    self.n_bins = n_bins
    self.n_classes = n_classes
    self.xforms = []
    # self.bins = [[]]
    # self.weights = [[]]

    self.AUG_DICT = {
        "autocontrast": {"f": self.autocontrast, "weight": [np.ones(self.n_bins)*1.0]},
        "blur": {"f": self.blur, "weight": [np.ones(self.n_bins)*1.0]},
        "brightness": {"f": self.brightness, "weight":[np.ones(self.n_bins)*1.0]},
        "color": {"f": self.color, "weight": [np.ones(self.n_bins)*1.0]},
        "contrast": {"f": self.contrast, "weight": [np.ones(self.n_bins)*1.0]},
        "cutout": {"f": self.cutout, "weight": [np.ones(self.n_bins)*1.0]},
        "equalize": {"f": self.equalize, "weight": [np.ones(self.n_bins)*1.0]},
        "invert": {"f": self.invert, "weight": [np.ones(self.n_bins)*1.0]},
        "identity": {"f": self.identity, "weight": [np.ones(self.n_bins)*1.0]},
        "posterize": {"f": self.posterize, "weight": [np.ones(self.n_bins)*1.0]},
        "rescale": {"f": self.rescale, "weight": [np.ones(self.n_bins)*1.0, np.ones(6)*1.0]},
        "rotate": {"f": self.rotate, "weight": [np.ones(self.n_bins)*1.0]},
        "sharpness": {"f": self.sharpness, "weight": [np.ones(self.n_bins)*1.0]},
        "shear_x": {"f": self.shear_x, "weight": [np.ones(self.n_bins)*1.0]},
        "shear_y": {"f": self.shear_y, "weight": [np.ones(self.n_bins)*1.0]},
        "smooth": {"f": self.smooth, "weight": [np.ones(self.n_bins)*1.0]},
        "solarize": {"f": self.solarize, "weight": [np.ones(self.n_bins)*1.0]},
        "translate_x": {"f": self.translate_x, "weight": [np.ones(self.n_bins)*1.0]},
        "translate_y": {"f": self.translate_y, "weight": [np.ones(self.n_bins)*1.0]}
    }
    self.N = len(self.AUG_DICT.keys())
    self.options = list(self.AUG_DICT.keys())

    self.batch_choices = []
    self.batch_bins = []

  def weight_to_p(self, weight):
        p = weight + (1 - self.decay)  # Avoid to have all zero.
        p = p / p.max()
        p[p < self.threshold] = 0
        return p/np.sum(p)

  def augment(self, x, uniform_bin_sampling=False):
    aug_x = Image.fromarray(np.uint8(x))#255*x))

    choices = [self.options[i] for i in np.random.choice(np.arange(self.N), self.depth, replace=False)]
    bins = []

    for k in range(self.depth):
        choice_key = choices[k]
        
        transformation = self.AUG_DICT[choice_key]["f"]
        # sample bins
        if uniform_bin_sampling:
          p = np.ones(self.N)/self.N
        else:
          w = self.AUG_DICT[choice_key]["weight"][0]
          p = self.weight_to_p(w)
        curr_bins = {}
        curr_bins["bin"] = np.random.choice(np.arange(self.n_bins), p=p)

        if choice_key=="rescale":
          if uniform_bin_sampling:
            p = np.ones(6)/6
          else:
            w = self.AUG_DICT[choice_key]["weight"][1]
            p = self.weight_to_p(w)
          curr_bins["bin2"] = np.random.choice(np.arange(6), p=p)

        aug_x = transformation(aug_x, **curr_bins)
        bins.append(curr_bins)

    return np.array(aug_x), choices, bins

  def augment_batch(self, batch):
    aug_batch = tf.identity(batch)

    #aug_batch = tf.map_fn(aug_batch, self.augment)
    batch_choices = []
    batch_bins = []
    
    if batch.ndim == 3:
      sample, choices, bins = self.augment(sample)
      batch_choices.append(choices)
      batch_bins.append(bins)
    elif batch.ndim == 4:
      for sample in aug_batch:
        sample, choices, bins = self.augment(sample)
        batch_choices.append(choices)
        batch_bins.append(bins)

    return aug_batch, batch_choices, batch_bins

  def update_weights(self, label, pred, choices, bins):
    omega = 1 - (1 / (2*self.n_classes)) * np.sum(tf.math.abs(label - pred))

    for k, choice in enumerate(choices):
      w = self.AUG_DICT[choice]["weight"][0]
      bin = bins[k]["bin"]
      self.AUG_DICT[choice]["weight"][0][bin] = self.decay * w[bin] + (1 - self.decay) * omega
      # print(self.AUG_DICT[choice]["weight"][0])
      if choices[k] == "rescale":
        w = self.AUG_DICT[choice]["weight"][1]
        bin = bins[k]["bin2"]
        self.AUG_DICT[choice]["weight"][1][bin] = self.decay * w[bin] + (1 - self.decay) * omega



  def update_weights_batch(self, labels, preds, choices, bins):
    [self.update_weights(l, p, c, b) for l, p, c, b in zip(labels, preds, choices, bins)]

  def get_param(self, r_min, r_max, bin):
      possible_value = np.linspace(r_min, r_max, self.n_bins)
      return possible_value[bin]

  def autocontrast(self, x, bin):
      param = self.get_param(0, 1, bin)
      return Image.blend(x, ImageOps.autocontrast(x), param)
  
  def blur(self, x, bin):
      param = self.get_param(0, 1, bin)
      return Image.blend(x, x.filter(ImageFilter.BLUR), param)
  
  def brightness(self, x, bin):
      param = self.get_param(0, 1, bin)
      return ImageEnhance.Brightness(x).enhance(0.1 + 1.9*param)

  def color(self, x, bin):
      param = self.get_param(0, 1, bin)
      return ImageEnhance.Color(x).enhance(0.1 + 1.9*param)

  def contrast(self, x, bin):
      param = self.get_param(0, 1, bin)
      return ImageEnhance.Contrast(x).enhance(0.1 + 1.9*param)

  def cutout(self, x, bin):
    """Taken directlly from FixMatch code"""
    level = self.get_param(0, 0.5, bin)

    size = 1 + int(level * min(x.size) * 0.499)
    img_height, img_width = x.size
    height_loc = np.random.randint(low=0, high=img_height)
    width_loc = np.random.randint(low=0, high=img_width)
    upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
    lower_coord = (min(img_height, height_loc + size // 2), min(img_width, width_loc + size // 2))
    pixels = x.load()  # create the pixel map
    for i in range(upper_coord[0], lower_coord[0]):  # for every col:
        for j in range(upper_coord[1], lower_coord[1]):  # For every row
            pixels[i, j] = (127, 127, 127)  # set the color accordingly
    return x

  def equalize(self, x, bin):
      param = self.get_param(0, 1, bin)
      return Image.blend(x, ImageOps.equalize(x), param)

  def invert(self, x, bin):
      param = self.get_param(0, 1, bin)
      return Image.blend(x, ImageOps.invert(x), param)
  
  def identity(self, x, bin):
      return x

  def posterize(self, x, bin):
      param = int(self.get_param(0, 8, bin))
      return ImageOps.posterize(x, param)

  def rescale(self, x, bin, bin2):
      param = self.get_param(0.5, 1, bin)
      methods = (Image.ANTIALIAS, Image.BICUBIC, Image.BILINEAR, Image.BOX, Image.HAMMING, Image.NEAREST)
      method = methods[bin2]
      s = x.size
      scale = param*0.25
      crop = (scale * s[0], scale * s[1], s[0] * (1 - scale), s[1] * (1 - scale))
      return x.crop(crop).resize(x.size, method)

  def rotate(self, x, bin):
      param = self.get_param(-45, 45, bin)
      angle = int(np.round((2 * param - 1) * 45))
      return x.rotate(angle)

  def sharpness(self, x, bin):
      param = self.get_param(0, 1, bin)
      return ImageEnhance.Sharpness(x).enhance(0.1 + 1.9*param)

  def shear_x(self, x, bin):
      param = self.get_param(-0.3, 0.3, bin)
      shear = (2 * param - 1) * 0.3
      return x.transform(x.size, Image.AFFINE, (1, shear, 0, 0, 1, 0))

  def shear_y(self, x, bin):
      param = self.get_param(-0.3, 0.3, bin)
      shear = (2 * param - 1) * 0.3
      return x.transform(x.size, Image.AFFINE, (1, 0, 0, shear, 1, 0))

  def smooth(self, x, bin):
      param = self.get_param(0, 1, bin)
      return Image.blend(x, x.filter(ImageFilter.SMOOTH), param)

  def solarize(self, x, bin):
      param = self.get_param(0, 1, bin)
      th = int(param * 255.999)
      return ImageOps.solarize(x, th)

  def translate_x(self, x, bin):
      param = self.get_param(-0.3, 0.3, bin)
      delta = (2 * param - 1) * 0.3
      return x.transform(x.size, Image.AFFINE, (1, 0, delta, 0, 1, 0))

  def translate_y(self, x, bin):
      param = self.get_param(-0.3, 0.3, bin)
      delta = (2 * param - 1) * 0.3
      return x.transform(x.size, Image.AFFINE, (1, 0, 0, 0, 1, delta))


class OurCosineDecay(tf.keras.experimental.CosineDecay):

  def __call__(self, step):
    print("HEJ")
    with ops.name_scope_v2(self.name or "CosineDecay"):
      initial_learning_rate = ops.convert_to_tensor_v2(
          self.initial_learning_rate, name="initial_learning_rate")
      dtype = initial_learning_rate.dtype
      decay_steps = math_ops.cast(self.decay_steps, dtype)

      global_step_recomp = math_ops.cast(step, dtype)
      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
      completed_fraction = global_step_recomp / decay_steps
      cosine_decayed = math_ops.cos(
          constant_op.constant(7/16 * math.pi) * completed_fraction)

      decayed = (1 - self.alpha) * cosine_decayed + self.alpha
      return math_ops.multiply(initial_learning_rate, decayed)

In [14]:
def training(model, ds_l, ds_u, hparams, mean=None, std=None,
                   val_interval=2000, log_interval=200, batch_size=128):

    def train_prep(x, y):
        x = tf.cast(x, tf.float32)
        return x, y

    def valid_prep(x, y):
        x = tf.cast(x, tf.float32)
        return x, y

    def weak_transformation(x):
      x = tf.image.random_flip_left_right(x)
      max_shift = tf.cast(x.shape[1]*0.125, dtype=tf.dtypes.int32)
      shift = tf.random.uniform([x.shape[0], 2], minval=-max_shift, maxval=max_shift, dtype=tf.dtypes.int32)
      
      return tfa.image.translate(x, tf.cast(shift, tf.dtypes.float32))
      

    def pseudolabel(class_dist):
        argmax = tf.math.argmax(class_dist, axis=1)
        return tf.one_hot(argmax, class_dist.shape[1])

    def threshold_gate(one_hot, logits, threshold):
        max_probs = tf.math.multiply(one_hot, tf.nn.softmax(logits))
        return tf.cast(max_probs > threshold, max_probs.dtype)# * max_probs

    
    #@tf.function
    def step(x_l, y_l, x_u, training):
        with tf.GradientTape() as tape:            

            # labeled data
            x_l_weak = weak_transformation(x_l)
            output_l = model(x_l_weak, training)
            loss_l = loss_fn(y_l, output_l)

            
            # unlabeled data
            x_u_weak = weak_transformation(x_u)
            output_u_weak = model(x_u_weak, training)  # should this be training or not?
            y_u = pseudolabel(output_u_weak)
            y_u = threshold_gate(y_u, output_u_weak, hparams['treshold'])

            x_u_strong, choices, bins = cta.augment_batch(x)
            output_u_strong = model(x_u_strong, training)
            cta.update_weights_batch(y_u, outputs_u_strong, choices, bins)
            
            unlabeled_loss = loss_fn(y_u, output_u_strong)
            

            #add losses together
            loss = labeled_loss + hparams['lamda'] * unlabeled_loss


        if training:
            gradients = tape.gradient(loss, model.trainable_weights)
            optimizer.apply_gradients(zip(gradients, model.trainable_weights))

        accuracy(y, outs)
        cls_loss(c_loss)
        reg_loss(r_loss)

    schedule = OurCosineDecay(hparams['eta'], hparams['K'])
    optimizer = tf.keras.optimizers.SGD(schedule, momentum=hparams['beta'], nesterov=hparams['nesterov'])
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

    cta = CTAugment(hparams['cta_classes'], hparams['cta_decay'], hparams['cta_threshold'], hparams['cta_depth'])
    
    # split into batches
    #img_l = tf.split(img_l, hparams['batch_size'], axis=0, name='split')
    ds_l = ds_l.map(train_prep).batch(hparams['batch_size']).prefetch(-1)
    ds_u = ds_u.map(train_prep).batch(hparams['batch_size']).prefetch(-1)


    #runid = run_name + '_x' + str(np.random.randint(10000))
    #writer = tf.summary.create_file_writer(logdir + '/' + runid)
    accuracy = tf.metrics.SparseCategoricalAccuracy()
    cls_loss = tf.metrics.Mean()
    reg_loss = tf.metrics.Mean()
    
    #print(f"RUNID: {runid}")
    #tf.keras.utils.plot_model(model)#, os.path.join('saved_plots', runid + '.png'))    

    training_step = 0
    best_validation_acc = 0
    epochs = 1

    
    for epoch in range(epochs):
        #for x, y in tqdm( ds_l.take(val_interval), desc=f'epoch {epoch+1}/{epochs}',
        #                 total=val_interval, ncols=100, ascii=True):
        for (x_l, y_l), (x_u, _) in tqdm( zip(ds_l, ds_u), desc=f'epoch {epoch+1}/{epochs}',
                         total=val_interval, ncols=100, ascii=True):

            tf.print(x_l.shape)
            tf.print(y_l.shape)            
            tf.print(x_u.shape)


            training_step += 1
            step(x_l, y_l, x_u, training=True)

            if training_step % log_interval == 0:
                #with writer.as_default():
                    c_loss, r_loss, err = cls_loss.result(), reg_loss.result(), 1-accuracy.result()
                    print(f" c_loss: {c_loss:^6.3f} | r_loss: {r_loss:^6.3f} | err: {err:^6.3f}", end='\r')

                    tf.summary.scalar('train/error_rate', err, training_step)
                    tf.summary.scalar('train/classification_loss', c_loss, training_step)
                    tf.summary.scalar('train/regularization_loss', r_loss, training_step)
                    tf.summary.scalar('train/learnig_rate', optimizer._decayed_lr('float32'), training_step)
                    cls_loss.reset_states()
                    reg_loss.reset_states()
                    accuracy.reset_states()

        for x, y in ds['test']:
            step(x, y, training=False)

        #with writer.as_default(): TBULATE THE FOLLOWING WHEN UNCOMMENTING!
        tf.summary.scalar('test/classification_loss', cls_loss.result(), step=training_step)
        tf.summary.scalar('test/error_rate', 1-accuracy.result(), step=training_step)
            
        if accuracy.result() > best_validation_acc:
                best_validation_acc = accuracy.result()
                #model.save_weights(os.path.join('saved_models', runid + '.tf'))
                
        cls_loss.reset_states()
        accuracy.reset_states()

In [19]:

# hyperparams
lamda = 1     # proportion of unlabeled loss in total loss
eta = 0.03    # learning rate
beta = 0.09   # momentum
tau = 0.95    # threshold in pseudo-labeling
mu = 0.7      # proportion of unlabeled samples in batch
B = 64        # number of labeled examples in batch(in training)
K = 2 ** 20
nesterov = False
batch_size = 2  # should be 64?
# weight decay
# SGD instead of Adam


#CTAugment params
cta_classes = 10
cta_decay = 0.99
cta_depth = 2
cta_threshold = 0.8

hparams = {'lamda': lamda, 'eta': eta, 'beta': beta, 'tau': tau, 'mu': mu, 'B': B, 'K': K, 'nesterov': False, 'batch_size': batch_size,
           'cta_classes': cta_classes, 'cta_decay': cta_decay, 'cta_depth': cta_depth, 'cta_threshold': cta_threshold}

In [18]:
model = WRN_28_2()


In [15]:
def ParseFunction(serialized, image_shape=[32, 32, 3]):
    features = {'image': tf.io.FixedLenFeature([], tf.string),
                'label': tf.io.FixedLenFeature([], tf.int64)}

    parsed_example = tf.io.parse_single_example(serialized=serialized, features=features) 
    image = tf.image.decode_image(parsed_example['image'])
    image.set_shape(image_shape)
    # image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0
    data = dict(image=image, label=parsed_example['label'])
    return data

def stl_ParseFunction(input):
  return ParseFunction(serialized=input, image_shape=[96, 96, 3])

def LoadData(filename, tensor=False):
  dataset = tf.data.TFRecordDataset(filename)
  dataset = dataset.map(ParseFunction)
  
  # it = tf.compat.v1.data.make_one_shot_iterator(dataset) # Never used?
  images = np.stack([x['image'] for x in dataset])
  labels = np.stack([x['label'] for x in dataset])

  if tensor:
      return tf.data.Dataset.from_tensor_slices((images, labels))
  else:
      return images, labels


In [16]:
fpath = '/content/drive/My Drive/Colab Notebooks/cifar10.3@10-label.tfrecord'

im, ls = LoadData(fpath)

In [26]:
def test_error(model, test_data, test_labels):
  out = model(test_data)
  out_l = tf.math.argmax(out, axis=1)
  return np.sum(out_l == test_labels)/len(test_labels)

test_error(model, im, ls)

0.1