In [1]:
# ML
from keras.layers import Conv2D, AveragePooling2D, LeakyReLU, Input, Concatenate, BatchNormalization, Dropout, Flatten, Dense
from keras.models import Model
from keras.utils.vis_utils import plot_model
from keras.utils import Sequence
from keras.callbacks import EarlyStopping, Callback
import tensorflow as tf

# STD Libraries
import os
import gc
from time import time
import random

# Colab
from google.colab import drive

drive.mount('/content/drive/', force_remount=False)

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [2]:
class NCNN:
  def __init__(self):
    # ===== Hyperparameters =====
    # base filters
    self.base_f = 32
    # amount of downsampling blocks
    self.ds_blocks = 4
    # universal dropout rate
    self.dp_rt = .25
    # alpha value for LeakyReLU
    self.alpha_1 = .2

    # fMRI scan shape
    self.sc_shape = (42, 80, 80)

    self.data_path = 'drive/My Drive/NISR'

    # ===== Data fields =====
    self.n_runs = 5
    self.runs = random.sample([1, 2, 4, 8, 9], self.n_runs)

    # ===== Build models =====
    self.model = self.build_architecture()

  
  # ===== Architecture =====

  def cf(self, x, f):
    '''
    Downsampling block
    c64, c128, c256, ...
    '''
    x = Conv2D(f, kernel_size=7, padding='same')(x)
    x = AveragePooling2D(pool_size=2, strides=2)(x)
    x = LeakyReLU(alpha=self.alpha_1)(x)
    x = BatchNormalization()(x)
    return x
  
  # ===== Models =====

  def build_base(self, in_shape):
    '''
    Base model for inclusive directions
    '''
    f = self.base_f

    x0 = Input(shape=in_shape)

    # block w/o stride
    x = Conv2D(f, kernel_size=7, padding='same')(x0)
    x = AveragePooling2D(pool_size=2)(x)
    x = LeakyReLU(alpha=self.alpha_1)(x)
    x = BatchNormalization()(x)
    x = Dropout(self.dp_rt)(x)

    for i in range(self.ds_blocks):
      x = self.cf(x, f * 2)
      f *= 2
    
    x = Flatten()(x)
    
    return x0, x
  
  def build_architecture(self):
    '''
    Connects 3 networks, and adds prediction layer after concatenation
    '''
    x_dim = self.sc_shape
    y_dim = (self.sc_shape[0], self.sc_shape[2], self.sc_shape[1])
    z_dim = (self.sc_shape[1], self.sc_shape[2], self.sc_shape[0])

    print(x_dim, y_dim, z_dim)

    x_in, x_out = self.build_base(x_dim)
    y_in, y_out = self.build_base(y_dim)
    z_in, z_out = self.build_base(z_dim)

    u = Concatenate()([x_out, y_out, z_out])
    u = BatchNormalization()(u)
    u = Dense(512)(u)
    u = LeakyReLU(alpha=self.alpha_1)(u)
    u = Dense(1, activation='sigmoid')(u)

    model = Model(inputs=[x_in, y_in, z_in], outputs=u)

    model.summary()
    plot_model(model, to_file='model_architecture.png', show_shapes=True, show_layer_names=True)

    return model
  
    # ===== Load data =====

    class DataLoader(Sequence):
      def __init__(self, batch_size=2, shuffle=True):
        self.batch_size = batch_size
        self.shuffle = shuffle

        self.no_runs = self.batch_size * 5
      
      def __len__(self):
        return int (np.floor())
      
      def __data_generation(self):
        x = tf.cast(np.empty([self.no_runs, 209, 42, 80, 80]), dtype=tf.float32)
        y = tf.cast(np.empty([self.no_runs, 209]), dtype=tf.uint16)

        for dir_name, _, filenames in os.walk(self.data_path):
          pass

    def data_loader(self):
      batch_size = 2
      no_runs = batch_size * self.runs
      while True:
        x = tf.cast(np.empty([no_runs, 209, 42, 80, 80]), dtype=tf.float32)
        y = tf.cast(np.empty([no_runs, 209]), dtype=tf.uint16)

        for dir_name, _, filenames in os.walk(self.data_path):
          for filename in filenames:
            print(filename)
      

    # ===== Callbacks =====

    class GC(Callback):
      '''
      Garbage Collection Callback
      Removes delocalized memory
      '''
      def on_epoch_end(self, epoch, logs=None):
        gc.collect()
      
    class ETA(Callback):
      '''
      Estimated Time of Arrival Callback
      Keeps track of time elapsed
      '''
      def on_train_begin(self, logs={}):
        self.times = []
        self.start = time()
    
      def on_train_end(self, logs={}):
        self.elapsed = time() - self.start

      def on_epoch_begin(self, batch, logs={}):
        self.epoch_start = time()
      
      def on_epoch_end(self, batch, logs={}):
        self.times.append(time() - self.epoch_start)

    class LossLog(Callback):
      '''
      Loss Logger
      Keeps track of loss and validation loss
      '''
      def on_train_begin(self, logs={}):
        self.loss = {'batch': [], 'epoch': []}
        self.val_loss = {'batch': [], 'epoch': []}
      
      def on_batch_end(self, batch, logs={}):
        self.loss['batch'].append(logs.get('loss'))
        self.val_loss['batch'].append(logs.get('val_loss'))
      
      def on_epoch_end(self, batch, logs={}):
        self.loss['epoch'].append(logs.get('loss'))
        self.val_loss['epoch'].append(logs.get('val_loss'))

    
    # ===== Training process =====
    def train(self):
      # define callbacks
      early_stop_cb = EarlyStopping(monitor='loss', patience=5)
      eta_cb = ETA()
      gc_cb = GC()
      ll_cb = LossLog()

      # self.model.fit_generator()


In [3]:
model = NCNN()

(42, 80, 80) (42, 80, 80) (80, 80, 42)
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 42, 80, 80)] 0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 42, 80, 80)] 0                                            
__________________________________________________________________________________________________
input_3 (InputLayer)            [(None, 80, 80, 42)] 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 42, 80, 32)   125472      input_1[0][0]                    
_______________________________________________________