### Imports

In [105]:
from keras.models import Model
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import CSVLogger, Callback
from datetime import datetime
import keras.backend as K
import extras.ourUtils as utils
import numpy as np
import Models
import sys
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklear.model_selection import KFold


### Init

In [77]:
batch_size = 20
nrEpochs = 10
full_train = True
path = '/home/jaskmo/Documents/programering/02456DomainAdaptation/'
source_data = path + 'taperImages/pysNetData'
target_data = path + 'taperImages/hData'
stdout_cell = sys.stdout
MIQ = ['DA', 'target', 'source']
kf = KFold(n_splits = 10)
n_subjects_phys = 20;
n_subjects_hosp = 37;

## Get data as generators

In [3]:
datagen = ImageDataGenerator(rescale=1./255)

# make a data generator for dplInput
def train_gen_DAnet(source, target, batch_size):
    half = batch_size//2
    while True:
        source_data, source_lable = source.next()
        target_data, target_lable = target.next()
        if len(source_lable) != batch_size or len(target_lable) != batch_size:
            continue
        dpl_data = np.concatenate((source_data[:half,...],target_data[:half,...]),axis=0)
               
        domain_tmp = np.ones(batch_size, dtype='int8')
        domain_tmp[half:] = domain_tmp[half:] * 0
        dpl_lable = np.concatenate((domain_tmp.reshape(batch_size,1),
                                       np.flip(domain_tmp,0).reshape(batch_size,1)),1)

        yield({'lplInput':source_data,'dplInput':dpl_data}, {'lplOut':source_lable,'dplOut':dpl_lable})
        
def test_gen_DAnet(source, target, batch_size):
    half = batch_size//2
    while True:
        source_data, source_lable = source.next()
        target_data, target_lable = target.next()
        if len(source_lable) != batch_size or len(target_lable) != batch_size:
            continue
        dpl_data = np.concatenate((source_data[:half,...],target_data[:half,...]),axis=0)
               
        domain_tmp = np.ones(batch_size, dtype='int8')
        domain_tmp[half:] = domain_tmp[half:] * 0
        dpl_lable = np.concatenate((domain_tmp.reshape(batch_size,1),
                                       np.flip(domain_tmp,0).reshape(batch_size,1)),1)

        yield({'lplInput':target_data,'dplInput':dpl_data}, {'lplOut':target_lable,'dplOut':dpl_lable})

In [None]:
for item in MIQ:
    tt
    for train_index, test_index in kf.split(n_subjects_phys):
        create_data_split(path, test_index)
        
        

#### Train data

In [4]:

train_gen_source = datagen.flow_from_directory(source_data + '/train', target_size=(224, 224), 
                                               batch_size=batch_size, class_mode='categorical', shuffle=True)

train_gen_target = datagen.flow_from_directory(target_data + '/train', target_size=(224, 224), 
                                               batch_size=batch_size, class_mode='categorical', shuffle=True)

train_gen_DA = train_gen_DAnet(train_gen_source, train_gen_target, batch_size)

train_stepE = np.floor_divide(train_gen_source.n, batch_size)

Found 29772 images belonging to 5 classes.
Found 10629 images belonging to 5 classes.


#### validation data

In [5]:
valid_gen_source = datagen.flow_from_directory(source_data + '/validation', target_size=(224, 224), 
                                               batch_size=batch_size, class_mode='categorical', shuffle=True)
valid_gen_target = datagen.flow_from_directory(target_data + '/validation', target_size=(224, 224), 
                                               batch_size=batch_size, class_mode='categorical', shuffle=True)

valid_gen_DA = test_gen_DAnet(valid_gen_source, valid_gen_target, batch_size)

val_stepE = np.floor_divide(valid_gen_source.n, batch_size)

Found 4807 images belonging to 5 classes.
Found 2838 images belonging to 5 classes.


#### test data

In [6]:
test_gen_source = datagen.flow_from_directory(source_data + '/test', target_size=(224, 224), 
                                               batch_size=batch_size, class_mode='categorical', shuffle=True)
test_gen_target = datagen.flow_from_directory(target_data + '/test', target_size=(224, 224), 
                                               batch_size=batch_size, class_mode='categorical', shuffle=True)

test_gen_DA = test_gen_DAnet(test_gen_source, test_gen_target, batch_size)

test_stepE = np.floor_divide(test_gen_source.n, batch_size)

Found 3862 images belonging to 5 classes.
Found 2722 images belonging to 5 classes.


### get model

In [7]:
# init. the variable to controle the flipgradient layer
lamFunk = K.variable(0.0)
current_model = Models.DA_model(lamFunk)

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
lplInput (InputLayer)            (None, 224, 224, 3)   0                                            
____________________________________________________________________________________________________
dplInput (InputLayer)            (None, 224, 224, 3)   0                                            
____________________________________________________________________________________________________
model_1 (Model)                  (None, 25088)         14714688    lplInput[0][0]                   
                                                                   dplInput[0][0]                   
____________________________________________________________________________________________________
flipGrad (Lambda)                (None, 25088)         0           model_1[2][0]           

### Callbacks

In [8]:
now = datetime.now()
csv_logger = CSVLogger('/media/jaskmo/ELEK/bme/Project02456/trainingLog/DA_Model' + 
                           str(now.day) + '-' + str(now.month) + '-' + str(now.year) + '_' + 
                           str(now.hour) + str(now.minute) + '.log')

class FlipControle(Callback):
    def __init__(self, alphaIn):
        self.alpha = alphaIn
        print(K.get_value(lamFunk))
        
    def on_epoch_end(self, epoch, logs={}):
        p = (epoch+1)/nrEpochs
        K.set_value(self.alpha, (2/(1+np.exp(-10*p)))-1)
        print(K.get_value(lamFunk))

### Fit the S!@¤

In [9]:
current_model.fit_generator(train_gen_DA, train_stepE, epochs=nrEpochs, verbose=1, validation_data=test_gen_DA, 
                            validation_steps=val_stepE, callbacks=[csv_logger,FlipControle(lamFunk)], initial_epoch=0,
                            max_queue_size=2)

if MIQ == "DA":
    DAlpm = utils.dissect_DAlpm(current_model)
    current_model = DAlpm

0.0
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f81680b9e48>

In [104]:
# save model
    
current_model.save(filepath=path + 'models/'+ MIQ + str(now.day) + '-' + str(now.month) + '-' + str(now.year) + '_' + 
                       str(now.hour) + str(now.minute) + '.h5')

In [11]:
test_img, test_lable = test_gen_target.next()
for count in range(test_gen_target.n//batch_size):
    tmp_img, tmp_lable = test_gen_target.next()
    test_img = np.concatenate((test_img, tmp_img), axis=0)
    test_lable = np.concatenate((test_lable, tmp_lable),axis=0)

In [19]:
# Compute the test metrecis 
inv_map = {v: k for k, v in test_gen_target.class_indices.items()}
target_names = list(inv_map.values())

targets_test_int = [np.where(r == 1)[0][0] for r in test_lable]
y_pred = mod.predict(test_img)
y_pred2 = np.argmax(y_pred, axis = 1)
# Test accuracy:
acc = accuracy_score(targets_test_int, y_pred2)
print('Accuracy on target domain = ', acc)

conf_mat = confusion_matrix(targets_test_int, y_pred2)
print(conf_mat)
# Per class metrics
class_report = classification_report(targets_test_int, y_pred2, target_names=target_names)
print(class_report)

# save to file 
test_file = '/media/jaskmo/ELEK/bme/Project02456/testLog/DA_Model' + str(now.day) + '-' + str(now.month) + '-' + str(now.year) + '_' + str(now.hour) + str(now.minute) + '.log'

sys.stdout = open(test_file, 'w')

print('Accuracy on target domain = ' + str(acc) +'\n \n' + 
      'Confution matric on target domain: \n' + str(conf_mat) + '\n\n' + 
      'Class report on target domain: \n' + class_report)

sys.stdout = stdout_cell

# Evaluate error on source data
# _, metric = current_model.evaluate_generator(generator=test_gen_DA, steps=test_stepE)
# print('Accuracy on source domain = ', metric)

    
# elif training_mode == 'target': # Training on target data from hospital
#     # Convert from onehot
#     targets_test_int = [np.where(r == 1)[0][0] for r in targets_test_hosp]
#     y_pred = current_model.predict(inputs_test_hosp)
#     y_pred2 = np.argmax(y_pred, axis = 1)
#     # Test accuracy:
#     acc = accuracy_score(targets_test_int, y_pred2)
#     print('Accuracy in this domain = ', acc)
#     # Confusion matrix for target
#     conf_mat = confusion_matrix(targets_test_int, y_pred2)
#     print(conf_mat)
#     # Per class metrics
#     class_report = classification_report(targets_test_int, y_pred2, target_names=target_names)
#     print(class_report)
    
#     # Evaluate error on source data
#     _, metric = current_model.evaluate(x=inputs_test_phys, y=targets_test_phys, batch_size=50)
#     print('Accuracy on other domain = ', metric)