## Exponential Moving Average
t_i = alpha * t_{i-1} + (1 - alpha) * s_i, with a value of alpha = 0.99

In [1]:
import os
os.chdir(os.path.join(os.getcwd(), '..'))

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from src.model import mean_teacher
from keras.applications.resnet50 import ResNet50
import keras.backend as K

Using TensorFlow backend.


In [3]:
config = K.tf.ConfigProto()
config.gpu_options.allow_growth = True
session = K.tf.InteractiveSession(config=config)
K.set_learning_phase(1)

In [4]:
configs = {
    'input_shape': (224,224,3),
    'num_of_classes': 12,
    'lr':1e-3,
    'ratio':0.5
}

In [5]:
mean_teacher_model, student_model, teacher_model = mean_teacher(configs)

Instructions for updating:
Colocations handled automatically by placer.




In [None]:
mean_teacher_weights = mean_teacher_model.layers[-1].get_weights()
teacher_weights = teacher_model.get_weights()
student_weights = student_model.get_weights()

In [None]:
for i in range(len(mean_teacher_weights)):
    if not np.array_equal(mean_teacher_weights[i], student_weights[i]):
        print(np.array_equal(mean_teacher_weights[i], student_weights[i]))

In [None]:
len(mean_teacher_weights), len(student_weights), len(teacher_weights)

In [None]:
temp = student_model.get_weights()

In [None]:
ema(student_model, student_model)

In [None]:
def ema(student_model, teacher_model, alpha = 0.99):
    '''
    Calculates the exponential moving average of the student model weights and updates the teacher model weights\
    
    formula:
    t_i = alpha * t_{i-1} + (1 - alpha) * s_i, with default alpha = 0.99
    t_i = weights of teacher model in current epoch
    s_i = weights of student model in current epoch
    
    '''
    
    student_weights = student_model.get_weights()
    teacher_weights = teacher_model.get_weights()
    
    assert len(student_weights) == len(teacher_weights), 'length of student and teachers weights are not equal Please check. \n Student: {}, \n Teacher:{}'.format(len(student_weights), len(teacher_weights))
    
    new_layers = []
    for i, layers in enumerate(student_weights):
        new_layer = alpha*(teacher_weights[i]) + (1-alpha)*layers
        new_layers.append(new_layer)
    teacher_model.set_weights(new_layers)
    

# EMA lambda callback

In [6]:
from keras.callbacks import LambdaCallback

In [7]:
ema_callback = LambdaCallback(on_epoch_end=lambda : ema(student_model, teacher_model, alpha = 0.99))

## Loss function
### model architecture

![test](pictures/student_teacher_model_arch.png)

useful links:
- https://keras.io/layers/about-keras-layers/
- https://keras.io/getting-started/functional-api-guide/
- https://github.com/keras-team/keras/blob/master/keras/losses.py
- https://keras.io/getting-started/functional-api-guide/
- https://towardsdatascience.com/advanced-keras-constructing-complex-custom-losses-and-metrics-c07ca130a618
- https://stackoverflow.com/questions/38972380/keras-how-to-use-fit-generator-with-multiple-outputs-of-different-type

In [8]:
from keras.optimizers import Adam
from keras.models import Model
from keras.preprocessing.image import ImageDataGenerator

In [None]:
def categorical_crossentropy(y_true, y_pred):
    return K.categorical_crossentropy(y_true, y_pred)

In [None]:
def weighted_sum_loss(squared_difference_layer, ratio = 0.5):

    def categorical_crossentropy_custom(y_true, y_pred):
        return ratio * K.categorical_crossentropy(y_true, y_pred) + (1 - ratio)*squared_difference_layer
    
    return categorical_crossentropy_custom

# Data Generator

useful links:
- https://medium.com/@ensembledme/writing-custom-keras-generators-fe815d992c5a

In [None]:
def get_distribution(data_path, label, color):
    walker = os.walk(data_path)
    next(walker) # skip the first row
    class_freq = dict()
    for r,d,f in walker:
        class_freq[r.split('/')[-1]] = len(f)
    class_freq_df = pd.DataFrame.from_dict(class_freq, orient = 'index', columns = ['count'])
    class_freq_df.reset_index(inplace = True)
    class_freq_df.columns = [label, 'count']
    class_freq_df.sort_values('count', axis = 0, ascending=False, inplace=True)
    
    sns.catplot(x = 'count', y = label, kind = 'bar', data=class_freq_df, color = color)

In [None]:
color_1 = sns.xkcd_rgb['denim blue']
color_2 = sns.xkcd_rgb['dusty purple']

In [9]:
syn_path = os.path.join(os.getcwd(), 'reduced_data', 'synthetic')
real_path = os.path.join(os.getcwd(), 'reduced_data', 'real')

In [None]:
get_distribution(syn_path, 'synthetic', color_1)

In [None]:
get_distribution(real_path, 'real', color_2)

In [12]:
sup_gen = ImageDataGenerator()
unsup_gen = ImageDataGenerator()

sup_data_gen = sup_gen.flow_from_directory(syn_path)

Found 1200 images belonging to 12 classes.


In [13]:
configs = {'batch_size': 32,
          'target_size': (224,224)}

In [14]:
def mean_teacher_data_gen(sup_gen, unsup_gen, batch_size, target_size):
    '''
    
    '''
    
    syn_path = os.path.join(os.getcwd(), 'reduced_data', 'synthetic')
    real_path = os.path.join(os.getcwd(), 'reduced_data', 'real')

    sup_data_gen = sup_gen.flow_from_directory(real_path,
                                               target_size=target_size,
                                                    class_mode='categorical')

    unsup_data_gen = unsup_gen.flow_from_directory(syn_path,
                                                   target_size=target_size,
                                                        class_mode='categorical')
    
    while True:
    
        syn_img, syn_labels = sup_data_gen.next()
        real_img, _ = unsup_data_gen.next()

        yield [syn_img, real_img] , syn_labels

In [15]:
mean_teacher_generator = mean_teacher_data_gen(sup_gen, unsup_gen, configs['batch_size'], configs['target_size'])

In [None]:
inputs, outputs = next(mean_teacher_generator)

In [None]:
syn_img, real_img = inputs

In [None]:
class_dict = dict[(v,k) for (k,v) in sup_data_gen.class_indices]

In [None]:
fig,ax = plt.subplots(1,2)
ax[0].set_axis_off()
ax[1].set_axis_off()
ax[0].imshow(syn_img[0].astype(np.uint8))
ax[1].imshow(real_img[0].astype(np.uint8))
plt.title(class)
plt.show()

In [None]:
for i, syn_i in enumerate(syn_img):
    ax, fig = plt.subplot([2,1])
    

# Training

In [16]:
total_samples = sup_data_gen.n
batch_size = sup_data_gen.batch_size

In [17]:
total_samples//batch_size + 1

38

In [18]:
mean_teacher_model.fit_generator(mean_teacher_generator,
                                steps_per_epoch=total_samples // batch_size + 1,
                                epochs=1,
                                callbacks=[ema_callback])

Instructions for updating:
Use tf.cast instead.
Epoch 1/1
Found 1200 images belonging to 12 classes.
Found 1200 images belonging to 12 classes.


TypeError: <lambda>() takes 0 positional arguments but 2 were given