## Import required packages

In [1]:
from __future__ import print_function, division
import os
import struct
import socket
import pickle

import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics import PearsonCorrCoef
from torch import optim


import scipy.stats as stats
import h5py
import gc
import time
from tqdm import tqdm



# Split 3D CNN Client Side
This code is the server part of split 3D-CNN model for **multi** client and a server.

In [2]:
!pip install numba
from numba import cuda
device = cuda.get_current_device()
device.reset()



## SET Hyperparameter (**)

In [3]:
batch_size = 5
num_workers = 8

In [4]:
users = 3 # number of clients
epoch = 50  # default
lr = 2e-5

## SET CUDA

In [5]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"
#device = "cpu"
torch.manual_seed(777)
if device =="cuda:1":
    torch.cuda.manual_seed_all(777)

In [6]:
# client_order = int(input("client_order(start from 0): "))
client_order = 2

## Data load

In [7]:
class MRIDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform = None):
        """
        Args:
            csv_file (string): csv 파일의 경로
            root_dir (string): 모든 이미지가 존재하는 디렉토리 경로
            transform (callable, optional): 샘플에 적용될 Optional transform
        """
        self.mri_annotation = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.mri_annotation)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        try:        
            img_name = os.path.join(self.root_dir,'wm'+ str(self.mri_annotation.iloc[idx, 3]) + '_' + str(self.mri_annotation.iloc[idx, 0]).zfill(7)+ '_ses1' + '_t1w.nii')
            mri_image = nib.load(img_name).get_fdata()
            mri_age = self.mri_annotation.iloc[idx, 1]
            sample = {'image': mri_image, 'mri_age': mri_age}
        except:
            return None
            
        return mri_image, mri_age
    
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

def flatten(lst):
    result = []
    for item in lst:
        if type(item) == list :
            result += flatten(item)
        else:
            result += [item]
    return result


In [8]:
mri_train_dataset = MRIDataset(csv_file = './FLdata/2_CoRR/CoRR_Phenotype_train.csv', root_dir = './FLdata/2_CoRR/T1w/wm/')
mri_test_dataset = MRIDataset(csv_file = './FLdata/2_CoRR/CoRR_Phenotype_test.csv', root_dir = './FLdata/2_CoRR/T1w/wm/')
mri_val_dataset = MRIDataset(csv_file = './FLdata/2_CoRR/CoRR_Phenotype_validation.csv', root_dir = './FLdata/2_CoRR/T1w/wm/')
model_PATH = './model/CheckpointCoRR.pt'

In [9]:
train_loader = DataLoader(mri_train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn, drop_last = True)
test_loader = DataLoader(mri_test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, drop_last = True)
val_loader = DataLoader(mri_val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn, drop_last = True)

In [10]:
x_train, y_train = next(iter(train_loader))
print(x_train.size())
print(y_train.size())

torch.Size([5, 113, 137, 113])
torch.Size([5])


### Set other hyperparameters in the model
Hyperparameters here should be same with the server side.

In [11]:
class CNN3DModel(nn.Module):
    def __init__(self):
        super(CNN3DModel, self).__init__()
        
        
        self.client_block = nn.Sequential(
            # First Block
            nn.Conv3d(1, 16, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),
            
            nn.Conv3d(16, 16, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm3d(16),
            nn.ReLU(),

            nn.MaxPool3d(2),

            # Second Block
            nn.Conv3d(16, 32, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),
    
            nn.Conv3d(32, 32, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm3d(32),
            nn.ReLU(),

            nn.MaxPool3d(2),

            # Third Block
            nn.Conv3d(32, 64, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),

            nn.Conv3d(64, 64, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm3d(64),
            nn.ReLU(),

            nn.MaxPool3d(2),
        )
        
        '''
        self.server_block = nn.Sequential(
            
            # 4th Block
            nn.Conv3d(64, 128, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),

            nn.Conv3d(128, 128, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm3d(128),
            nn.ReLU(),

            nn.MaxPool3d(2),

            # 5th Block
            nn.Conv3d(128, 256, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),

            nn.Conv3d(256, 256, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm3d(256),
            nn.ReLU(),

            nn.MaxPool3d(2),
        )

        self.regressor = nn.Sequential(
            nn.Linear(9216, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1),
        )
        '''

    def forward(self, x):
        x = self.client_block(x)
        return x

In [12]:
splitnn_client = CNN3DModel().to(device)
print(splitnn_client)

CNN3DModel(
  (client_block): Sequential(
    (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ReLU()
    (2): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (3): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU()
    (5): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (7): ReLU()
    (8): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (9): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (13): ReLU()
    (14): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (15): BatchNorm3d(64

## Socket initialization

### Required socket functions

In [13]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    msg =  recvall(sock, msglen)
    msg = pickle.loads(msg)
    return msg

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data


### Set host address and port number

In [14]:
host = "127.0.0.1"
port = 10080

## SET TIMER

In [15]:
start_time = time.time()    # store start time
print("timer start!")

timer start!


### Open the client socket

In [16]:
s = socket.socket()
s.connect((host, port))

In [17]:
epoch = recv_msg(s)   # get epoch
total_batch = len(train_loader)
msg = total_batch
send_msg(s, msg)   # send total_batch of train dataset

In [18]:
criterion = nn.L1Loss().to(device)
lr = 0.001
optimizer = optim.Adam(splitnn_client.parameters(), lr=lr)

## Real training process

In [19]:
for e in range(epoch):
    client_weights = recv_msg(s)
    splitnn_client.load_state_dict(client_weights)
    splitnn_client.eval()
    i = 1
    for images, labels in tqdm(train_loader, desc='Epoch '+str(e+1)):
        train = images.view(images.shape[0], 1, images.shape[1], images.shape[2],images.shape[3]).to(device, dtype = torch.float32)
        label = labels.to(device)
        # Clear gradients
        optimizer.zero_grad()
        # Forward propagation
        outputs = splitnn_client(train)
        client_output = outputs.clone().detach().requires_grad_(True)
        msg = {
            'client_output': client_output,
            'label': label
        }
        send_msg(s, msg)
        client_grad = recv_msg(s)
        outputs.backward(client_grad)
        optimizer.step()
        if (i == total_batch):
            break;
        i = i+1;
        del outputs
        del client_output
    
    send_msg(s, splitnn_client.state_dict())

Epoch 1:  99%|███████████████████████████████▊| 192/193 [00:29<00:00,  6.41it/s]
Epoch 2:  99%|███████████████████████████████▊| 192/193 [00:29<00:00,  6.59it/s]
Epoch 3:  99%|███████████████████████████████▊| 192/193 [00:28<00:00,  6.66it/s]
Epoch 4:  99%|███████████████████████████████▊| 192/193 [00:28<00:00,  6.66it/s]
Epoch 5:  99%|███████████████████████████████▊| 192/193 [00:28<00:00,  6.65it/s]
Epoch 6:  99%|███████████████████████████████▊| 192/193 [00:28<00:00,  6.64it/s]
Epoch 7:  99%|███████████████████████████████▊| 192/193 [00:28<00:00,  6.66it/s]
Epoch 8:  99%|███████████████████████████████▊| 192/193 [00:28<00:00,  6.65it/s]
Epoch 9:  99%|███████████████████████████████▊| 192/193 [00:28<00:00,  6.69it/s]
Epoch 10:  99%|██████████████████████████████▊| 192/193 [00:28<00:00,  6.66it/s]


In [20]:
end_time = time.time()  #store end time
print("WorkingTime of ",device ,": {} sec".format(end_time - start_time))

WorkingTime of  cuda:1 : 563.3791165351868 sec
