## Handling Unbalanced Datasets using Class Weights in Keras

We will be working with MNIST dataset to illustrate a simple strategy to handle unbalanced datasets.  

### Lets import the required libraries

In [65]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dropout, Flatten, Dense
from tensorflow.keras.utils import to_categorical
from sklearn.utils import class_weight
import numpy as np

### Lets define a method to load the mnist dataset

In [66]:
def load_datasets():
    (X_train, Y_train), (X_test, Y_test) = mnist.load_data()
    return X_train, Y_train, X_test, Y_test

### define a method to normalize the image data

In [67]:
def normalize_data(trainX, testX):
    #reshape the data
    X_train = trainX.reshape((trainX.shape[0], 28, 28, 1))
    X_test = testX.reshape((testX.shape[0], 28, 28, 1))
    # converting to floats
    train_norm = X_train.astype('float32')
    test_norm = X_test.astype('float32')
    # normalize in b2in 0,1
    X_train_norm = train_norm / 255.0
    X_test_norm = test_norm / 255.0
    
    return X_train_norm, X_test_norm

### Lets define a method to convert the integer label to one hot encoding

In [68]:
def transform_Y(Y_train, Y_test):
    Y_train_transform = to_categorical(Y_train)
    Y_test_transform = to_categorical(Y_test)
    
    return Y_train_transform, Y_test_transform

### Here we define a very simple CNN network for digit classification task

In [69]:
#lets define a simple CNN for mnist digit classification
def build_model():
    model = Sequential()
    
    model.add(Conv2D(32, (3, 3), activation = 'relu', input_shape = (28, 28, 1)))
    model.add(MaxPooling2D(2, 2))
    model.add(Dropout(0.5))
    
    model.add(Conv2D(64, (3, 3), activation = 'relu'))
    model.add(MaxPooling2D(2, 2))
    model.add(Dropout(0.5))
    
    model.add(Conv2D(64, (3, 3), activation = 'relu'))
    model.add(Flatten())
    
    model.add(Dense(128, activation = 'relu'))
    model.add(Dropout(0.5))
    
    model.add(Dense(10, activation = 'softmax'))
    
    return model

### We use the weight computing method from Scikit-learn

In [70]:
def get_class_weights(Y_train):
    weights = class_weight.compute_class_weight('balanced', np.unique(Y_train), Y_train)
    class_weights = {}
    
    for i in range(10):
        class_weights[i] = weights[i]
    
    return class_weights

### Lets do it now... :P

In [71]:
#loads the datasets
X_train, Y_train, X_test, Y_test = load_datasets()

#normalizing the x values
X_train_norm, X_test_norm = normalize_data(X_train, X_test)

#transorm the y values to one hot
Y_train_transform, Y_test_transform = transform_Y(Y_train, Y_test)

## Lets first train the model without class weight and test the model on test data

In [72]:
model_without_class_weight = build_model()

In [73]:
#compile the model
model_without_class_weight.compile(optimizer = 'rmsprop', loss = 'categorical_crossentropy', metrics = ['acc'])

In [74]:
#train the model without class weigths
history = model_without_class_weight.fit(X_train_norm, Y_train_transform, epochs = 10, batch_size = 56, validation_split=0.1)

Train on 54000 samples, validate on 6000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


### Lets test the model without class weight

In [75]:
model_without_class_weight.evaluate(X_test_norm, Y_test_transform)



[0.25026279188394546, 0.9727]

## Lets train the model with class weight and test the model on test data again

In [76]:
model_with_class_weight = build_model()

In [77]:
model_with_class_weight.compile(optimizer = 'rmsprop', loss = 'categorical_crossentropy', metrics = ['acc'])

In [78]:
# get the class weights
class_weights = get_class_weights(Y_train)

In [79]:
#train the model without class weigths
history = model_with_class_weight.fit(X_train_norm, Y_train_transform, epochs = 10, batch_size = 56, validation_split=0.1, class_weight=class_weights)

Train on 54000 samples, validate on 6000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


### Lets test the model with class weight

In [80]:
model_with_class_weight.evaluate(X_test_norm, Y_test_transform)



[0.04339111114250263, 0.9866]

### See the improvement..  :D 

| Model Type | Test Accuracy |
|---------------------|----------------------|
| Model without class weight | 0.9727 |
| Model with class weight | 0.9866 |