In [1]:
import numpy as np
import tensorflow as tf
from curriculum_learning.models.classifier_model import ClassifierModel
from curriculum_learning import utils
import yaml
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

In [36]:
with open("models_hyperparameters.yaml", "r") as stream:
    models_hyperparameters = yaml.safe_load(stream)
    
with open("config_tests.yaml", "r") as stream:
    config_tests = yaml.safe_load(stream)
    
N_EPOCHS = 50
N_TRIALS = 30
BATCH_SIZE = 512

CONFIG = config_tests["proba_best"]

loss = tf.keras.losses.SparseCategoricalCrossentropy()

In [81]:
ds_1 = tfds.load("stl10", split="train", as_supervised=True, shuffle_files=False)
ds_2 = tfds.load("stl10", split="test", as_supervised=True, shuffle_files=False)

x = []
y = []

for x_, y_ in ds_1.as_numpy_iterator():
    x.append(x_)
    y.append(y_)
    
for x_, y_ in ds_2.as_numpy_iterator():
    x.append(x_)
    y.append(y_)

2024-06-14 15:27:54.600667: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-06-14 15:27:55.310206: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [82]:
x = np.array(x, dtype=np.float32) / 255
y = np.array(y, dtype=np.float32)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=42, stratify=y)
x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=0.5, random_state=42, stratify=y_test)

n_classes = len(np.unique(y))
train_size = x_train.shape[0]
train_size, len(x_val), len(x_test)

(9100, 1950, 1950)

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


In [84]:
x_train_sorted = x_train[np.argsort(y_train)]
y_train_sorted = y_train[np.argsort(y_train)]
_, counts = np.unique(y_train_sorted, return_counts=True)

In [85]:
samples_proba = utils.calculate_proba_edges(x_train_sorted, counts)

In [88]:
model = ClassifierModel(output_shape=n_classes, **models_hyperparameters["test_model_1"])
model.compile(optimizer="adam", loss=loss, metrics=["accuracy"])
model(x_train[0:1])
# model.save_weights("../models/default_model.weights.h5")
model.load_weights("../models/default_model.weights.h5")
model_weights = model.get_weights()

In [89]:
CONFIG

{'order_type': 'proba'}

In [None]:
model_scores = []
verbose = 0
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', restore_best_weights=True, patience=5, start_from_epoch=0)

for _ in tqdm(range(N_TRIALS)):
    model.set_weights(model_weights)

    for i in range(N_EPOCHS):
        n_samples = int(np.tanh(4 * (i + 1) / N_EPOCHS) * train_size)

        samples_ids = utils.chose_samples(n_samples, samples_proba, CONFIG["order_type"])

        model.fit(
            x_train_sorted[samples_ids],
            y_train_sorted[samples_ids],
            # validation_data=(x_val, y_val),
            epochs=1,
            batch_size=BATCH_SIZE,
            verbose=verbose,
        )

    model.fit(
        x_train, y_train, validation_data=(x_val, y_val), epochs=500, batch_size=BATCH_SIZE, 
        callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss', restore_best_weights=True, patience=5, start_from_epoch=10)], 
        verbose=verbose
    )
            
    _, accuracy = model.evaluate(x_test, y_test, batch_size=BATCH_SIZE, verbose=1)
    model_scores.append(accuracy)
    print("Mean:", np.mean(model_scores), " Median: ", np.median(model_scores))

  0%|          | 0/30 [00:00<?, ?it/s]

In [15]:
np.min(model_scores)

0.7212345600128174

In [17]:
CONFIG

{'order_type': 'proba'}

In [16]:
model_scores

[0.8392592668533325,
 0.8106172680854797,
 0.8518518805503845,
 0.8118518590927124,
 0.8627160787582397,
 0.7886419892311096,
 0.8288888931274414,
 0.8254321217536926,
 0.82419753074646,
 0.8402469158172607,
 0.7856789827346802,
 0.8083950877189636,
 0.8318518400192261,
 0.8059259057044983,
 0.8204938173294067,
 0.8283950686454773,
 0.7212345600128174,
 0.760493814945221,
 0.8012345433235168,
 0.7787654399871826,
 0.8548148274421692,
 0.8037037253379822,
 0.8424691557884216,
 0.7701234817504883,
 0.8392592668533325,
 0.7866666913032532,
 0.7990123629570007,
 0.8306173086166382,
 0.8404937982559204,
 0.843950629234314]

In [17]:
CONFIG

{'order_type': 'fixed'}