In [2]:
import numpy as np
import tensorflow as tf
from curriculum_learning.models.classifier_model import ClassifierModel
from curriculum_learning import utils
import yaml
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [3]:
with open("models_hyperparameters.yaml", "r") as stream:
    models_hyperparameters = yaml.safe_load(stream)
    
with open("config_tests.yaml", "r") as stream:
    config_tests = yaml.safe_load(stream)

In [4]:
N_EPOCHS = 25
N_TRIALS = 50
BATCH_SIZE = 256

CONFIG = config_tests["assessment_proba_best"]

loss = tf.keras.losses.SparseCategoricalCrossentropy()

In [5]:
x1, y1 = utils.load_cifar_data("../data/cifar-10-batches-py/")
x1 /= 255

x, _, y, _ = train_test_split(x1, y1, train_size=0.4, random_state=42)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=0.1, random_state=42)

n_classes = len(np.unique(y))
train_size = x_train.shape[0]
train_size

16000

In [6]:
assessment_model = ClassifierModel(output_shape=n_classes, **models_hyperparameters["assessment_model"])

assessment_model.compile(loss=loss, metrics=["accuracy"])
assessment_model(x_train[:1])
assessment_model.load_weights("../models/assessment_model.weights.h5")

2024-03-24 14:41:07.211640: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Max
2024-03-24 14:41:07.211667: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 32.00 GB
2024-03-24 14:41:07.211674: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 10.67 GB
2024-03-24 14:41:07.211713: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-03-24 14:41:07.211732: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
  trackable.load_own_variables(weights_store.get(inner_path))


In [7]:
x_train_sorted = x_train[np.argsort(y_train)]
y_train_sorted = y_train[np.argsort(y_train)]
_, counts = np.unique(y_train_sorted, return_counts=True)

samples_proba = utils.calculate_proba(
    assessment_model, x_train_sorted, y_train_sorted, counts, CONFIG["negative_loss"]
)

2024-03-24 14:41:10.354312: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


In [9]:
model = ClassifierModel(output_shape=n_classes, **models_hyperparameters["test_model_1"])
model.compile(optimizer="adam", loss=loss, metrics=["accuracy"])
model(x_train[0:1])
# model.save_weights("../models/default_model.weights.h5")
model.load_weights("../models/default_model.weights.h5")
model_weights = model.get_weights()

In [10]:
results = {}
model_scores = []

for _ in tqdm(range(N_TRIALS)):
    model.set_weights(model_weights)
        
    for i in range(N_EPOCHS):
        n_samples = int(np.tanh(4 * (i + 1) / N_EPOCHS) * train_size)
        
        samples_ids = utils.chose_samples(n_samples, samples_proba, CONFIG["order_type"])
        model.fit(
            x_train_sorted[samples_ids],
            y_train_sorted[samples_ids],
            # validation_data=(x_val, y_val),
            epochs=1,
            batch_size=BATCH_SIZE,
            verbose=0,
        )
        
        if CONFIG["progressive"]:
            samples_proba = utils.calculate_proba(
                model, x_train_sorted, y_train_sorted, counts, CONFIG["negative_loss"]
            )
            
    _, accuracy = model.evaluate(x_test, y_test, batch_size=BATCH_SIZE, verbose=1)
    model_scores.append(accuracy)

    results["test_model_1"] = model_scores

  0%|          | 0/50 [00:00<?, ?it/s]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 24ms/step - accuracy: 0.5210 - loss: 2.2072


  2%|▏         | 1/50 [01:03<51:56, 63.59s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6143 - loss: 1.4198


  4%|▍         | 2/50 [01:56<45:57, 57.45s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.5288 - loss: 1.8591


  6%|▌         | 3/50 [02:49<43:26, 55.45s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6141 - loss: 1.2887


  8%|▊         | 4/50 [03:43<41:56, 54.71s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5305 - loss: 1.7822


 10%|█         | 5/50 [04:36<40:34, 54.10s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5564 - loss: 1.5156


 12%|█▏        | 6/50 [05:29<39:25, 53.77s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5986 - loss: 1.4858


 14%|█▍        | 7/50 [06:22<38:24, 53.59s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6226 - loss: 1.2875


 16%|█▌        | 8/50 [07:15<37:25, 53.46s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5325 - loss: 1.9485


 18%|█▊        | 9/50 [08:09<36:33, 53.50s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5389 - loss: 1.9117


 20%|██        | 10/50 [09:02<35:37, 53.44s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6024 - loss: 1.4953


 22%|██▏       | 11/50 [09:55<34:39, 53.33s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6063 - loss: 1.2685


 24%|██▍       | 12/50 [10:47<33:27, 52.84s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5902 - loss: 1.4387


 26%|██▌       | 13/50 [11:39<32:24, 52.56s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6107 - loss: 1.3705


 28%|██▊       | 14/50 [12:31<31:28, 52.46s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5525 - loss: 1.8324


 30%|███       | 15/50 [13:23<30:29, 52.27s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5879 - loss: 1.4110


 32%|███▏      | 16/50 [14:15<29:35, 52.23s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6034 - loss: 1.4744


 34%|███▍      | 17/50 [15:07<28:40, 52.15s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6074 - loss: 1.3846


 36%|███▌      | 18/50 [15:59<27:43, 51.97s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.5820 - loss: 1.4156


 38%|███▊      | 19/50 [16:51<26:54, 52.10s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5637 - loss: 1.6313


 40%|████      | 20/50 [17:43<26:04, 52.16s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5690 - loss: 1.6369


 42%|████▏     | 21/50 [18:35<25:04, 51.89s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5965 - loss: 1.3970


 44%|████▍     | 22/50 [19:27<24:12, 51.88s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5956 - loss: 1.4780


 46%|████▌     | 23/50 [20:19<23:25, 52.05s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5871 - loss: 1.4350


 48%|████▊     | 24/50 [21:11<22:31, 51.97s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6015 - loss: 1.4252


 50%|█████     | 25/50 [22:03<21:41, 52.05s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5801 - loss: 1.4619


 52%|█████▏    | 26/50 [22:55<20:50, 52.11s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5684 - loss: 1.7631


 54%|█████▍    | 27/50 [23:47<19:58, 52.12s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.5669 - loss: 1.6007


 56%|█████▌    | 28/50 [24:40<19:06, 52.14s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5850 - loss: 1.4627


 58%|█████▊    | 29/50 [25:32<18:15, 52.16s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6218 - loss: 1.3351


 60%|██████    | 30/50 [26:24<17:23, 52.15s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.6105 - loss: 1.3301


 62%|██████▏   | 31/50 [27:16<16:31, 52.16s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5714 - loss: 1.5697


 64%|██████▍   | 32/50 [28:09<15:40, 52.27s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6035 - loss: 1.3295


 66%|██████▌   | 33/50 [29:01<14:50, 52.41s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5921 - loss: 1.5851


 68%|██████▊   | 34/50 [29:54<13:58, 52.41s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5646 - loss: 1.5991


 70%|███████   | 35/50 [30:46<13:05, 52.37s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.6476 - loss: 1.2050


 72%|███████▏  | 36/50 [31:38<12:11, 52.27s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.5916 - loss: 1.4956


 74%|███████▍  | 37/50 [32:31<11:20, 52.37s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5685 - loss: 1.5075


 76%|███████▌  | 38/50 [33:23<10:28, 52.37s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6165 - loss: 1.3737


 78%|███████▊  | 39/50 [34:16<09:38, 52.55s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5835 - loss: 1.4890


 80%|████████  | 40/50 [35:08<08:44, 52.41s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.5043 - loss: 2.0143


 82%|████████▏ | 41/50 [36:02<07:55, 52.88s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5401 - loss: 1.9424


 84%|████████▍ | 42/50 [36:56<07:05, 53.14s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.5027 - loss: 2.4109 


 86%|████████▌ | 43/50 [37:50<06:13, 53.37s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6108 - loss: 1.3234


 88%|████████▊ | 44/50 [38:43<05:19, 53.29s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5962 - loss: 1.4613


 90%|█████████ | 45/50 [39:35<04:24, 52.89s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6185 - loss: 1.3311


 92%|█████████▏| 46/50 [40:27<03:31, 52.79s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6292 - loss: 1.2703


 94%|█████████▍| 47/50 [41:20<02:37, 52.67s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5828 - loss: 1.4752


 96%|█████████▌| 48/50 [42:12<01:45, 52.52s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5797 - loss: 1.5248


 98%|█████████▊| 49/50 [43:04<00:52, 52.44s/it]

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.6178 - loss: 1.2248


100%|██████████| 50/50 [43:59<00:00, 52.79s/it]


In [11]:
results

{'test_model_1': [0.522777795791626,
  0.6100000143051147,
  0.5322222113609314,
  0.621666669845581,
  0.5380555391311646,
  0.5563889145851135,
  0.6041666865348816,
  0.6302777528762817,
  0.5338888764381409,
  0.54666668176651,
  0.602222204208374,
  0.6155555844306946,
  0.5961111187934875,
  0.6127777695655823,
  0.5461111068725586,
  0.5911111235618591,
  0.6030555367469788,
  0.6102777719497681,
  0.5861111283302307,
  0.5677777528762817,
  0.5730555653572083,
  0.6000000238418579,
  0.597777783870697,
  0.5958333611488342,
  0.605555534362793,
  0.5874999761581421,
  0.5736111402511597,
  0.570277750492096,
  0.5952777862548828,
  0.6313889026641846,
  0.6133333444595337,
  0.5902777910232544,
  0.6127777695655823,
  0.5963888764381409,
  0.5691666603088379,
  0.644444465637207,
  0.5947222113609314,
  0.5680555701255798,
  0.6133333444595337,
  0.5930555462837219,
  0.49861112236976624,
  0.5447221994400024,
  0.5022222399711609,
  0.6166666746139526,
  0.6030555367469788,
  

In [None]:
sum(results["test_model_1"]) / N_TRIALS