In [1]:
import numpy as np
import tensorflow as tf
from curriculum_learning.models.classifier_model import ClassifierModel
from curriculum_learning import utils
import yaml
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [2]:
with open("models_hyperparameters.yaml", "r") as stream:
    models_hyperparameters = yaml.safe_load(stream)
    
with open("config_tests.yaml", "r") as stream:
    config_tests = yaml.safe_load(stream)

In [3]:
N_EPOCHS = 25
N_TRIALS = 50
BATCH_SIZE = 512

CONFIG = config_tests["assessment_proba_best"]

loss = tf.keras.losses.SparseCategoricalCrossentropy()

In [4]:
x, y = utils.load_cifar_data("../data/cifar-10-batches-py/")

In [5]:
x /= 255

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.6, random_state=42)
x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=0.5, random_state=42)

n_classes = len(np.unique(y))
train_size = x_train.shape[0]

In [None]:
assessment_model = ClassifierModel(output_shape=n_classes, **models_hyperparameters["assessment_model"])

assessment_model.compile(optimizer="adam", loss=loss, metrics=["accuracy"])

assessment_model.fit(x_train, y_train, epochs=25, batch_size=BATCH_SIZE)

Epoch 1/25


2024-03-23 16:57:08.788768: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Max
2024-03-23 16:57:08.788794: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 32.00 GB
2024-03-23 16:57:08.788800: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 10.67 GB
2024-03-23 16:57:08.788820: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-03-23 16:57:08.788833: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
2024-03-23 16:57:09.708097: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 51ms/step - accuracy: 0.1171 - loss: 2.4632
Epoch 2/25
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 30ms/step - accuracy: 0.1818 - loss: 2.2408
Epoch 3/25
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 32ms/step - accuracy: 0.2118 - loss: 2.1368
Epoch 4/25
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 30ms/step - accuracy: 0.2382 - loss: 2.0613
Epoch 5/25
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 30ms/step - accuracy: 0.2642 - loss: 1.9944
Epoch 6/25
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 32ms/step - accuracy: 0.2892 - loss: 1.9333
Epoch 7/25
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 30ms/step - accuracy: 0.2960 - loss: 1.8961
Epoch 8/25
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 29ms/step - accuracy: 0.3049 - loss: 1.8596
Epoch 9/25
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m

In [7]:
x_train_sorted = x_train[np.argsort(y_train)]
y_train_sorted = y_train[np.argsort(y_train)]
_, counts = np.unique(y_train_sorted, return_counts=True)

samples_proba = utils.calculate_proba(
    assessment_model, x_train_sorted, y_train_sorted, counts, CONFIG["negative_loss"]
)

In [8]:
results = {}
# test_models = ["test_model_1", "test_model_2", "test_model_3"]
test_models = ["test_model_1"]


for test_model in test_models:
    print(test_model)
    model_scores = []

    for _ in tqdm(range(N_TRIALS)):
        model = ClassifierModel(output_shape=n_classes, **models_hyperparameters[test_model])
        model.compile(optimizer="adam", loss=loss, metrics=["accuracy"])

        for i in range(N_EPOCHS):
            n_samples = int(np.tanh(4 * (i + 1) / N_EPOCHS) * train_size)

            samples_ids = utils.chose_samples(n_samples, samples_proba, CONFIG["order_type"])

            model.fit(
                x_train_sorted[samples_ids],
                y_train_sorted[samples_ids],
                # validation_data=(x_val, y_val),
                epochs=1,
                batch_size=BATCH_SIZE,
                verbose=0,
            )
            
            if CONFIG["progressive"]:
                samples_proba = utils.calculate_proba(
                    model, x_train_sorted, y_train_sorted, counts, CONFIG["negative_loss"]
                )

        _, accuracy = model.evaluate(x_test, y_test, batch_size=BATCH_SIZE, verbose=1)

        model_scores.append(accuracy)

    results[test_model] = model_scores

test_model_1


  0%|          | 0/50 [00:00<?, ?it/s]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6873 - loss: 0.9240


  2%|▏         | 1/50 [01:21<1:06:35, 81.54s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - accuracy: 0.6190 - loss: 1.1069


  4%|▍         | 2/50 [02:58<1:12:21, 90.44s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5530 - loss: 1.2691


  6%|▌         | 3/50 [06:48<2:00:46, 154.17s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6405 - loss: 1.0647


  8%|▊         | 4/50 [13:27<3:12:22, 250.93s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6741 - loss: 0.9494


 10%|█         | 5/50 [15:24<2:31:56, 202.59s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6824 - loss: 0.9064


 12%|█▏        | 6/50 [17:25<2:08:12, 174.82s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6513 - loss: 0.9920


 14%|█▍        | 7/50 [19:34<1:54:37, 159.94s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.5848 - loss: 1.1918


 16%|█▌        | 8/50 [21:43<1:45:03, 150.09s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6272 - loss: 1.0260


 18%|█▊        | 9/50 [34:58<4:00:13, 351.55s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.6108 - loss: 1.1239


 20%|██        | 10/50 [37:16<3:10:28, 285.72s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6103 - loss: 1.0921


 22%|██▏       | 11/50 [39:44<2:38:22, 243.66s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5994 - loss: 1.1883


 24%|██▍       | 12/50 [42:14<2:16:08, 214.95s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6780 - loss: 0.9374


 26%|██▌       | 13/50 [44:57<2:03:01, 199.49s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6567 - loss: 1.0149


 28%|██▊       | 14/50 [47:39<1:52:53, 188.14s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5638 - loss: 1.2426


 30%|███       | 15/50 [50:30<1:46:42, 182.94s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6637 - loss: 0.9789


 32%|███▏      | 16/50 [53:47<1:46:02, 187.13s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6386 - loss: 1.0258


 34%|███▍      | 17/50 [56:46<1:41:30, 184.55s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6564 - loss: 0.9819


 36%|███▌      | 18/50 [59:53<1:38:49, 185.30s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6586 - loss: 0.9766


 38%|███▊      | 19/50 [1:03:01<1:36:16, 186.33s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6260 - loss: 1.0850


 40%|████      | 20/50 [1:06:19<1:34:55, 189.86s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6991 - loss: 0.8782


 42%|████▏     | 21/50 [1:09:37<1:32:54, 192.21s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5742 - loss: 1.1866


 44%|████▍     | 22/50 [1:13:00<1:31:12, 195.44s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5784 - loss: 1.2275


 46%|████▌     | 23/50 [1:16:25<1:29:09, 198.12s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5439 - loss: 1.4008


 48%|████▊     | 24/50 [1:20:15<1:29:59, 207.68s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6765 - loss: 0.9390


 50%|█████     | 25/50 [1:24:05<1:29:23, 214.55s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6357 - loss: 1.0405


 52%|█████▏    | 26/50 [1:27:58<1:27:59, 219.99s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6590 - loss: 0.9959


 54%|█████▍    | 27/50 [1:31:35<1:24:02, 219.26s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5921 - loss: 1.1473


 56%|█████▌    | 28/50 [1:35:42<1:23:27, 227.62s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - accuracy: 0.5470 - loss: 1.3253


 58%|█████▊    | 29/50 [1:40:16<1:24:28, 241.38s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6753 - loss: 0.9376


 60%|██████    | 30/50 [1:44:50<1:23:46, 251.33s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6856 - loss: 0.9159


 62%|██████▏   | 31/50 [1:49:29<1:22:11, 259.56s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6562 - loss: 1.0175


 64%|██████▍   | 32/50 [1:54:36<1:22:08, 273.82s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6573 - loss: 0.9954


 66%|██████▌   | 33/50 [3:43:30<10:09:41, 2151.85s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - accuracy: 0.6448 - loss: 1.0363


 68%|██████▊   | 34/50 [3:48:52<7:07:25, 1602.85s/it] 

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - accuracy: 0.6631 - loss: 0.9596


 70%|███████   | 35/50 [3:54:27<5:05:36, 1222.41s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - accuracy: 0.6035 - loss: 1.1269


 72%|███████▏  | 36/50 [3:59:58<3:42:51, 955.12s/it] 

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.6613 - loss: 0.9863


 74%|███████▍  | 37/50 [4:05:39<2:47:01, 770.87s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.4807 - loss: 1.6256


 76%|███████▌  | 38/50 [4:11:14<2:08:01, 640.08s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - accuracy: 0.6829 - loss: 0.9159


 78%|███████▊  | 39/50 [4:17:03<1:41:20, 552.81s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - accuracy: 0.6387 - loss: 1.0126


 80%|████████  | 40/50 [4:22:57<1:22:10, 493.06s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6804 - loss: 0.9127


 82%|████████▏ | 41/50 [4:31:50<1:15:45, 505.02s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6625 - loss: 0.9698


 84%|████████▍ | 42/50 [4:37:45<1:01:20, 460.01s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.6310 - loss: 1.0856


 86%|████████▌ | 43/50 [4:44:09<51:01, 437.32s/it]  

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.6554 - loss: 0.9721


 88%|████████▊ | 44/50 [4:50:28<41:57, 419.65s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.6203 - loss: 1.1073


 90%|█████████ | 45/50 [4:56:55<34:09, 409.88s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.5954 - loss: 1.1248


 92%|█████████▏| 46/50 [5:03:11<26:39, 399.88s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6485 - loss: 1.0005


 94%|█████████▍| 47/50 [5:09:45<19:54, 398.16s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - accuracy: 0.6448 - loss: 1.0180


 96%|█████████▌| 48/50 [5:16:37<13:24, 402.31s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.6742 - loss: 0.9389


 98%|█████████▊| 49/50 [5:23:13<06:40, 400.30s/it]

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.6813 - loss: 0.9067


100%|██████████| 50/50 [5:29:59<00:00, 396.00s/it]


In [9]:
results

{'test_model_1': [0.6743999719619751,
  0.6011999845504761,
  0.5415999889373779,
  0.6237999796867371,
  0.6615999937057495,
  0.6761999726295471,
  0.6462000012397766,
  0.5699999928474426,
  0.6205999851226807,
  0.5968000292778015,
  0.604200005531311,
  0.5860000252723694,
  0.6636000275611877,
  0.6452000141143799,
  0.5587999820709229,
  0.6496000289916992,
  0.631600022315979,
  0.6492000222206116,
  0.6462000012397766,
  0.6173999905586243,
  0.6868000030517578,
  0.574400007724762,
  0.5649999976158142,
  0.5410000085830688,
  0.6669999957084656,
  0.6218000054359436,
  0.6453999876976013,
  0.5839999914169312,
  0.5432000160217285,
  0.6614000201225281,
  0.6758000254631042,
  0.63919997215271,
  0.6452000141143799,
  0.6366000175476074,
  0.6517999768257141,
  0.5902000069618225,
  0.6470000147819519,
  0.4657999873161316,
  0.6711999773979187,
  0.6340000033378601,
  0.6733999848365784,
  0.656000018119812,
  0.6208000183105469,
  0.6407999992370605,
  0.61080002784729,
  

In [10]:
sum(results["test_model_1"]) / N_TRIALS

0.6244360017776489