implementation of _Co-teaching: Robust Training of Deep Neural Networks with Extremely Noisy Labels_, Bo Han et al., 2018, https://arxiv.org/abs/1804.06872

In [1]:
from keras.models import Sequential
from keras.layers import (Conv2D, Dense, MaxPooling2D, GlobalAveragePooling2D,
                          Dropout, BatchNormalization, Input, LeakyReLU,
                          AveragePooling2D)
from keras.datasets import cifar10, mnist
from keras.utils import to_categorical
from sklearn.utils import shuffle
import sys
import numpy as np

Using TensorFlow backend.


In [2]:
np.seterr('raise')

{'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}

In [3]:
def make_noisy_labels_simmetry(labels, eps):
    assert len(labels.shape) == 1
    labels = labels.ravel()
    classes = set(labels)
    randoms = np.random.random(labels.shape)
    
    noisy = [
        lbl if r < eps else np.random.choice([x for x in classes if x != lbl])
        for lbl, r in zip(labels, randoms)
    ]
    
    return to_categorical(np.array(noisy).reshape((-1, 1)))

In [4]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [5]:
x_train = x_train.reshape((-1, 28, 28, 1))
x_test = x_test.reshape((-1, 28, 28, 1))

y_train_noisy = make_noisy_labels_simmetry(y_train, 0.45)
y_train = to_categorical(y_train)

y_test_noisy = make_noisy_labels_simmetry(y_test, 0.45)
y_test = to_categorical(y_test)

In [6]:
def make_model():
    model = Sequential([
        Conv2D(4, (3, 3), padding='same', input_shape=(28, 28, 1)), LeakyReLU(0.01), BatchNormalization(),
        Conv2D(4, (3, 3), padding='same'), LeakyReLU(0.01), BatchNormalization(),
        Conv2D(4, (3, 3), padding='same'), LeakyReLU(0.01),
        MaxPooling2D((2, 2), strides=2), BatchNormalization(),
        Conv2D(8, (3, 3), padding='same'), LeakyReLU(0.01), BatchNormalization(),
        Conv2D(8, (3, 3), padding='same'), LeakyReLU(0.01), BatchNormalization(),
        Conv2D(8, (3, 3), padding='same'), LeakyReLU(0.01),
        MaxPooling2D((2, 2)), BatchNormalization(),
        Conv2D(16, (3, 3), padding='same'), LeakyReLU(0.01), BatchNormalization(),
        Conv2D(16, (3, 3), padding='same'), LeakyReLU(0.01), BatchNormalization(),
        Conv2D(16, (3, 3), padding='same'), LeakyReLU(0.01),
        GlobalAveragePooling2D(), BatchNormalization(),
        Dense(12), LeakyReLU(0.01), BatchNormalization(),
        Dense(10, activation='softmax'),
    ])
    
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    
    return model

In [9]:
class CoTeach:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)
    
    @staticmethod
    def samplewise_categorical_crossentropy(y_true, y_pred):
        return -np.log(y_pred[y_true == 1] + 1e-12)
    
    def get_clean_instances(self, batch_x, batch_y, model, r):
        ''' returns the 100*r % samples with the lowest loss
        '''
        preds = model.predict(batch_x)
        losses = self.samplewise_categorical_crossentropy(batch_y, preds)
        threshold = np.quantile(losses, r)
        mask = losses < threshold
        return batch_x[mask], batch_y[mask]
    
    def train_on_batch(self, epoch, batch_x, batch_y, models):
        ''' for every model, get the clean samples and train the other models on them
        '''
        r = 1 - min(self.tau * epoch / self.tk, self.tau)

        losses, accs = np.zeros(len(models)), np.zeros(len(models))
        for i, m in enumerate(models):
            bx, by = self.get_clean_instances(batch_x, batch_y, m, r)
            
            for j, m2 in enumerate(models):
                if i != j:
                    loss, acc = m2.train_on_batch(bx, by)
                    losses[j] += loss
                    accs[i] += acc
        
        return losses / len(models), accs / len(models), r

    def train(self, X, y, models, validation_data=None):
        N = X.shape[0]
        
        epoch_losses, epoch_accs = [], []
        for epoch in range(self.n_epochs):

            X, y = shuffle(X, y)
            for batch_start in range(0, N ,self.batch_size):
                losses, accs, r = self.train_on_batch(
                    epoch,
                    X[batch_start:batch_start+self.batch_size],
                    y[batch_start:batch_start+self.batch_size],
                    models
                )
                
                message = 'Epoch %d / %d: %d / %d | R: %.4f - Avg. loss: %.4f - Avg. acc.: %.4f' % (
                    epoch + 1, self.n_epochs, batch_start, N, r, np.mean(losses), np.mean(accs)
                )
                print('\r' + message, end='', flush=True)  # does not work if \r is at the end

            # score on validation set
            if not validation_data:
                continue

            x_test, y_test = validation_data
            loss_sum = acc_sum = 0.0
            for m in models:
                loss, acc = m.evaluate(x_test, y_test, verbose=0)
                loss_sum += loss
                acc_sum += acc

            print(' | Avg. val. loss: %.4f - Avg. val. acc.: %.4f' % (
                loss_sum / len(models), acc_sum / len(models)
            ))


CoTeach(
    tau=0.45,
    tk=5,
    batch_size=32,
    n_epochs=10,
).train(
    x_train[:1024],
    y_train_noisy[:1024],
    models=[make_model(), make_model()],
    validation_data=(x_test[:128], y_test_noisy[:128])
)

Epoch 1 / 10: 992 / 1024 | R: 1.0000 - Avg. loss: 1.2059 - Avg. acc.: 0.0726 | Avg. val. loss: 2.4611 - Avg. val. acc.: 0.1367
Epoch 2 / 10: 992 / 1024 | R: 0.9100 - Avg. loss: 1.1265 - Avg. acc.: 0.1121 | Avg. val. loss: 2.3970 - Avg. val. acc.: 0.1562
Epoch 3 / 10: 992 / 1024 | R: 0.8200 - Avg. loss: 0.9445 - Avg. acc.: 0.2308 | Avg. val. loss: 2.3146 - Avg. val. acc.: 0.1914
Epoch 4 / 10: 992 / 1024 | R: 0.7300 - Avg. loss: 0.9093 - Avg. acc.: 0.1957 | Avg. val. loss: 2.3213 - Avg. val. acc.: 0.2305
Epoch 5 / 10: 992 / 1024 | R: 0.6400 - Avg. loss: 0.7239 - Avg. acc.: 0.2875 | Avg. val. loss: 2.2955 - Avg. val. acc.: 0.2148
Epoch 6 / 10: 992 / 1024 | R: 0.5500 - Avg. loss: 0.6975 - Avg. acc.: 0.2778 | Avg. val. loss: 2.2952 - Avg. val. acc.: 0.2500
Epoch 7 / 10: 992 / 1024 | R: 0.5500 - Avg. loss: 0.7003 - Avg. acc.: 0.2500 | Avg. val. loss: 2.3735 - Avg. val. acc.: 0.2188
Epoch 8 / 10: 992 / 1024 | R: 0.5500 - Avg. loss: 0.4440 - Avg. acc.: 0.4306 | Avg. val. loss: 2.3735 - Avg. va