In [1]:
import pydicom, re
import tensorflow as tf
import pandas as pd
import numpy as np
import flwr as fl
from os import listdir
from os.path import isfile, join, exists

# from ipynb.fs.full.utils import MRIDataGenerator

In [2]:
CLIENT_NUM = 4

In [3]:
class MRIDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, df, X_col, y_col, batch_size,
                 input_size= (256, 256), depth_size=64,
                 shuffle=True):
 
        self.df = df.copy()
        self.X_col = X_col
        self.y_col = y_col
        self.depth_size = depth_size
        self.batch_size = batch_size
        self.input_size = input_size
        self.shuffle = shuffle
        self.n = len(self.df)
 
    def on_epoch_end(self):
        if self.shuffle:
            self.df = self.df.sample(frac=1).reset_index(drop=True)
 
    def __get_input(self, path, target_size):
        scan3d = None
        onlyfiles = [f for f in listdir(path) if isfile(join(path, f))]
        filepatt = 'Image-{}.dcm'
        digits = [int(re.search('\d+',i).group()) for i in listdir(path) if re.match(filepatt.format('\d+\\'),i)]
        digits.sort()
        onlyfiles = [filepatt.format(dig) for dig in digits]
        
        center = len(onlyfiles) // 2
        left = max(0, center - (self.depth_size // 2))
        right = min(len(onlyfiles), center + (self.depth_size // 2))
        onlyfiles = onlyfiles[left: right]
        if len(onlyfiles) < self.depth_size:
            img_shape = pydicom.read_file(f'{path}{onlyfiles[0]}').pixel_array.shape
            add_z = self.depth_size - len(onlyfiles)
            scan3d = np.zeros((add_z, target_size[0], target_size[1],1))
        
            
        scans = []
        for f in onlyfiles:
            img = pydicom.read_file(f'{path}{f}')
            img = img.pixel_array
            img = self._rescale(img)
            img = np.expand_dims(img, axis=-1)
            img = tf.image.resize(img,(target_size[0], target_size[1])).numpy()
            
            img = self._normalize(img)
            scans.append(img)
        
        if scan3d is not None:
            return np.concatenate([np.array(scans), scan3d]) 
        else:
            return np.array(scans)
    def _rescale(self, arr):
        arr_min = arr.min()
        arr_max = arr.max()
        if (arr_max - arr_min) == 0:
            return arr
        return (arr - arr_min) / (arr_max - arr_min)
    def _normalize(self, arr):
        img = arr - arr.mean()
        # divide by the standard deviation (only if it is different from zero)
        if np.std(img) != 0:
            img = img / np.std(img)
        return img
    def __get_data(self, batches):
        if self.X_col is None:
            PATHS = ['FLAIR_path', 'T1w_path', 'T2w_path', 'T1wCE_path']
            X_batch = []
            for p in PATHS:
                batch_part_path = batches[p]
                X_batch.append(np.asarray([self.__get_input(x,  self.input_size) for x in batch_part_path]))
            y_batch = batches[self.y_col].values
            X_batch = np.concatenate(X_batch, axis=4)
            
        else:
            path_batch = batches[self.X_col]
            X_batch = np.asarray([self.__get_input(x,  self.input_size) for x in path_batch])
            y_batch = batches[self.y_col].values
        return X_batch, y_batch
    def __getitem__(self, index):
        batches = self.df[index * self.batch_size:(index + 1) * self.batch_size]
        X, y = self.__get_data(batches)
        return X, y
    def __len__(self):
        return self.n // self.batch_size 

In [4]:
def checkDirectoryForExistence(dirName):
    return exists(f'./client_{CLIENT_NUM}/train/{str(dirName).zfill(5)}')

def getAllExistingDirs():
    return f'./client_{CLIENT_NUM}/train/' + train_df[train_df['BraTS21ID'].apply(checkDirectoryForExistence)]['BraTS21ID'].astype(str).str.zfill(5)

train_df = pd.read_csv(f"./train_labels.csv")
train_df['FLAIR_path'] = getAllExistingDirs() + '/FLAIR/'
train_df['T1w_path'] = getAllExistingDirs() + '/T1w/' 
train_df['T2w_path'] = getAllExistingDirs() + '/T2w/' 
train_df['T1wCE_path'] = getAllExistingDirs() + '/t1wCE/'
train_df = train_df.dropna(how='any')

In [5]:
def checkDirectoryForExistence(dirName):
    return exists(f'./client_{CLIENT_NUM}/test/{str(dirName).zfill(5)}')

def getAllExistingDirs():
    return f'./client_{CLIENT_NUM}/test/' + test_df[test_df['BraTS21ID'].apply(checkDirectoryForExistence)]['BraTS21ID'].astype(str).str.zfill(5)

test_df = pd.read_csv(f"./train_labels.csv")
test_df['FLAIR_path'] = getAllExistingDirs() + '/FLAIR/'
test_df['T1w_path'] = getAllExistingDirs() + '/T1w/' 
test_df['T2w_path'] = getAllExistingDirs() + '/T2w/' 
test_df['T1wCE_path'] = getAllExistingDirs() + '/t1wCE/'
test_df = test_df.dropna(how='any')

In [6]:
depth = 64
resolution = (192, 192)
batches = 4
gen = MRIDataGenerator(train_df, 'FLAIR_path', 'MGMT_value', batches, resolution, depth, True)

In [7]:
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv3D, MaxPooling3D, Flatten, Dense

model = Sequential()
model.add(Conv3D(32, kernel_size=(3, 3, 3), activation='relu', kernel_initializer='he_uniform', input_shape=(64,192, 192, 1)))
model.add(MaxPooling3D(pool_size=(2, 2, 2)))
model.add(Conv3D(64, kernel_size=(3, 3, 3), activation='relu', kernel_initializer='he_uniform'))
model.add(MaxPooling3D(pool_size=(2, 2, 2)))
model.add(Conv3D(128, kernel_size=(3, 3, 3), activation='relu', kernel_initializer='he_uniform'))
model.add(MaxPooling3D(pool_size=(2, 2, 2)))
model.add(Conv3D(256, kernel_size=(3, 3, 3), activation='relu', kernel_initializer='he_uniform'))
model.add(MaxPooling3D(pool_size=(2, 2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu', kernel_initializer='he_uniform'))
model.add(Dense(64, activation='relu', kernel_initializer='he_uniform'))
model.add(Dense(1, activation='sigmoid'))

model.compile(
        optimizer='adam', 
        loss='binary_crossentropy',
        metrics=[tf.keras.metrics.BinaryAccuracy()]
    )

In [8]:
train_history = []
val_loss = []
val_acc = []

class MRI_Classifier_Client(fl.client.NumPyClient):
    def get_parameters(self):
        return model.get_weights()

    def fit(self, parameters, config):
        model.set_weights(parameters)
        history = model.fit(gen, steps_per_epoch = batches, verbose=1, epochs = 3)
        train_history.append(history)
        return model.get_weights(), len(train_df), {}
    
    def get_properties(self):
        pass

    def evaluate(self, parameters, config):
        model.set_weights(parameters)
        loss, accuracy = model.evaluate(MRIDataGenerator(test_df, 'FLAIR_path', 'MGMT_value', batches, resolution, depth, True))
        val_loss.append(loss)
        val_acc.append(accuracy)
        return loss, len(test_df), {"accuracy": accuracy}

In [None]:
fl.client.start_numpy_client("localhost:8080", client=CifarClient())

INFO flower 2022-05-26 21:24:52,702 | connection.py:102 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flower 2022-05-26 21:24:52,953 | connection.py:39 | ChannelConnectivity.IDLE
DEBUG flower 2022-05-26 21:24:52,953 | connection.py:39 | ChannelConnectivity.READY


Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
