In [1]:
import numpy as np
import tensorflow as tf
from curriculum_learning.models.classifier_model import ClassifierModel
from curriculum_learning import utils
import yaml
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [34]:
N_EPOCHS = 100
N_TRIALS = 3
BATCH_SIZE = 256

In [35]:
with open("models_hyperparameters.yaml", "r") as stream:
    models_hyperparameters = yaml.safe_load(stream)
    
x, y = utils.load_data("../data/cifar-10-batches-py/data_batch_1")
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.8, random_state=42)

train_size = x_train.shape[0]

In [36]:
results = {}

test_models = ["test_model_1", "test_model_2", "test_model_3"]

In [37]:
for test_model in test_models:
    print(test_model)
    model_scores = []

    for _ in tqdm(range(N_TRIALS)):
        model = ClassifierModel(output_shape=10, **models_hyperparameters[test_model])
        
        model.compile(
            optimizer='adam',
            loss=tf.keras.losses.SparseCategoricalCrossentropy(),
            metrics=['accuracy']
        )
        
        n_samples_step = train_size // N_EPOCHS
        for i in range(N_EPOCHS):
            n_samples = (i + 1) * n_samples_step
            samples_ids = np.random.choice(range(train_size), n_samples, replace=False)
            
            model.fit(x_train[samples_ids], y_train[samples_ids], epochs=1, batch_size=BATCH_SIZE, verbose=1)
        
        _, accuracy = model.evaluate(x_test, y_test, batch_size=BATCH_SIZE, verbose=0)
        
        model_scores.append(accuracy)
        break
    break
    results[test_model] = model_scores

test_model_1


  0%|          | 0/3 [00:00<?, ?it/s]



  0%|          | 0/3 [00:16<?, ?it/s]


In [32]:
train_size * (i + 1) * n_samples_step

132000

In [17]:
model.evaluate(x_test, y_test, batch_size=BATCH_SIZE, verbose=0)

[3.1404736042022705, 0.13699999451637268]

In [16]:
results

{'test_model_1': [0.17874999344348907,
  0.12062499672174454,
  0.17787499725818634],
 'test_model_2': [0.15512500703334808,
  0.1120000034570694,
  0.10487499833106995],
 'test_model_3': [0.11675000190734863,
  0.1366250067949295,
  0.13699999451637268]}