# 0 - Requirements

In [None]:
pip install -U tensorflow-addons

In [None]:
!pip install wandb

In [3]:
import tensorflow as tf
import pandas as pd
from sklearn.model_selection import KFold, train_test_split
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
from typing import List, Dict
import wandb

import zipfile
from tqdm import tqdm
import os
import re

# Login into W&B
WB_ENTITY = 'ual'
WB_PROJECT = 'hvit_classifier'
WB_KEY = '1bb44e6be47564584868ec55bac8cf468cf0e47f'

In [4]:
tf.config.list_physical_devices('GPU')

[]

# 1 - Data

In [None]:
with zipfile.ZipFile('/content/drive/MyDrive/Computer Vision research/archive.zip') as zf:
    for member in tqdm(zf.infolist(), desc='Extracting '):
        if member.filename[-4:]=='jpeg':
            try:
                zf.extract(member, '/content/')
            except zipfile.error as e:
                pass
        else:
            continue

# 2 - Parameters

## Config

In [None]:
# Config
path = '/content/OCT2017 /train/'
color_mode = "grayscale"
class_mode = "categorical"
img_size = 128
batch_size = 32
epochs = 10
pct_split=[.8,.2,.1]
seed = 123
verbose=1
learning_rate = 0.00005
weight_decay = 0.0001
label_smoothing = .1

## Image generators

In [None]:
ImageDataGenerator_config = {
    'train':{
        "rescale":1./255,
        "shear_range":.1,
        "rotation_range":.2,
        "zoom_range":.1,
        "horizontal_flip" : True,
        },
    'val':{
        "rescale":1./255,
        },
    'test':{
        "rescale":1./255,
        }
}
flow_from_dataframe_config = {
    'train':{
        "dataframe":None,
        "directory":None,
        "x_col":"x_col",
        "y_col":"y_col",
        "batch_size":batch_size,
        "target_size":(img_size, img_size),
        "color_mode":color_mode,
        "class_mode":class_mode,
        "shuffle":True,
        "seed":seed,
        },
    'val':{
        "dataframe":None,
        "directory":None,
        "x_col":"x_col",
        "y_col":"y_col",
        "batch_size":batch_size,
        "target_size":(img_size, img_size),
        "color_mode":color_mode,
        "class_mode":class_mode,
        "shuffle":True,
        "seed":seed,
        },
    'test':{
        "dataframe":None,
        "directory":None,
        "x_col":"x_col",
        "y_col":"y_col",
        "batch_size":batch_size,
        "target_size":(img_size, img_size),
        "color_mode":color_mode,
        "class_mode":class_mode,
        "shuffle":True,
        "seed":seed,
        }
}

## Model

patch_size = [8,16,32]
num_channels = 1
num_heads = 8
transformer_layers = [4,4,4]
hidden_unit_factor = 2
mlp_head_units = [1024, 128]
num_classes=4
drop_attn=.2
drop_proj=.2
drop_linear=.4
projection_dim = None
resampling_type="standard"
original_attn=True

# Run experiment

# Import model
from HViT_classification.model.ViT_model import HViT
from HViT_classification.model.experiments import run_WB_experiment
# Set group
WB_GROUP = 'HViT'
# Start running
with tf.device('/device:GPU:0'):
  # Instance model
  inputs = tf.keras.layers.Input((img_size, img_size, num_channels))
  outputs = HViT(
                 img_size,
                 patch_size,
                 projection_dim,
                 num_channels,
                 num_heads,
                 transformer_layers,
                 mlp_head_units,
                 num_classes,
                 hidden_unit_factor,
                 drop_attn,
                 drop_proj,
                 drop_linear,
                 resampling_type,
                 original_attn,
                 )(inputs)
  model = tf.keras.Model(inputs, outputs)
  # Run experiment
  run_WB_experiment(WB_KEY,
                    WB_ENTITY,
                    WB_PROJECT,
                    WB_GROUP,
                    model,
                    ImageDataGenerator_config,
                    flow_from_dataframe_config,
                    path=path,
                    epochs=epochs,
                    pct_split=pct_split,
                    learning_rate=learning_rate,
                    weight_decay=weight_decay,
                    label_smoothing = label_smoothing,
                    seed=seed,
                    verbose=verbose,
                    )