In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"
import tensorflow as tf
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
for device in gpu_devices:
    tf.config.experimental.set_memory_growth(device, True)
tf.get_logger().setLevel("ERROR")
import numpy as np
from keras import metrics
import matplotlib.pyplot as plt
from tensorflow.keras.layers import LeakyReLU, Conv2D, MaxPool2D, UpSampling2D, Input, Lambda, BatchNormalization, Activation, Dense, AveragePooling2D, MaxPooling2D, Flatten, Dropout, Concatenate
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler
from sklearn.metrics import classification_report, confusion_matrix
import time
import pandas as pd
import gc
from tensorflow.keras import backend as k
from tqdm import tqdm
import math
from tensorflow.keras.callbacks import Callback
from sklearn import preprocessing
from keras.regularizers import l1,l2

In [None]:
def original_model():
    model_input = Input((10000, 27, 1))

    # first convolution layer
    model_output = Conv2D(3, kernel_size=(1, 27), activation=None)(model_input)
    model_output = BatchNormalization()(model_output)
    model_output = Activation("relu")(model_output)

    # sceond convolution layer
    model_output = Conv2D(3, (1, 1), activation=None)(model_output)
    model_output = BatchNormalization()(model_output)
    model_output = Activation("relu")(model_output)

    # pooling layer
    model_output = AveragePooling2D(pool_size=(10000, 1))(model_output)
    model_output = Flatten()(model_output)

    # Dense layer
    model_output = Dense(3, activation=None)(model_output)
    model_output = BatchNormalization()(model_output)
    model_output = Activation("relu")(model_output)

    # output layer
    model_output = Dense(1, activation=None)(model_output)
    model_output = BatchNormalization()(model_output)
    model_output = Activation("sigmoid")(model_output)
    
    model = Model(model_input, model_output)
    
    model.compile(loss='binary_crossentropy',
              optimizer=Adam(lr=0.0001),
              metrics=['accuracy'])
    
    model.load_weights('saved_weights.hdf5')
        
#     for idx, layer in enumerate(model.layers):
#         print(idx, layer.name, layer.trainable)
        
    return model
        

def current_model(ori_model):
    model_input = Input((1000, 25, 1))

    # first convolution layer
    model_output = Conv2D(3, kernel_size=(1, 25), activation=None)(model_input)
    model_output = BatchNormalization()(model_output)
    model_output = Activation("relu")(model_output)

    # sceond convolution layer
    model_output = Conv2D(3, (1, 1), activation=None)(model_output)
    model_output = BatchNormalization()(model_output)
    model_output = Activation("relu")(model_output)

    # pooling layer
    model_output = AveragePooling2D(pool_size=(1000, 1))(model_output)
    model_output = Flatten()(model_output)

    # Dense layer
    model_output = Dense(3, activation=None)(model_output)
    model_output = BatchNormalization()(model_output)
    model_output = Activation("relu")(model_output)

    # output layer
    model_output = Dense(1, activation=None)(model_output)
    model_output = BatchNormalization()(model_output)
    model_output = Activation("sigmoid")(model_output)
    
    model = Model(model_input, model_output)
    
    model.compile(loss='binary_crossentropy',
              optimizer=Adam(lr=0.0001),
              metrics=['accuracy'])
    
    for i in list(range(4, 7))+list(range(8, 12)):
        model.layers[i].set_weights(ori_model.layers[i].get_weights())
        model.layers[i].trainable = False
    
    
#     for idx, layer in enumerate(model.layers):
#         print(idx, layer.name, layer.trainable)
    print(model.summary())
    
    return model

In [None]:
class ClearMemory(Callback):
    def on_epoch_end(self, epoch, logs=None):
        gc.collect()
        k.clear_session()

def plot_history(history, path):
    """
    Plot keras training history
    :param history: keras history
    :return: None
    """
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 7))

    ax1.plot(history.history['accuracy'])
    ax1.plot(history.history['val_accuracy'])
    ax1.set_title('model acc')
    ax1.set_ylabel('Accuracy')
    ax1.set_xlabel('epoch')
    ax1.legend(['train', 'test'], loc='upper left')

    ax2.plot(history.history['loss'])
    ax2.plot(history.history['val_loss'])
    ax2.set_title('model loss')
    ax2.set_ylabel('loss')
    ax2.set_xlabel('epoch')
    ax2.legend(['train', 'test'], loc='upper left')
    
    plt.savefig(path)

np.random.seed(1)
tf.random.set_seed(1)

In [None]:
accs = []
print('Running')
for fidx in range(0, 5):
    print('===============Training Fold {}==============='.format(fidx))
    
    '''################################# Load Data #################################'''
    X_train = np.load('data_folds_raw/X_train_{}_arsinh.npy'.format(fidx))
    X_val = np.load('data_folds_raw/X_val_{}_arsinh.npy'.format(fidx))
    
    # X_train = np.load('data_folds_raw/X_train_{}_norm.npy'.format(fidx))
    # X_val = np.load('data_folds_raw/X_val_{}_norm.npy'.format(fidx))
    
    y_train = np.load('data_folds_raw/y_train_{}.npy'.format(fidx))
    y_val = np.load('data_folds_raw/y_val_{}.npy'.format(fidx))
    score_train = np.load('data_folds_raw/score_train_{}.npy'.format(fidx))
    score_val = np.load('data_folds_raw/score_val_{}.npy'.format(fidx))
    covar_train = np.load('data_folds_raw/covar_train_{}_norm.npy'.format(fidx))
    covar_val = np.load('data_folds_raw/covar_val_{}_norm.npy'.format(fidx))

    X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], X_train.shape[2], 1)
    X_val = X_val.reshape(X_val.shape[0], X_val.shape[1], X_val.shape[2], 1)
    covar_train = covar_train.reshape(covar_train.shape[0], covar_train.shape[1], 1)
    covar_val = covar_val.reshape(covar_val.shape[0], covar_val.shape[1], 1)
    y_train = y_train.reshape(y_train.shape[0], 1)
    y_val = y_val.reshape(y_val.shape[0], 1)
    score_train = score_train.reshape(score_train.shape[0], 1)
    score_val = score_val.reshape(score_val.shape[0], 1)

    print(X_train.shape, '<->', np.unique(y_val, return_counts=True))

    covar_train = covar_train.reshape(covar_train.shape[0], covar_train.shape[1])
    covar_val = covar_val.reshape(covar_val.shape[0], covar_val.shape[1])
    
    
    print(X_train[0].shape)
    ori_model = original_model()
    cur_model = current_model(ori_model)
    
    del ori_model
    gc.collect()
    
    early_stopping = EarlyStopping(monitor='val_loss', patience=25, verbose=1, mode='min')
    mcp_save = ModelCheckpoint(
        '/mnt/Dev_ssd/Dev_ssd/hiv_brain/classification_checkpoints/{}.h5'.format(fidx), save_best_only=True,
        monitor='val_loss', mode='min'
    )
    hist = cur_model.fit(
        [X_train, covar_train],
        y_train,
        epochs=500,
        batch_size=64,
        shuffle=True,
        validation_data=([X_val, covar_val], y_val),
        validation_batch_size=32,
        callbacks=[ClearMemory(), mcp_save]
    )
    # plot_history(hist, '/mnt/Dev_ssd/Dev_ssd/hiv_brain/classification_model_checkpoints/{}.jpg'.format(fidx))
    
    accs.append(max(hist.history['val_accuracy']))

    del X_train
    del X_val
    del y_train
    del y_val
    del score_train
    del score_val
    del hist
    gc.collect()

print(accs)
print(sum(accs) / len(accs), (max(accs) - min(accs) )/ 2, np.std(accs))