In [1]:
import numpy as np
import tensorflow as tf
from curriculum_learning.models.classifier_model import ClassifierModel
from curriculum_learning import utils
import yaml
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
with open("models_hyperparameters.yaml", "r") as stream:
    models_hyperparameters = yaml.safe_load(stream)
    
with open("config_tests.yaml", "r") as stream:
    config_tests = yaml.safe_load(stream)
    
N_EPOCHS = 50
N_TRIALS = 30
BATCH_SIZE = 512

CONFIG = config_tests["proba_best"]

loss = tf.keras.losses.SparseCategoricalCrossentropy()

In [3]:
ds_1 = tfds.load("eurosat", split="train", as_supervised=True, shuffle_files=False)
# ds_2 = tfds.load("stl10", split="test", as_supervised=True, shuffle_files=False)

x = []
y = []

for x_, y_ in ds_1.as_numpy_iterator():
    x.append(x_)
    y.append(y_)
    
# for x_, y_ in ds_2.as_numpy_iterator():
#     x.append(x_)
#     y.append(y_)

2024-07-01 11:10:17.628907: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [4]:
x = np.array(x, dtype=np.float32) / 255
y = np.array(y, dtype=np.float32)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=42, stratify=y)
x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=0.5, random_state=42, stratify=y_test)

n_classes = len(np.unique(y))
train_size = x_train.shape[0]
train_size, len(x_val), len(x_test)

(18900, 4050, 4050)

In [5]:
x_train_sorted = x_train[np.argsort(y_train)]
y_train_sorted = y_train[np.argsort(y_train)]
_, counts = np.unique(y_train_sorted, return_counts=True)

In [6]:
samples_proba = utils.calculate_proba_edges(x_train_sorted, counts, blur=True)

In [7]:
model = ClassifierModel(output_shape=n_classes, **models_hyperparameters["test_model_1"])
model.compile(optimizer="adam", loss=loss, metrics=["accuracy"])
model(x_train[0:1])
# model.save_weights("../models/default_model.weights.h5")
model.load_weights("../models/default_model.weights.h5")
model_weights = model.get_weights()

In [18]:
CONFIG

{'order_type': 'proba'}

In [None]:
model_scores = []
verbose = 0
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', restore_best_weights=True, patience=5, start_from_epoch=0)

for _ in tqdm(range(N_TRIALS // 2)):
    model.set_weights(model_weights)

    for i in range(N_EPOCHS):
        n_samples = int(np.tanh(4 * (i + 1) / N_EPOCHS) * train_size)

        samples_ids = utils.chose_samples(n_samples, samples_proba, CONFIG["order_type"])

        model.fit(
            x_train_sorted[samples_ids],
            y_train_sorted[samples_ids],
            # validation_data=(x_val, y_val),
            epochs=1,
            batch_size=BATCH_SIZE,
            verbose=verbose,
        )

    model.fit(
        x_train, y_train, validation_data=(x_val, y_val), epochs=500, batch_size=BATCH_SIZE, 
        callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss', restore_best_weights=True, patience=5, start_from_epoch=10)], 
        verbose=verbose
    )
            
    _, accuracy = model.evaluate(x_test, y_test, batch_size=BATCH_SIZE, verbose=0)
    model_scores.append(accuracy)
    print(f"Mean: {np.mean(model_scores):.4f}   Median: {np.median(model_scores):.4f}   Last accuracy {accuracy:.4f}")

  7%|▋         | 1/15 [03:52<54:15, 232.55s/it]

Mean: 0.8659   Median: 0.8659   Last accuracy 0.8659


 13%|█▎        | 2/15 [06:42<42:24, 195.72s/it]

Mean: 0.8643   Median: 0.8643   Last accuracy 0.8627


 20%|██        | 3/15 [10:12<40:27, 202.27s/it]

Mean: 0.8571   Median: 0.8627   Last accuracy 0.8427


 27%|██▋       | 4/15 [13:55<38:34, 210.44s/it]

Mean: 0.8513   Median: 0.8527   Last accuracy 0.8338


 33%|███▎      | 5/15 [18:05<37:28, 224.84s/it]

Mean: 0.8499   Median: 0.8444   Last accuracy 0.8444


In [None]:
np.min(model_scores)

In [None]:
CONFIG

In [None]:
model_scores


In [28]:
CONFIG

{'order_type': 'proba'}