In [None]:
%load_ext autoreload
%autoreload 1

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras import backend as K
from tqdm import tqdm

import sys
sys.path += ['../']
from utils.lr_finder import LRFinder

# Fake Model for Quick POC

In [None]:
%load_ext autoreload
%autoreload 1

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras import backend as K
from tqdm import tqdm

import sys
sys.path += ['../']
from utils.lr_finder import LRFinder

np.random.seed(43)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1),
])

epoch = 5
batch_size = 32
sample_size = 50
train_data = tf.data.Dataset.from_tensor_slices((np.zeros((sample_size, batch_size, 10, 3)),np.random.rand(sample_size, batch_size, 1)))
test_data = tf.data.Dataset.from_tensor_slices((np.zeros((sample_size, batch_size, 10, 3)), np.random.rand(sample_size, batch_size, 1)))

callbacks = [LRFinder(train_data, batch_size, window_size=4, max_steps=400, filename='logs/lr_finder')]
lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay((4,400), epoch * len(train_data), alpha=1e-2)
opt = tf.keras.optimizers.Adam(learning_rate=lr_decayed_fn, beta_1=0.9, beta_2=0.98, epsilon=1e-09)
model.compile(optimizer=opt, loss='mse')

model.fit(train_data, validation_data=test_data, batch_size=batch_size, epochs=5, callbacks=callbacks)

# Real Model for Validation and Test

In [None]:
%load_ext autoreload
%autoreload 1

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras import backend as K
from tqdm import tqdm

import sys
sys.path += ['../']
from utils.lr_finder import LRFinder

import os, json, time
from datasets import get_dataset
from trainers import get_trainer

def save_log(history, val_metric: str=None):
    logs = { 'history': history }

    if val_metric != None:
        logs['best_acc'] = max(history[val_metric])
        print('Best score: ', logs['best_acc'])

    logfile_name = f'logs/qnet-{int(time.time())}.json'
    os.makedirs(os.path.dirname(logfile_name), exist_ok=True)
    with open(logfile_name, 'w') as f:
        json.dump(logs, f, indent=4)
    print('Log file saved at: ', logfile_name)

class Args:
    lr = 3e-4 * 128
    batch_size = 128
    seq_len = 8
    epochs = 5
    model = 'fnet'
    embed_size = 16
    num_blocks = 1
    qnet_depth = 1
    lr_finder = [4,400,'logs/lr_finder']
    dataset = 'msra'
    # dataset = 'colbert'
    # dataset = 'rentrunway'
    # dataset = 'msra'

args = Args()

np.random.seed(42)
tf.random.set_seed(42)

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("GPU is", "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")

dataset = get_dataset(args.dataset)
trainer = get_trainer(dataset.getTask())

fitting = trainer.train(args, dataset)

if dataset.getTask() == 'classification':
    if dataset.getOutputSize() > 2:
        save_log(fitting.history, 'val_categorical_accuracy')
    else:
        save_log(fitting.history, 'val_binary_accuracy')
else:
    save_log(fitting.history)