In [2]:
import tensorflow as tf

from tensorflow.keras import datasets, layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import matplotlib.pyplot as plt
tf.device('/device:GPU:0')
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']
EPOCHS=10

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
train_images = 1./255 * train_images
test_images = 1./255 * test_images

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!ls /content/drive/MyDrive/training_checkpoints

checkpoint
checkpoint_airplane_automobile.data-00000-of-00001
checkpoint_airplane_automobile.index
checkpoint_airplane_bird.data-00000-of-00001
checkpoint_airplane_bird.index
checkpoint_airplane_cat.data-00000-of-00001
checkpoint_airplane_cat.index
checkpoint_airplane_deer.data-00000-of-00001
checkpoint_airplane_deer.index
checkpoint_airplane_dog.data-00000-of-00001
checkpoint_airplane_dog.index
checkpoint_airplane_frog.data-00000-of-00001
checkpoint_airplane_frog.index
checkpoint_airplane_horse.data-00000-of-00001
checkpoint_airplane_horse.index
checkpoint_airplane_ship.data-00000-of-00001
checkpoint_airplane_ship.index
checkpoint_airplane_truck.data-00000-of-00001
checkpoint_airplane_truck.index
checkpoint_correct.data-00000-of-00001
checkpoint_correct.index
tmp


In [4]:
import numpy as np

def introduce_confusion_on_one_class(train_labels, class_to_error, cur_error_rate=1):
    train_labels = train_labels.flatten()
    should_error = (np.random.rand(train_labels.shape[0]) < cur_error_rate)
    should_error =  np.logical_and(should_error, train_labels == class_to_error)
    random_labels = np.random.randint(10, size=train_labels.shape[0])
    return np.choose(should_error, [train_labels.flatten(), random_labels])
# new_labels = introduce_confusion_on_one_class(train_labels, 1)
# plt.hist(new_labels, bins = 20)
# plt.show()

def introduce_confusion_between_two_classes(train_labels, c1, c2, cur_error_rate=1):
    train_labels = train_labels.flatten()
    should_error = (np.random.rand(train_labels.shape[0]) < cur_error_rate)
    flip_classes = np.logical_or(train_labels == c1, train_labels == c2)
    should_error = np.logical_and(should_error, flip_classes)
    flipped = (train_labels == c2) * c1 + (train_labels == c1) * c2
    return np.choose(should_error, [train_labels.flatten(), flipped])

# new_labels = introduce_confusion_between_two_classes(train_labels, 1, 2)
# plt.hist(new_labels, bins = 20)
# plt.show()
#print(new_labels[:10])

def build_and_compile_new_model(checkpoint_location):
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10))

    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath='/content/drive/MyDrive/training_checkpoints/'+checkpoint_location,
        save_weights_only=True,
        monitor='val_accuracy',
        mode='max',
        save_best_only=True)

    model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
    
    return model, model_checkpoint_callback


In [9]:
import os
touch_prefix = '/content/drive/MyDrive/training_checkpoints/touch'
def touch(point):
    with open(touch_prefix+point, 'a'):
        os.utime(touch_prefix+point, None)

def exists(point):
    return os.path.exists(touch_prefix+point)

def train_experiment(name, train_labels, silent=False, load=False):
    model, model_checkpoint_callback = build_and_compile_new_model(name)

    if not exists(name):
        model.fit(train_images, train_labels, epochs=EPOCHS, 
             validation_data=(test_images, test_labels), 
             callbacks=[model_checkpoint_callback], verbose=(not silent) * 1)

    if load:
        model.load_weights('/content/drive/MyDrive/training_checkpoints/'+name)
    touch(name)
    return model

In [17]:
train_experiment('checkpoint_correct', train_labels)

<tensorflow.python.keras.engine.sequential.Sequential at 0x7fb700a0ac88>

In [27]:
import itertools, tqdm

#fucked up model for all combos
for fst, snd in tqdm.tqdm(list(itertools.combinations(class_names, 2))[:15]):
    first_class = class_names.index(fst)
    second_class = class_names.index(snd)
    confused_train_labels = introduce_confusion_between_two_classes(train_labels, first_class, second_class)
    train_experiment(f'checkpoint_{fst}_{snd}', confused_train_labels, silent=True)

 13%|█▎        | 2/15 [01:36<10:25, 48.11s/it]



100%|██████████| 15/15 [12:08<00:00, 48.58s/it]


In [12]:
import itertools, tqdm

#fucked up model for all combos
for fst, snd in tqdm.tqdm(list(itertools.combinations(class_names, 2))):
    for i in range(50):
        first_class = class_names.index(fst)
        second_class = class_names.index(snd)
        confused_train_labels = introduce_confusion_between_two_classes(train_labels, first_class, second_class)
        train_experiment(f'checkpoint_{fst}_{snd}_{i}', confused_train_labels, silent=True, load=False)


  0%|          | 0/45 [00:00<?, ?it/s][A
  2%|▏         | 1/45 [00:02<01:57,  2.67s/it][A
  4%|▍         | 2/45 [00:05<01:54,  2.67s/it][A
  7%|▋         | 3/45 [00:08<01:52,  2.68s/it][A
  9%|▉         | 4/45 [00:10<01:49,  2.68s/it][A
 11%|█         | 5/45 [00:13<01:47,  2.70s/it][A
 13%|█▎        | 6/45 [00:16<01:52,  2.89s/it][A
 16%|█▌        | 7/45 [00:19<01:47,  2.82s/it][A
 18%|█▊        | 8/45 [00:22<01:42,  2.77s/it][A
 20%|██        | 9/45 [00:24<01:38,  2.73s/it][A
 22%|██▏       | 10/45 [00:27<01:35,  2.72s/it][A
 24%|██▍       | 11/45 [00:30<01:32,  2.71s/it][A
 27%|██▋       | 12/45 [00:32<01:28,  2.69s/it][A
 29%|██▉       | 13/45 [00:35<01:25,  2.68s/it][A
 31%|███       | 14/45 [00:38<01:30,  2.91s/it][A
 33%|███▎      | 15/45 [00:41<01:25,  2.84s/it][A
 36%|███▌      | 16/45 [00:44<01:21,  2.80s/it][A
 38%|███▊      | 17/45 [00:46<01:17,  2.78s/it][A
 40%|████      | 18/45 [00:49<01:13,  2.74s/it][A
 42%|████▏     | 19/45 [00:52<01:10,  2.71s/it]