In [20]:
import pandas as pd
import glob
import os


In [27]:
def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path, exist_ok=True)
    return path

# 数据划分
每个客户端拥有若干的机构的联合的数据(因为GPU资源不够), 各个客户端都是用 BRATS2018 + BRATS2019(不含重复数据), 这两个数据有明确的机构信息, BRATS2020 的训练集也使用, 但是不含对应的机构信息, 因此设立一个新的机构信息 "BRATS2020". 数据集的信息参考[挑战赛官网](https://www.med.upenn.edu/cbica/brats2020/data.html)

In [43]:
# 定义相关的参数
# PREFIX = 'clara_seg_ct_brats_fl'
PREFIX = 'brats_seg'
BRATS_PREFIX = '/home/liuyuan/shu_codes/datasets/brats'

mkdir(PREFIX)
mkdir(os.path.sep.join((PREFIX, 'config')))

'brats_seg/config'

In [50]:
def get_data_info(data_prefix, sub_name):
    data_info = pd.DataFrame(columns=['id', 'hgg_or_lgg', 'institution', 'ins_pid', 'sub_name'])
    # 遍历对应的目录
    hgg_dirs = glob.glob(os.path.join(data_prefix, 'HGG', '**'))
    lgg_dirs = glob.glob(os.path.join(data_prefix, 'LGG', '**'))
    print(f'HGG: {len(hgg_dirs)}, LGG: {len(lgg_dirs)}')
    for gg_type, gg in (('hgg', hgg_dirs), ('lgg', lgg_dirs)):
        names = [os.path.basename(x) for x in gg]
        for name in names:
            items = name.split('_')
            data_info = data_info.append({'id': name, 
                                          'hgg_or_lgg': gg_type, 
                                          'institution': items[1], 
                                          'sub_name': sub_name,
                                          'ins_pid': '_'.join(items[1:])}, ignore_index=True)
    return data_info
        

channels = ['t1ce', 't1', 't2', 'flair']
channel_suffix = ['_t1ce.nii.gz', '_t1.nii.gz', '_t2.nii.gz', '_flair.nii.gz']
mask_suffix = '_seg.nii.gz'
all_2018_info = get_data_info(BRATS_PREFIX + '/MICCAI_BraTS_2018_Data_Training', sub_name='MICCAI_BraTS_2018_Data_Training')
all_2019_info = get_data_info(BRATS_PREFIX + '/MICCAI_BraTS_2019_Data_Training', sub_name='MICCAI_BraTS_2019_Data_Training')
print(all_2018_info)
print(all_2019_info)

HGG: 210, LGG: 75
HGG: 259, LGG: 76
                       id hgg_or_lgg institution       ins_pid  \
0    Brats18_TCIA03_257_1        hgg      TCIA03  TCIA03_257_1   
1    Brats18_TCIA03_296_1        hgg      TCIA03  TCIA03_296_1   
2     Brats18_CBICA_ANZ_1        hgg       CBICA   CBICA_ANZ_1   
3    Brats18_TCIA06_165_1        hgg      TCIA06  TCIA06_165_1   
4    Brats18_TCIA01_412_1        hgg      TCIA01  TCIA01_412_1   
..                    ...        ...         ...           ...   
280  Brats18_TCIA09_462_1        lgg      TCIA09  TCIA09_462_1   
281  Brats18_TCIA12_101_1        lgg      TCIA12  TCIA12_101_1   
282  Brats18_TCIA13_653_1        lgg      TCIA13  TCIA13_653_1   
283  Brats18_TCIA10_299_1        lgg      TCIA10  TCIA10_299_1   
284  Brats18_TCIA09_451_1        lgg      TCIA09  TCIA09_451_1   

                            sub_name  
0    MICCAI_BraTS_2018_Data_Training  
1    MICCAI_BraTS_2018_Data_Training  
2    MICCAI_BraTS_2018_Data_Training  
3    MICCAI_Bra

In [51]:
# 以 brats2018 作为训练数据, 2019作为测试数据, 需要计算二者的差值
def get_different(df1, df2):
    # apply 第二参数, 1为column, 0 为 index. https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.apply.html
    c1 = df1[['ins_pid']].apply(tuple,1)
    c2 = df2[['ins_pid']].apply(tuple,1)
    return df2[~c2.isin(c1)]

brats_2019_part = get_different(all_2018_info, all_2019_info)
print(brats_2019_part)

                      id hgg_or_lgg institution      ins_pid  \
0    BraTS19_CBICA_BGX_1        hgg       CBICA  CBICA_BGX_1   
3    BraTS19_CBICA_AYG_1        hgg       CBICA  CBICA_AYG_1   
13   BraTS19_CBICA_BCF_1        hgg       CBICA  CBICA_BCF_1   
15   BraTS19_CBICA_AVB_1        hgg       CBICA  CBICA_AVB_1   
16   BraTS19_CBICA_BGW_1        hgg       CBICA  CBICA_BGW_1   
17   BraTS19_CBICA_BJY_1        hgg       CBICA  CBICA_BJY_1   
18   BraTS19_CBICA_AOC_1        hgg       CBICA  CBICA_AOC_1   
23   BraTS19_CBICA_BGG_1        hgg       CBICA  CBICA_BGG_1   
29   BraTS19_CBICA_BGT_1        hgg       CBICA  CBICA_BGT_1   
32   BraTS19_TMC_11964_1        hgg         TMC  TMC_11964_1   
36   BraTS19_TMC_12866_1        hgg         TMC  TMC_12866_1   
38   BraTS19_TMC_21360_1        hgg         TMC  TMC_21360_1   
48   BraTS19_CBICA_BGE_1        hgg       CBICA  CBICA_BGE_1   
54   BraTS19_CBICA_ATN_1        hgg       CBICA  CBICA_ATN_1   
56   BraTS19_TMC_27374_1        hgg     

In [53]:
# 不含有与 2018，2019 的训练的数据, 这些数据均为 HGG 数据
brats2020 = [
"BraTS20_Training_336",
"BraTS20_Training_337",
"BraTS20_Training_338",
"BraTS20_Training_339",
"BraTS20_Training_340",
"BraTS20_Training_341",
"BraTS20_Training_342",
"BraTS20_Training_343",
"BraTS20_Training_344",
"BraTS20_Training_345",
"BraTS20_Training_346",
"BraTS20_Training_347",
"BraTS20_Training_348",
"BraTS20_Training_349",
"BraTS20_Training_350",
"BraTS20_Training_351",
"BraTS20_Training_352",
"BraTS20_Training_353",
"BraTS20_Training_354",
"BraTS20_Training_355",
"BraTS20_Training_356",
"BraTS20_Training_357",
"BraTS20_Training_358",
"BraTS20_Training_359",
"BraTS20_Training_360",
"BraTS20_Training_361",
"BraTS20_Training_362",
"BraTS20_Training_363",
"BraTS20_Training_364",
"BraTS20_Training_365",
"BraTS20_Training_366",
"BraTS20_Training_367",
"BraTS20_Training_368",
"BraTS20_Training_369",
]
brats_2020_part = pd.DataFrame(columns=['id', 'hgg_or_lgg', 'institution', 'ins_pid', 'sub_name'])
for subject_id in brats2020:
    items = subject_id.split('_')
    brats_2020_part = brats_2020_part.append({'id': subject_id, 
                                              'hgg_or_lgg': 'hgg', 
                                              'institution': 'brats2020', 
                                              'sub_name': 'MICCAI_BraTS2020_TrainingData',
                                              'ins_pid': '_'.join(items[1:])}, ignore_index=True)
print(brats_2020_part)

                      id hgg_or_lgg institution       ins_pid  \
0   BraTS20_Training_336        hgg   brats2020  Training_336   
1   BraTS20_Training_337        hgg   brats2020  Training_337   
2   BraTS20_Training_338        hgg   brats2020  Training_338   
3   BraTS20_Training_339        hgg   brats2020  Training_339   
4   BraTS20_Training_340        hgg   brats2020  Training_340   
5   BraTS20_Training_341        hgg   brats2020  Training_341   
6   BraTS20_Training_342        hgg   brats2020  Training_342   
7   BraTS20_Training_343        hgg   brats2020  Training_343   
8   BraTS20_Training_344        hgg   brats2020  Training_344   
9   BraTS20_Training_345        hgg   brats2020  Training_345   
10  BraTS20_Training_346        hgg   brats2020  Training_346   
11  BraTS20_Training_347        hgg   brats2020  Training_347   
12  BraTS20_Training_348        hgg   brats2020  Training_348   
13  BraTS20_Training_349        hgg   brats2020  Training_349   
14  BraTS20_Training_350 

# 统计一些信息

In [49]:
# 输出一些统计信息
def print_insitutions_info(df):
    ds = dict()
    inst = set(df['institution'].to_list())
    print('institution: ', inst)
    for ins in inst:
        ds[ins] = len(df[df['institution'] == ins])
    for k, v in ds.items():
        print('\t', k, ':', v)
    return ds

all_2018_info_stat = print_insitutions_info(all_2018_info)
brats_2019_part_stat = print_insitutions_info(brats_2019_part)
brats_2020_part_stat = print_insitutions_info(brats_2020_part)

institution:  {'TCIA05', 'TCIA13', 'TCIA06', 'TCIA09', 'TCIA02', 'TCIA03', 'CBICA', 'TCIA10', 'TCIA01', '2013', 'TCIA08', 'TCIA04', 'TCIA12'}
	 TCIA05 : 4
	 TCIA13 : 13
	 TCIA06 : 8
	 TCIA09 : 11
	 TCIA02 : 34
	 TCIA03 : 12
	 CBICA : 88
	 TCIA10 : 35
	 TCIA01 : 22
	 2013 : 30
	 TCIA08 : 14
	 TCIA04 : 8
	 TCIA12 : 6
institution:  {'CBICA', 'TMC'}
	 CBICA : 41
	 TMC : 9
institution:  {'brats2020'}
	 brats2020 : 34


In [38]:
center_tcia_num = sum(map(lambda y: y[1], list(filter(lambda x: x[0].startswith('TCIA'), all_2018_info_stat.items()))))
# TCIA 167, 0, 0
# 2013  30, 0, 0
# CBICA 88,41, 0
# TMC    0, 9, 0
# 2020   0, 0,34

167


# 写入相关的配置文件

各个客户端分配的数据可以相同也可以不同，由于 `最小的客户端数量`< `允许接入的客户端数量`，如果有些客户端总是很慢，那么服务端聚合的总是那些提交早的客户端发送的模型

In [59]:
# 合并一些文件, 因为资源确实有限, 目前试一下只使用若干的机构, 这些机构同时含有测试(validation)和训练
import json
import numpy as np
np.random.seed(100)


def write_to_clara_info(client_id, df, train_validation_rate=0.85):
    path = os.path.join(PREFIX, 'config', 'fl_dataset_{}.json'.format(client_id))
    images_masks = []
    for gg_type, pid, sub_name in zip(df['hgg_or_lgg'], df['id'], df['sub_name']):
        if sub_name == 'MICCAI_BraTS2020_TrainingData':
            channel_filepath = [os.path.join(sub_name, pid, pid + channel_suffix[i]) for i, c in enumerate(channels)]
            mask_filepath = os.path.join(sub_name, pid, pid + mask_suffix)
        else:
            channel_filepath = [os.path.join(sub_name, gg_type.upper(), pid, pid + channel_suffix[i]) for i, c in enumerate(channels)]
            mask_filepath = os.path.join(sub_name, gg_type.upper(), pid, pid + mask_suffix)
        images_masks.append({'image': channel_filepath, 'label': mask_filepath})
    # 随机排序
    np.random.shuffle(images_masks)
    # 拆分数据
    n_trains = int(len(images_masks) * train_validation_rate)
    n_tests = len(images_masks) - n_trains
    print('Client ', client_id, ',num_train: ', n_trains, ', num_test: ', n_tests, ', write to ', path)
    # 
    target = dict(training=images_masks[:n_trains], validation=images_masks[n_trains:], description='BRATS')
    with open(path, 'w') as fp:
        json.dump(target, fp)

        
def merge_inst_or_center(inst1, inst2):
    df = pd.concat([inst1, inst2])
    return df
    
# 使用三张卡
# inst_merged = [
#         ['CBICA'], ['TCIA13', 'TCIA02', '2013', 'TCIA01', 'TCIA08'],
#         ['TCIA09', 'TCIA12', 'TCIA10', 'TCIA04', 'TCIA06', 'TCIA03', 'TCIA05']
#     ]
# 使用6张卡
# inst_merged = [
#         ['CBICA'], ['TCIA13', 'TCIA02', 'TCIA03'], ['2013'], ['TCIA01', 'TCIA08', 'TCIA06'],
#         ['TCIA09', 'TCIA12', 'TCIA05'], ['TCIA10', 'TCIA04']
#     ]
# 使用 7 张卡, 包含最新的 2020 的数据集
# 将三个 BRATS 数据表全部合并
all_brats_train_data = merge_inst_or_center(merge_inst_or_center(all_2018_info, brats_2019_part), brats_2020_part)
print(all_brats_train_data)

# 开始合并对应的机构
inst_merged = [
        ['CBICA'], ['TCIA13', 'TCIA02', 'TCIA03'], ['2013'], ['TCIA01', 'TCIA08', 'TCIA06'],
        ['TCIA09', 'TCIA12', 'TCIA05'], ['TCIA10', 'TCIA04'], ['TMC', 'brats2020'],
    ]

ms = []
for i, inss in enumerate(inst_merged):
    tag = all_brats_train_data[all_brats_train_data['institution'] == inss[0]]
    for j in range(1, len(inss)):
        tag = merge_inst_or_center(tag, all_brats_train_data[all_brats_train_data['institution'] == inss[j]])
    ms.append(tag)

for i, item in enumerate(ms):
    print_insitutions_info(item)
for i, item in enumerate(ms):
    write_to_clara_info(i, item)

                      id hgg_or_lgg institution       ins_pid  \
0   Brats18_TCIA03_257_1        hgg      TCIA03  TCIA03_257_1   
1   Brats18_TCIA03_296_1        hgg      TCIA03  TCIA03_296_1   
2    Brats18_CBICA_ANZ_1        hgg       CBICA   CBICA_ANZ_1   
3   Brats18_TCIA06_165_1        hgg      TCIA06  TCIA06_165_1   
4   Brats18_TCIA01_412_1        hgg      TCIA01  TCIA01_412_1   
..                   ...        ...         ...           ...   
29  BraTS20_Training_365        hgg   brats2020  Training_365   
30  BraTS20_Training_366        hgg   brats2020  Training_366   
31  BraTS20_Training_367        hgg   brats2020  Training_367   
32  BraTS20_Training_368        hgg   brats2020  Training_368   
33  BraTS20_Training_369        hgg   brats2020  Training_369   

                           sub_name  
0   MICCAI_BraTS_2018_Data_Training  
1   MICCAI_BraTS_2018_Data_Training  
2   MICCAI_BraTS_2018_Data_Training  
3   MICCAI_BraTS_2018_Data_Training  
4   MICCAI_BraTS_2018_Data_Tr

# 以下的代码作为用于测试整合所有的数据, 用于中心化的联邦学习

In [60]:

target = dict()

def write_to_clara_info_for_center_learning(df, train_validation_rate=0.85):
    path = os.path.join(PREFIX, 'config', 'brats_2018_2019_2020_train_validation.json')
    images_masks = []
    for gg_type, pid, sub_name in zip(df['hgg_or_lgg'], df['id'], df['sub_name']):
        if sub_name == 'MICCAI_BraTS2020_TrainingData':
            channel_filepath = [os.path.join(sub_name, pid, pid + channel_suffix[i]) for i, c in enumerate(channels)]
            mask_filepath = os.path.join(sub_name, pid, pid + mask_suffix)
        else:
            channel_filepath = [os.path.join(sub_name, gg_type.upper(), pid, pid + channel_suffix[i]) for i, c in enumerate(channels)]
            mask_filepath = os.path.join(sub_name, gg_type.upper(), pid, pid + mask_suffix)
        images_masks.append({'image': channel_filepath, 'label': mask_filepath})
    # 随机排序
    np.random.shuffle(images_masks)
    # 拆分数据
    n_trains = int(len(images_masks) * train_validation_rate)
    n_tests = len(images_masks) - n_trains
    print('num_train: ', n_trains, ', num_test: ', n_tests, ', write to ', path)
    # 
    target = dict(training=images_masks[:n_trains], validation=images_masks[n_trains:], description='BRATS')
    with open(path, 'w') as fp:
        json.dump(target, fp)

write_to_clara_info_for_center_learning(all_brats_train_data)

num_train:  313 , num_test:  56 , write to  brats_seg/config/brats_2018_2019_2020_train_validation.json


# 用于提交结果的验证集数据

In [63]:
def get_data_info_2020(data_prefix, channels, ):
    # 2020 需要通过数据映射。。。
    df = pd.read_csv(data_prefix + '/name_mapping_validation_data.csv')
    data_info = pd.DataFrame(columns=['id', 'institution'])
    # 遍历对应的目录
    for _, one_row in df.iterrows():
        sub_id = one_row['BraTS_2020_subject_ID']
        print(sub_id)
        brats2019_id = one_row['BraTS_2019_subject_ID']
        items = brats2019_id.split('_')
        data_info = data_info.append({'id': sub_id, 'institution': items[1]}, ignore_index=True)
    return data_info
all_brats2020_validation_info = get_data_info_2020(BRATS_PREFIX + '/MICCAI_BraTS2020_ValidationData', channels)
print(all_brats2020_validation_info)

BraTS20_Validation_001
BraTS20_Validation_002
BraTS20_Validation_003
BraTS20_Validation_004
BraTS20_Validation_005
BraTS20_Validation_006
BraTS20_Validation_007
BraTS20_Validation_008
BraTS20_Validation_009
BraTS20_Validation_010
BraTS20_Validation_011
BraTS20_Validation_012
BraTS20_Validation_013
BraTS20_Validation_014
BraTS20_Validation_015
BraTS20_Validation_016
BraTS20_Validation_017
BraTS20_Validation_018
BraTS20_Validation_019
BraTS20_Validation_020
BraTS20_Validation_021
BraTS20_Validation_022
BraTS20_Validation_023
BraTS20_Validation_024
BraTS20_Validation_025
BraTS20_Validation_026
BraTS20_Validation_027
BraTS20_Validation_028
BraTS20_Validation_029
BraTS20_Validation_030
BraTS20_Validation_031
BraTS20_Validation_032
BraTS20_Validation_033
BraTS20_Validation_034
BraTS20_Validation_035
BraTS20_Validation_036
BraTS20_Validation_037
BraTS20_Validation_038
BraTS20_Validation_039
BraTS20_Validation_040
BraTS20_Validation_041
BraTS20_Validation_042
BraTS20_Validation_043
BraTS20_Val

In [62]:
#用于提交给 MICAAI 测试
target = dict()
env_json = PREFIX + '/config/brats_2020_validation_submit.json'
sub_name = 'MICCAI_BraTS2020_ValidationData'
# training 部分
images_masks = []
for _, item in all_brats2020_validation_info.iterrows():
    pid = item['id']
    channel_filepath = [os.path.join(sub_name, pid, pid + channel_suffix[i]) for i, c in enumerate(channels)]
    mask_filepath = os.path.join(sub_name, pid, pid + mask_suffix)
    images_masks.append({'image': channel_filepath, 'label': mask_filepath})
target['test'] = images_masks


# 吸入
with open(env_json, 'w') as fp:
    json.dump(target, fp)

# 这个地方处理 CLARA 输出的 NII
因为Clara将预测出来的标签是分开的, MICCAI 则是一个 Seg办好了三个标签
BRATS 需要注意的就是 Origin 的位置, BRATS 的输入数据的 Origin 的位置为 (0,-239,0)

In [66]:
# 用于联邦学习的实验
mmar_root='brats_seg_fl'
exp_name='brats2018_2019_2020'
fl_round = 199

src = mmar_root + '/eval/' + exp_name + '/result_on_round_' + str(fl_round)
target = mmar_root + '/submit/' + exp_name

In [68]:
# 用于普通训练
mmar_root='brats_seg'

src = mmar_root + '/eval'
target = mmar_root + '/submit'

In [69]:
import SimpleITK as sitk
import os
import numpy as np


if not os.path.exists(target):
    os.makedirs(target)
        
def load_nii(fs):
    itk_img = sitk.ReadImage(fs)
    img = sitk.GetArrayFromImage(itk_img)
    return img

for _, item in all_brats2020_validation_info.iterrows():
    pid = item['id']
    labels_prefix = '/'.join((src, pid + '_t1ce'))
    # labels_prefix = '/'.join(('clara_seg_ct_brats/eval_2020_validation', pid + '_t1ce'))
    labels_fs = ['/'.join((labels_prefix, pid + '_t1ce_' + x + '.nii.gz')) for x in ['ET', 'TC', 'WT']]
    
    et_mask, tc_mask, wt_mask = [load_nii(x) for x in labels_fs]
    """
    mask中的标签
    0: 背景
    1: edema(ED)
    2: non-enhancing tumor(NET)
    3: enhancing tumor(ET)
    网络预测的结果
    WT = ED(2) + ET(4) + NCR/NET(1)
    TC = ET+ NET
    ET = ET
    
    ET在 NII 中标签为 4, NET
    """
    n_slice, h, w = et_mask.shape
    print('Process: ', pid, ', shape:', et_mask.shape)
    merged = np.zeros((n_slice, h, w), dtype=np.uint8)
    
    merged[wt_mask == 1] = 2  # WT 可能是 ED
    merged[tc_mask == 1] = 1  # TC 也可能是 NET
    merged[et_mask == 1] = 4  # ET 至少是 ET

    out = sitk.GetImageFromArray(merged)
    out.SetOrigin((0, -239, 0))
    img_fp = target + '/' + pid + '.nii.gz'
    sitk.WriteImage(out, img_fp)
    

Process:  BraTS20_Validation_001 , shape: (155, 240, 240)
Process:  BraTS20_Validation_002 , shape: (155, 240, 240)
Process:  BraTS20_Validation_003 , shape: (155, 240, 240)
Process:  BraTS20_Validation_004 , shape: (155, 240, 240)
Process:  BraTS20_Validation_005 , shape: (155, 240, 240)
Process:  BraTS20_Validation_006 , shape: (155, 240, 240)
Process:  BraTS20_Validation_007 , shape: (155, 240, 240)
Process:  BraTS20_Validation_008 , shape: (155, 240, 240)
Process:  BraTS20_Validation_009 , shape: (155, 240, 240)
Process:  BraTS20_Validation_010 , shape: (155, 240, 240)
Process:  BraTS20_Validation_011 , shape: (155, 240, 240)
Process:  BraTS20_Validation_012 , shape: (155, 240, 240)
Process:  BraTS20_Validation_013 , shape: (155, 240, 240)
Process:  BraTS20_Validation_014 , shape: (155, 240, 240)
Process:  BraTS20_Validation_015 , shape: (155, 240, 240)
Process:  BraTS20_Validation_016 , shape: (155, 240, 240)
Process:  BraTS20_Validation_017 , shape: (155, 240, 240)
Process:  BraT

In [None]:
all_brats2020_validation_info