# SimCLR
### Jade
### Test different data augmentation techniques in SimCLR

### Part 0: Import libraries

In [1]:
import pandas as pd
import tensorflow as tf
from glob import glob
import os
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
import csv
import json
import time
from tensorflow.keras.applications import ResNet50, ResNet101V2, Xception, InceptionV3
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import *
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
#from utils import *
#import helpers
#import losses
import argparse
import cv2

### Part 1: Set path 

In [2]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [3]:
BASE_PATH = '/content/gdrive/My Drive/BigEarthNet/'
OUTPUT_PATH = os.path.join(BASE_PATH, 'data_augmentation')
TFR_PATH = os.path.join(BASE_PATH, 'tfrecords')

### Part 2: Model functions

In [4]:
# helper
import tensorflow as tf
import numpy as np
#from augmentation.gaussian_filter import GaussianBlur


def get_negative_mask(batch_size):
    # return a mask that removes the similarity score of equal/similar images.
    # this function ensures that only distinct pair of images get their similarity scores
    # passed as negative examples
    negative_mask = np.ones((batch_size, 2 * batch_size), dtype=bool)
    for i in range(batch_size):
        negative_mask[i, i] = 0
        negative_mask[i, i + batch_size] = 0
    return tf.constant(negative_mask)

In [5]:
# losses
import tensorflow as tf

cosine_sim_1d = tf.keras.losses.CosineSimilarity(axis=1, reduction=tf.keras.losses.Reduction.NONE)
cosine_sim_2d = tf.keras.losses.CosineSimilarity(axis=2, reduction=tf.keras.losses.Reduction.NONE)


def _cosine_simililarity_dim1(x, y):
    v = cosine_sim_1d(x, y)
    return v


def _cosine_simililarity_dim2(x, y):
    # x shape: (N, 1, C)
    # y shape: (1, 2N, C)
    # v shape: (N, 2N)
    v = cosine_sim_2d(tf.expand_dims(x, 1), tf.expand_dims(y, 0))
    return v


def _dot_simililarity_dim1(x, y):
    # x shape: (N, 1, C)
    # y shape: (N, C, 1)
    # v shape: (N, 1, 1)
    v = tf.matmul(tf.expand_dims(x, 1), tf.expand_dims(y, 2))
    return v


def _dot_simililarity_dim2(x, y):
    v = tf.tensordot(tf.expand_dims(x, 1), tf.expand_dims(tf.transpose(y), 0), axes=2)
    # x shape: (N, 1, C)
    # y shape: (1, C, 2N)
    # v shape: (N, 2N)
    return v

In [6]:
#util
import tensorflow as tf
import cv2
import numpy as np
from tensorflow.keras.preprocessing import image


def read_tfrecord(example):
    BAND_STATS = {
        'mean': {
            'B01': 340.76769064,
            'B02': 429.9430203,
            'B03': 614.21682446,
            'B04': 590.23569706,
            'B05': 950.68368468,
            'B06': 1792.46290469,
            'B07': 2075.46795189,
            'B08': 2218.94553375,
            'B8A': 2266.46036911,
            'B09': 2246.0605464,
            'B11': 1594.42694882,
            'B12': 1009.32729131
        },
        'std': {
            'B01': 554.81258967,
            'B02': 572.41639287,
            'B03': 582.87945694,
            'B04': 675.88746967,
            'B05': 729.89827633,
            'B06': 1096.01480586,
            'B07': 1273.45393088,
            'B08': 1365.45589904,
            'B8A': 1356.13789355,
            'B09': 1302.3292881,
            'B11': 1079.19066363,
            'B12': 818.86747235
        }
    }

    # Use this one-liner to standardize each feature prior to reshaping.
    def standardize_feature(data, band_name):
        return ((tf.dtypes.cast(data, tf.float32) - BAND_STATS['mean'][band_name]) / BAND_STATS['std'][band_name])

    # decode the TFRecord
    # The parse single example methods takes an example (from a tfrecords file),
    # and a dictionary that explains the data format of each feature.
    example = tf.io.parse_single_example(example, {
        'B01': tf.io.FixedLenFeature([20 * 20], tf.int64),
        'B02': tf.io.FixedLenFeature([120 * 120], tf.int64),
        'B03': tf.io.FixedLenFeature([120 * 120], tf.int64),
        'B04': tf.io.FixedLenFeature([120 * 120], tf.int64),
        'B05': tf.io.FixedLenFeature([60 * 60], tf.int64),
        'B06': tf.io.FixedLenFeature([60 * 60], tf.int64),
        'B07': tf.io.FixedLenFeature([60 * 60], tf.int64),
        'B08': tf.io.FixedLenFeature([120 * 120], tf.int64),
        'B8A': tf.io.FixedLenFeature([60 * 60], tf.int64),
        'B09': tf.io.FixedLenFeature([20 * 20], tf.int64),
        'B11': tf.io.FixedLenFeature([60 * 60], tf.int64),
        'B12': tf.io.FixedLenFeature([60 * 60], tf.int64),
        'patch_name': tf.io.VarLenFeature(dtype=tf.string),
        'original_labels': tf.io.VarLenFeature(dtype=tf.string),
        'original_labels_multi_hot': tf.io.FixedLenFeature([43], tf.int64)
    })

    example['binary_label'] = example['original_labels_multi_hot'][tf.constant(12)]

    # After parsing our data into a tensor, let's standardize and reshape.
    reshaped_example = {
        'B01': tf.reshape(standardize_feature(example['B01'], 'B01'), [20, 20]),
        'B02': tf.reshape(standardize_feature(example['B02'], 'B02'), [120, 120]),
        'B03': tf.reshape(standardize_feature(example['B03'], 'B03'), [120, 120]),
        'B04': tf.reshape(standardize_feature(example['B04'], 'B04'), [120, 120]),
        'B05': tf.reshape(standardize_feature(example['B05'], 'B05'), [60, 60]),
        'B06': tf.reshape(standardize_feature(example['B06'], 'B06'), [60, 60]),
        'B07': tf.reshape(standardize_feature(example['B07'], 'B07'), [60, 60]),
        'B08': tf.reshape(standardize_feature(example['B08'], 'B08'), [120, 120]),
        'B8A': tf.reshape(standardize_feature(example['B8A'], 'B8A'), [60, 60]),
        'B09': tf.reshape(standardize_feature(example['B09'], 'B09'), [20, 20]),
        'B11': tf.reshape(standardize_feature(example['B11'], 'B11'), [60, 60]),
        'B12': tf.reshape(standardize_feature(example['B12'], 'B12'), [60, 60]),
        'patch_name': example['patch_name'],
        'original_labels': example['original_labels'],
        'original_labels_multi_hot': example['original_labels_multi_hot'],
        'binary_labels': example['binary_label']
    }

    # Next sort the layers by resolution
    bands_10m = tf.stack([reshaped_example['B04'],
                          reshaped_example['B03'],
                          reshaped_example['B02'],
                          reshaped_example['B08']], axis=2)

    bands_20m = tf.stack([reshaped_example['B05'],
                          reshaped_example['B06'],
                          reshaped_example['B07'],
                          reshaped_example['B8A'],
                          reshaped_example['B11'],
                          reshaped_example['B12']], axis=2)

    # Finally resize the 20m data and stack the bands together.
    img = tf.concat([bands_10m, tf.image.resize(bands_20m, [120, 120], method='bicubic')], axis=2)
    
    multi_hot_label = reshaped_example['original_labels_multi_hot']
    binary_label = reshaped_example['binary_labels']
    
    # Can update this to return the multilabel if doing multi-class classification
    return img, binary_label
  
  
def get_batched_dataset(filenames, batch_size, augment=False):
    option_no_order = tf.data.Options()
    option_no_order.experimental_deterministic = False

    dataset = tf.data.Dataset.list_files(filenames, shuffle=True)
    print(f'Filenames: {filenames}')
    dataset = dataset.with_options(option_no_order)
    dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=2, num_parallel_calls=1)
    dataset = dataset.shuffle(buffer_size=2048)
    #.repeat()
    
    dataset = dataset.map(read_tfrecord, num_parallel_calls=10)
    dataset = dataset.batch(batch_size, drop_remainder=True)  # drop_remainder will be needed on TPU
    dataset = dataset.prefetch(5)  #

    return dataset

class TimeHistory(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, batch, logs={}):
        self.epoch_time_start = time.time()

    def on_epoch_end(self, batch, logs={}):
        self.times.append(time.time() - self.epoch_time_start)

class Augment():
  def augfunc(self, sample):        
    # Randomly apply transformation (color distortions) with probability p.
    sample = self._random_apply(self._color_jitter, sample, p=0.8)
    sample = self._random_apply(self._color_drop, sample, p=0.2)
    sample = self._random_apply(self._blur, sample, p=0.5)

    return sample

  def _color_jitter(self,  x, s=1):
      # one can also shuffle the order of following augmentations
      # each time they are applied.
      x = tf.image.random_brightness(x, max_delta=0.8*s)
      x = tf.image.random_contrast(x, lower=1-0.8*s, upper=1+0.8*s)
      dx = tf.image.random_saturation(x[:,:,:3], lower=1-0.8*s, upper=1+0.8*s)
      dx = tf.image.random_hue(dx, max_delta=0.2*s)
      x = tf.concat([dx, x[:,:,3:]],axis=2)
      x = tf.clip_by_value(x, 0, 1)
      return x

  def _color_drop(self, x):
      dx = tf.image.rgb_to_grayscale(x[:,:,:3])
      dx = tf.tile(dx, [1, 1, 3])
      x = tf.concat([dx, x[:,:,3:]],axis=2)
      return x

  def _blur(self, x):
      # SimClr implementation is applied at 10% of image size with a random sigma
      p = np.random.uniform(0.1,2)
      if type(x) == np.ndarray:
          return (cv2.GaussianBlur(x,(5,5),p))
      return (cv2.GaussianBlur(x.numpy(),(5,5),p))

  def _random_apply(self, func, x, p):
      return tf.cond(
        tf.less(tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
                tf.cast(p, tf.float32)),
        lambda: func(x),
        lambda: x)

In [8]:
def get_training_dataset(training_filenames, batch_size):
  return get_batched_dataset(training_filenames, batch_size)


def build_simclr_model(imported_model, hidden_1, hidden_2, hidden_3):
  
  base_model = imported_model(include_top=False, weights=None, input_shape=[120,120, 10])
  base_model.trainable = True
  
  inputs = Input((120,120, 10))
  
  h = base_model(inputs, training=True)
  h = GlobalAveragePooling2D()(h)
  
  projection_1 = Dense(hidden_1)(h)
  projection_1 = Activation("relu")(projection_1)
  projection_2 = Dense(hidden_2)(projection_1)
  projection_2 = Activation("relu")(projection_2)
  projection_3 = Dense(hidden_3)(projection_2)

  simclr_model = tf.keras.models.Model(inputs, projection_3)
  
  return simclr_model

In [9]:
@tf.function
def train_step(xis, xjs, model, optimizer, criterion, temperature, batch_size):
    # Mask to remove positive examples from the batch of negative samples
    negative_mask = get_negative_mask(batch_size)
  
    with tf.GradientTape() as tape:
        zis = model(xis)
        zjs = model(xjs)

        # normalize projection feature vectors
        zis = tf.math.l2_normalize(zis, axis=1)
        zjs = tf.math.l2_normalize(zjs, axis=1)

        l_pos = _dot_simililarity_dim1(zis, zjs)
        l_pos = tf.reshape(l_pos, (batch_size, 1))
        l_pos /= temperature

        negatives = tf.concat([zjs, zis], axis=0)

        loss = 0

        for positives in [zis, zjs]:
            l_neg = _dot_simililarity_dim2(positives, negatives)

            labels = tf.zeros(batch_size, dtype=tf.int32)

            l_neg = tf.boolean_mask(l_neg, negative_mask)
            l_neg = tf.reshape(l_neg, (batch_size, -1))
            l_neg /= temperature

            logits = tf.concat([l_pos, l_neg], axis=1) 
            loss += criterion(y_pred=logits, y_true=labels)

        loss = loss / (2 * batch_size)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    return loss    

In [10]:
def run_model(name, BATCH_SIZE, epochs, architecture, temperature):
    
    print(50 * "*")
    print(f"Running model: SimCLR {name}")
    print(50 * "=")
    print(f"Batch Size: {BATCH_SIZE}")
    print(50 * "=")
    print(f'Using Model Architecture: {architecture}')
    
    training_filenames = f'{TFR_PATH}/train-part-0.tfrecord'
    training_data = get_training_dataset(training_filenames, BATCH_SIZE)

#     len_train_records = 9942*5
#     steps_per_epoch = len_train_records // BATCH_SIZE
    
    criterion = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, 
                                                          reduction=tf.keras.losses.Reduction.SUM)
    decay_steps = 1000
    lr_decayed_fn = tf.keras.experimental.CosineDecay(
        initial_learning_rate=0.1, decay_steps=decay_steps)
    optimizer = tf.keras.optimizers.SGD(lr_decayed_fn)

    simclr_2 = build_simclr_model(architecture,1024, 512, 128)
    simclr_2.summary()

    
    epoch_wise_loss = []
    
    time_callback = TimeHistory()
    augment = Augment()
    
    ROTATION = 180
    SHIFT = 0.10
    FLIP = True
    ZOOM = 0.20
    JITTER = True
    BLUR = True
    
    datagen = image.ImageDataGenerator(
            rotation_range=ROTATION,
            width_shift_range=SHIFT,
            height_shift_range=SHIFT,
            horizontal_flip=FLIP,
            vertical_flip=FLIP,
            zoom_range=ZOOM,
            preprocessing_function= augment.augfunc)
    
    min_loss = 1e6
    min_loss_epoch = 0
    
    for epoch in tqdm(range(epochs)):
      step_wise_loss = []
      for image_batch in tqdm(training_data):
        a = datagen.flow(image_batch, batch_size=BATCH_SIZE, shuffle=False)
        b = datagen.flow(image_batch, batch_size=BATCH_SIZE, shuffle=False)

        loss = train_step(a[0][0], b[0][0], simclr_2, optimizer, criterion, temperature=0.1, batch_size=BATCH_SIZE)
        step_wise_loss.append(loss)

      epoch_wise_loss.append(np.mean(step_wise_loss))
      # Print the loss after every epoch
      print(f"****epoch: {epoch + 1} loss: {epoch_wise_loss[-1]:.3f}****\n")
        
      # Save best weights
      if epoch_wise_loss[-1] < min_loss:
        # Save the final model with weights
        simclr_2.save(f'{OUTPUT_PATH}/{name}.h5')
        min_loss_epoch = epoch+1
  
    # Store the epochwise loss and model metadata to dataframe
    df = pd.DataFrame(epoch_wise_loss)
    df['temperature'] = temperature
    df['batch_size'] = BATCH_SIZE
    df['epochs'] = epochs
    df['h1'] = 1024
    df['h2'] = 512
    df['output_dim'] = 128
    df['rotation'] = ROTATION
    df['shift'] = ROTATION
    df['flip'] = ROTATION
    df['zoom'] = ROTATION
    df['jitter'] = ROTATION
    df['blur'] = ROTATION
    df['best_epoch'] = min_loss_epoch
  
    df.to_pickle(f'{OUTPUT_PATH}/{name}.pkl')
    
    return df


### Part 3: Training model

In [11]:
run_model('simclr1',
          BATCH_SIZE=32,
          epochs=5,
          architecture=ResNet50,
          temperature=0.1)
    

**************************************************
Running model: SimCLR simclr1
Batch Size: 32
Using Model Architecture: <function ResNet50 at 0x7f6c179ba730>
Filenames: /content/gdrive/My Drive/BigEarthNet/tfrecords/train-part-0.tfrecord


  str(input_shape[-1]) + ' input channels.')
  0%|          | 0/5 [00:00<?, ?it/s]
0it [00:00, ?it/s][A

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 120, 120, 10)]    0         
_________________________________________________________________
resnet50 (Functional)        (None, 4, 4, 2048)        23609664  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dense (Dense)                (None, 1024)              2098176   
_________________________________________________________________
activation (Activation)      (None, 1024)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               524800    
_________________________________________________________________
activation_1 (Activation)    (None, 512)              

  str(self.x.shape[channels_axis]) + ' channels).')

1it [00:24, 24.74s/it][A
2it [00:40, 22.03s/it][A
3it [00:56, 20.13s/it][A
4it [01:11, 18.81s/it][A
5it [01:27, 17.87s/it][A
6it [01:43, 17.20s/it][A
7it [01:58, 16.79s/it][A
8it [02:14, 16.44s/it][A
9it [02:30, 16.21s/it][A
10it [02:45, 16.03s/it][A
11it [03:01, 15.92s/it][A
12it [03:17, 15.81s/it][A
13it [03:32, 15.76s/it][A
14it [03:48, 15.69s/it][A
15it [04:04, 15.75s/it][A
16it [04:19, 15.68s/it][A
17it [04:35, 15.70s/it][A
18it [04:51, 15.66s/it][A
19it [05:06, 15.66s/it][A
20it [05:25, 16.52s/it][A
21it [05:40, 16.27s/it][A
22it [05:56, 16.10s/it][A
23it [06:12, 16.02s/it][A
24it [06:28, 15.92s/it][A
25it [06:43, 15.82s/it][A
26it [06:59, 15.78s/it][A
27it [07:14, 15.69s/it][A
28it [07:30, 15.64s/it][A
29it [07:45, 15.59s/it][A
30it [08:01, 15.61s/it][A
31it [08:17, 15.59s/it][A
32it [08:32, 15.59s/it][A
33it [08:48, 15.57s/it][A
34it [09:03, 15.59s/it][A
35it [09:19, 15.59s/it][A
36it [09:34

****epoch: 1 loss: 3.574****



 20%|██        | 1/5 [31:57<2:07:51, 1917.87s/it]
0it [00:00, ?it/s][A
1it [00:17, 17.53s/it][A
2it [00:33, 16.98s/it][A
3it [00:48, 16.58s/it][A
4it [01:04, 16.38s/it][A
5it [01:20, 16.18s/it][A
6it [01:36, 16.03s/it][A
7it [01:51, 15.94s/it][A
8it [02:07, 15.91s/it][A
9it [02:23, 15.82s/it][A
10it [02:39, 15.79s/it][A
11it [02:55, 15.84s/it][A
12it [03:10, 15.86s/it][A
13it [03:26, 15.84s/it][A
14it [03:42, 15.82s/it][A
15it [04:00, 16.62s/it][A
16it [04:16, 16.35s/it][A
17it [04:32, 16.20s/it][A
18it [04:48, 16.14s/it][A
19it [05:04, 16.12s/it][A
20it [05:20, 16.11s/it][A
21it [05:36, 16.06s/it][A
22it [05:52, 16.01s/it][A
23it [06:08, 16.02s/it][A
24it [06:24, 15.99s/it][A
25it [06:40, 15.93s/it][A
26it [06:56, 15.89s/it][A
27it [07:11, 15.83s/it][A
28it [07:27, 15.80s/it][A
29it [07:43, 15.82s/it][A
30it [07:59, 15.84s/it][A
31it [08:15, 15.82s/it][A
32it [08:30, 15.82s/it][A
33it [08:46, 15.75s/it][A
34it [09:02, 15.71s/it][A
35it [09:17, 15.65s

****epoch: 2 loss: 3.098****



 40%|████      | 2/5 [1:04:03<1:36:00, 1920.25s/it]
0it [00:00, ?it/s][A
1it [00:17, 17.62s/it][A
2it [00:33, 17.03s/it][A
3it [00:49, 16.64s/it][A
4it [01:04, 16.34s/it][A
5it [01:20, 16.16s/it][A
6it [01:35, 15.98s/it][A
7it [01:51, 15.91s/it][A
8it [02:07, 15.85s/it][A
9it [02:25, 16.60s/it][A
10it [02:41, 16.37s/it][A
11it [02:57, 16.19s/it][A
12it [03:12, 16.00s/it][A
13it [03:28, 15.91s/it][A
14it [03:44, 15.83s/it][A
15it [03:59, 15.79s/it][A
16it [04:15, 15.74s/it][A
17it [04:31, 15.75s/it][A
18it [04:47, 15.75s/it][A
19it [05:02, 15.78s/it][A
20it [05:18, 15.76s/it][A
21it [05:34, 15.75s/it][A
22it [05:50, 15.76s/it][A
23it [06:05, 15.78s/it][A
24it [06:21, 15.79s/it][A
25it [06:37, 15.77s/it][A
26it [06:53, 15.75s/it][A
27it [07:08, 15.71s/it][A
28it [07:24, 15.74s/it][A
29it [07:40, 15.71s/it][A
30it [07:56, 15.72s/it][A
31it [08:11, 15.70s/it][A
32it [08:27, 15.69s/it][A
33it [08:42, 15.65s/it][A
34it [08:58, 15.72s/it][A
35it [09:14, 15.6

****epoch: 3 loss: 2.886****



 60%|██████    | 3/5 [1:35:58<1:03:57, 1918.71s/it]
0it [00:00, ?it/s][A
1it [00:17, 17.28s/it][A
2it [00:32, 16.79s/it][A
3it [00:48, 16.41s/it][A
4it [01:06, 17.00s/it][A
5it [01:22, 16.58s/it][A
6it [01:38, 16.29s/it][A
7it [01:53, 16.10s/it][A
8it [02:09, 15.95s/it][A
9it [02:24, 15.84s/it][A
10it [02:40, 15.73s/it][A
11it [02:56, 15.71s/it][A
12it [03:11, 15.69s/it][A
13it [03:27, 15.69s/it][A
14it [03:43, 15.69s/it][A
15it [03:58, 15.70s/it][A
16it [04:14, 15.68s/it][A
17it [04:30, 15.65s/it][A
18it [04:45, 15.65s/it][A
19it [05:01, 15.72s/it][A
20it [05:17, 15.71s/it][A
21it [05:32, 15.70s/it][A
22it [05:48, 15.69s/it][A
23it [06:04, 15.71s/it][A
24it [06:20, 15.74s/it][A
25it [06:35, 15.76s/it][A
26it [06:51, 15.74s/it][A
27it [07:07, 15.75s/it][A
28it [07:22, 15.67s/it][A
29it [07:38, 15.65s/it][A
30it [07:54, 15.64s/it][A
31it [08:09, 15.62s/it][A
32it [08:25, 15.73s/it][A
33it [08:41, 15.69s/it][A
34it [08:57, 15.74s/it][A
35it [09:12, 15.6

****epoch: 4 loss: 2.701****



 80%|████████  | 4/5 [2:07:53<31:57, 1917.51s/it]  
0it [00:00, ?it/s][A
1it [00:17, 17.48s/it][A
2it [00:33, 16.93s/it][A
3it [00:48, 16.53s/it][A
4it [01:04, 16.43s/it][A
5it [01:20, 16.18s/it][A
6it [01:36, 16.01s/it][A
7it [01:51, 15.88s/it][A
8it [02:07, 15.83s/it][A
9it [02:22, 15.75s/it][A
10it [02:38, 15.70s/it][A
11it [02:54, 15.72s/it][A
12it [03:10, 15.72s/it][A
13it [03:25, 15.67s/it][A
14it [03:41, 15.67s/it][A
15it [03:57, 15.69s/it][A
16it [04:12, 15.71s/it][A
17it [04:28, 15.70s/it][A
18it [04:44, 15.71s/it][A
19it [04:59, 15.72s/it][A
20it [05:15, 15.69s/it][A
21it [05:31, 15.71s/it][A
22it [05:46, 15.71s/it][A
23it [06:02, 15.77s/it][A
24it [06:18, 15.83s/it][A
25it [06:35, 15.96s/it][A
26it [06:51, 16.18s/it][A
27it [07:09, 16.50s/it][A
28it [07:25, 16.51s/it][A
29it [07:42, 16.54s/it][A
30it [07:58, 16.56s/it][A
31it [08:15, 16.60s/it][A
32it [08:32, 16.61s/it][A
33it [08:48, 16.58s/it][A
34it [09:05, 16.62s/it][A
35it [09:21, 16.5

****epoch: 5 loss: 2.544****



100%|██████████| 5/5 [2:40:03<00:00, 1920.79s/it]


Unnamed: 0,0,temperature,batch_size,epochs,h1,h2,output_dim,rotation,shift,flip,zoom,jitter,blur,best_epoch
0,3.573758,0.1,32,5,1024,512,128,180,180,180,180,180,180,5
1,3.097945,0.1,32,5,1024,512,128,180,180,180,180,180,180,5
2,2.886214,0.1,32,5,1024,512,128,180,180,180,180,180,180,5
3,2.700789,0.1,32,5,1024,512,128,180,180,180,180,180,180,5
4,2.54352,0.1,32,5,1024,512,128,180,180,180,180,180,180,5
