In [4]:
import numpy as np
import tensorflow as tf
from curriculum_learning.models.classifier_model import ClassifierModel
from curriculum_learning import utils
import yaml
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import seaborn as sns

In [5]:
with open("models_hyperparameters.yaml", "r") as stream:
    models_hyperparameters = yaml.safe_load(stream)
    
with open("config_tests.yaml", "r") as stream:
    config_tests = yaml.safe_load(stream)
    
N_EPOCHS = 50
N_TRIALS = 10
BATCH_SIZE = 512

CONFIG = config_tests["proba_best"]

loss = tf.keras.losses.SparseCategoricalCrossentropy()

In [6]:
ds_1 = tfds.load("stl10", split="train", as_supervised=True, shuffle_files=False)
ds_2 = tfds.load("stl10", split="test", as_supervised=True, shuffle_files=False)

x = []
y = []
for x_, y_ in ds_1.as_numpy_iterator():
    x.append(x_)
    y.append(y_)
for x_, y_ in ds_2.as_numpy_iterator():
    x.append(x_)
    y.append(y_)
    
x = np.array(x, dtype=np.float32) / 255
y = np.array(y, dtype=np.float32)

2024-04-28 13:21:04.784887: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-28 13:21:05.582573: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [7]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.4, random_state=42, stratify=y)
x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=0.5, random_state=42, stratify=y_test)

n_classes = len(np.unique(y))
train_size = x_train.shape[0]
train_size, len(x_val), len(x_test)

(7800, 2600, 2600)

In [8]:
x_train_sorted = x_train[np.argsort(y_train)]
y_train_sorted = y_train[np.argsort(y_train)]
_, counts = np.unique(y_train_sorted, return_counts=True)

In [9]:
model = ClassifierModel(output_shape=n_classes, **models_hyperparameters["test_model"])
model.compile(optimizer="adam", loss=loss, metrics=["accuracy"])
model(x_train[0:1])
# model.save_weights("../models/default_model.weights.h5")
model.load_weights("../models/default_model.weights.h5")
model_weights = model.get_weights()

In [23]:
model_scores = []
verbose = 0
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', restore_best_weights=True, patience=5, start_from_epoch=0)

for _ in tqdm(range(N_TRIALS)):
    model.set_weights(model_weights)
    
    # for j in range(len(model.layers[:-1])):
    #     model.layers[j].layers[-1].rate = 0.1
    
    for i in range(N_EPOCHS):
        n_samples = int(np.tanh(4 * (i + 1) / N_EPOCHS) * train_size)
        
        samples_proba = utils.calculate_proba(
            model, x_train_sorted, y_train_sorted, counts
        )
        
        samples_ids = utils.chose_samples(n_samples, samples_proba, CONFIG["order_type"])
        model.fit(
            x_train_sorted[samples_ids],
            y_train_sorted[samples_ids],
            # validation_data=(x_val, y_val),
            epochs=1,
            batch_size=BATCH_SIZE,
            verbose=verbose,
        )
        
        # for j in range(len(model.layers[:-1])):
        #     model.layers[j].layers[-1].rate += 0.003
        # 

    model.fit(
        x_train, y_train, validation_data=(x_val, y_val), epochs=500, batch_size=BATCH_SIZE, 
        callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss', restore_best_weights=True, patience=5, start_from_epoch=10)], 
        verbose=verbose
    )
            
    _, accuracy = model.evaluate(x_test, y_test, batch_size=BATCH_SIZE, verbose=verbose)
    model_scores.append(accuracy)
    print("Mean:", np.mean(model_scores), " Median: ", np.median(model_scores))

 10%|█         | 1/10 [02:23<21:33, 143.72s/it]

Mean: 0.5692307949066162  Median:  0.5692307949066162


 20%|██        | 2/10 [04:53<19:36, 147.06s/it]

Mean: 0.5778846144676208  Median:  0.5778846144676208


 30%|███       | 3/10 [07:27<17:33, 150.53s/it]

Mean: 0.5769230723381042  Median:  0.574999988079071


 40%|████      | 4/10 [09:55<14:57, 149.54s/it]

Mean: 0.577692300081253  Median:  0.5774999856948853


 50%|█████     | 5/10 [12:19<12:16, 147.32s/it]

Mean: 0.5795384526252747  Median:  0.5799999833106995


 50%|█████     | 5/10 [12:35<12:35, 151.19s/it]


KeyboardInterrupt: 

In [20]:
model_scores

[0.5646153688430786,
 0.5557692050933838,
 0.5773077011108398,
 0.5846154093742371]

In [15]:
model_scores

[0.5776923298835754, 0.579230785369873, 0.5892307758331299, 0.5807692408561707]

In [22]:
print("Mean:", np.mean(model_scores), " Median: ", np.median(model_scores))

Mean: 0.5750961601734161  Median:  0.577115386724472


In [18]:
print("Mean:", np.mean(model_scores), " Median: ", np.median(model_scores))

Mean: 0.5817307829856873  Median:  0.5800000131130219


In [10]:
model_scores_random = []
verbose = 0

for _ in tqdm(range(50)):
    model.set_weights(model_weights)

    model.fit(
        x_train, y_train, validation_data=(x_val, y_val), epochs=500, batch_size=BATCH_SIZE, verbose=verbose, 
        callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss', restore_best_weights=True, patience=5, start_from_epoch=35)]
    )

    _, accuracy = model.evaluate(x_test, y_test, batch_size=BATCH_SIZE, verbose=verbose)
    model_scores_random.append(accuracy)
    print("Mean:", np.mean(model_scores_random), " Median: ", np.median(model_scores_random))

  2%|▏         | 1/50 [00:42<34:22, 42.09s/it]

Mean: 0.5819230675697327  Median:  0.5819230675697327


  4%|▍         | 2/50 [01:19<31:33, 39.44s/it]

Mean: 0.5692307651042938  Median:  0.5692307651042938


  6%|▌         | 3/50 [01:56<30:05, 38.41s/it]

Mean: 0.567307690779368  Median:  0.5634615421295166


  8%|▊         | 4/50 [02:38<30:32, 39.84s/it]

Mean: 0.5689423084259033  Median:  0.5686538517475128


 10%|█         | 5/50 [03:17<29:36, 39.48s/it]

Mean: 0.5682307720184326  Median:  0.5653846263885498


 12%|█▏        | 6/50 [03:57<29:00, 39.56s/it]

Mean: 0.5676923096179962  Median:  0.565192312002182


 14%|█▍        | 7/50 [04:41<29:25, 41.07s/it]

Mean: 0.5670879142624992  Median:  0.5649999976158142


 16%|█▌        | 8/50 [05:29<30:10, 43.10s/it]

Mean: 0.5679326951503754  Median:  0.565192312002182


 18%|█▊        | 9/50 [06:07<28:23, 41.56s/it]

Mean: 0.5690598289171854  Median:  0.5653846263885498


 20%|██        | 10/50 [06:45<26:59, 40.48s/it]

Mean: 0.5690384626388549  Median:  0.5671153962612152


 22%|██▏       | 11/50 [07:27<26:37, 40.95s/it]

Mean: 0.5711888129060919  Median:  0.5688461661338806


 24%|██▍       | 12/50 [08:07<25:44, 40.65s/it]

Mean: 0.5699679503838221  Median:  0.5671153962612152


 26%|██▌       | 13/50 [08:43<24:18, 39.41s/it]

Mean: 0.56937870154014  Median:  0.5653846263885498


 28%|██▊       | 14/50 [09:25<23:58, 39.96s/it]

Mean: 0.5696153896195548  Median:  0.5671153962612152


 30%|███       | 15/50 [10:00<22:34, 38.69s/it]

Mean: 0.5687179525693258  Median:  0.5653846263885498


 32%|███▏      | 16/50 [10:33<20:50, 36.79s/it]

Mean: 0.5668990425765514  Median:  0.565192312002182


 34%|███▍      | 17/50 [11:16<21:15, 38.65s/it]

Mean: 0.5673981940045076  Median:  0.5653846263885498


 36%|███▌      | 18/50 [11:52<20:12, 37.90s/it]

Mean: 0.5671367545922598  Median:  0.565192312002182


 38%|███▊      | 19/50 [12:26<18:56, 36.65s/it]

Mean: 0.5670242936987626  Median:  0.5649999976158142


 40%|████      | 20/50 [13:03<18:22, 36.74s/it]

Mean: 0.5676153868436813  Median:  0.565192312002182


 42%|████▏     | 21/50 [13:42<18:06, 37.48s/it]

Mean: 0.5676373640696207  Median:  0.5653846263885498


 44%|████▍     | 22/50 [14:20<17:36, 37.75s/it]

Mean: 0.5676223798231645  Median:  0.5663461685180664


 46%|████▌     | 23/50 [14:59<17:07, 38.06s/it]

Mean: 0.5676755879236304  Median:  0.567307710647583


 48%|████▊     | 24/50 [15:32<15:52, 36.64s/it]

Mean: 0.5670352578163147  Median:  0.5663461685180664


 50%|█████     | 25/50 [16:12<15:37, 37.51s/it]

Mean: 0.5672769236564636  Median:  0.567307710647583


 52%|█████▏    | 26/50 [17:00<16:20, 40.87s/it]

Mean: 0.5678994082487546  Median:  0.5676923096179962


 54%|█████▍    | 27/50 [17:37<15:07, 39.48s/it]

Mean: 0.567606837661178  Median:  0.567307710647583


 56%|█████▌    | 28/50 [18:11<13:51, 37.81s/it]

Mean: 0.5673489017145974  Median:  0.5663461685180664


 58%|█████▊    | 29/50 [18:56<14:03, 40.16s/it]

Mean: 0.5677320957183838  Median:  0.567307710647583


 60%|██████    | 30/50 [19:33<13:00, 39.02s/it]

Mean: 0.5681410253047943  Median:  0.5676923096179962


 62%|██████▏   | 31/50 [20:10<12:14, 38.68s/it]

Mean: 0.5680645165904876  Median:  0.567307710647583


 64%|██████▍   | 32/50 [20:44<11:07, 37.08s/it]

Mean: 0.5677884612232447  Median:  0.5665384829044342


 66%|██████▌   | 33/50 [21:17<10:07, 35.76s/it]

Mean: 0.5673310016140793  Median:  0.5657692551612854


 68%|██████▊   | 34/50 [21:52<09:28, 35.53s/it]

Mean: 0.5672285065931433  Median:  0.5655769407749176


 70%|███████   | 35/50 [22:35<09:30, 38.04s/it]

Mean: 0.5673736265727451  Median:  0.5657692551612854


 72%|███████▏  | 36/50 [23:15<08:57, 38.40s/it]

Mean: 0.5676923079623116  Median:  0.5665384829044342


 74%|███████▍  | 37/50 [23:51<08:11, 37.80s/it]

Mean: 0.5674324325613074  Median:  0.5657692551612854


 76%|███████▌  | 38/50 [24:27<07:28, 37.38s/it]

Mean: 0.5674898781274494  Median:  0.5665384829044342


 78%|███████▊  | 39/50 [25:03<06:45, 36.90s/it]

Mean: 0.5673865874608358  Median:  0.5657692551612854


 80%|████████  | 40/50 [25:37<05:58, 35.82s/it]

Mean: 0.5673653841018677  Median:  0.5661538541316986


 82%|████████▏ | 41/50 [26:10<05:17, 35.24s/it]

Mean: 0.5667823640311637  Median:  0.5657692551612854


 84%|████████▍ | 42/50 [26:43<04:35, 34.39s/it]

Mean: 0.5665934071654365  Median:  0.5655769407749176


 86%|████████▌ | 43/50 [27:15<03:55, 33.64s/it]

Mean: 0.5659123437349186  Median:  0.5653846263885498


 88%|████████▊ | 44/50 [27:54<03:31, 35.20s/it]

Mean: 0.5662412589246576  Median:  0.5655769407749176


 90%|█████████ | 45/50 [28:33<03:02, 36.50s/it]

Mean: 0.5661794874403212  Median:  0.5653846263885498


 92%|█████████▏| 46/50 [29:11<02:27, 36.85s/it]

Mean: 0.5663963219393855  Median:  0.5655769407749176


 94%|█████████▍| 47/50 [29:54<01:56, 38.89s/it]

Mean: 0.5665957458475803  Median:  0.5657692551612854


 96%|█████████▌| 48/50 [30:33<01:17, 38.75s/it]

Mean: 0.566834936539332  Median:  0.5661538541316986


 98%|█████████▊| 49/50 [31:06<00:37, 37.00s/it]

Mean: 0.5666954480871862  Median:  0.5657692551612854


100%|██████████| 50/50 [31:45<00:00, 38.10s/it]

Mean: 0.5667923080921173  Median:  0.5661538541316986





In [11]:
model_scores_random

[0.5819230675697327,
 0.556538462638855,
 0.5634615421295166,
 0.573846161365509,
 0.5653846263885498,
 0.5649999976158142,
 0.5634615421295166,
 0.573846161365509,
 0.5780768990516663,
 0.5688461661338806,
 0.5926923155784607,
 0.556538462638855,
 0.5623077154159546,
 0.572692334651947,
 0.5561538338661194,
 0.5396153926849365,
 0.5753846168518066,
 0.5626922845840454,
 0.5649999976158142,
 0.5788461565971375,
 0.5680769085884094,
 0.567307710647583,
 0.5688461661338806,
 0.552307665348053,
 0.5730769038200378,
 0.5834615230560303,
 0.5600000023841858,
 0.5603846311569214,
 0.5784615278244019,
 0.5799999833106995,
 0.5657692551612854,
 0.5592307448387146,
 0.5526922941207886,
 0.5638461709022522,
 0.5723077058792114,
 0.5788461565971375,
 0.5580769181251526,
 0.569615364074707,
 0.5634615421295166,
 0.5665384531021118,
 0.5434615612030029,
 0.5588461756706238,
 0.5373076796531677,
 0.5803846120834351,
 0.5634615421295166,
 0.5761538743972778,
 0.5757692456245422,
 0.5780768990516663,


In [17]:
np.mean(model_scores), np.mean(model_scores_random)

(0.5817307829856873, 0.5667923080921173)

In [16]:
import scipy.stats
scipy.stats.ttest_ind(model_scores, model_scores_random)

TtestResult(statistic=2.654683049658429, pvalue=0.010509567980683561, df=52.0)