In [1]:
import pandas as pd
import glob
import os


# 数据划分
每个客户端拥有若干的机构的联合的数据(因为GPU资源不够), train和validation来源相同, 构建的联合模型则是在BRATS2018和BRATS2019的差集上验证

In [11]:
# 定义相关的参数
PREFIX = 'clara_seg_ct_brats_fl'
BRATS_PREFIX = '/home/liuyuan/shu_codes/datasets/brats'

In [12]:
def get_data_info(data_prefix, channels, ):
    data_info = pd.DataFrame(columns=['id', 'hgg_or_lgg', 'institution', 'ins_pid'])
    # 遍历对应的目录
    hgg_dirs = glob.glob(os.path.join(data_prefix, 'HGG', '**'))
    lgg_dirs = glob.glob(os.path.join(data_prefix, 'LGG', '**'))
    print(f'HGG: {len(hgg_dirs)}, LGG: {len(lgg_dirs)}')
    for gg_type, gg in (('hgg', hgg_dirs), ('lgg', lgg_dirs)):
        names = [os.path.basename(x) for x in gg]
        for name in names:
            items = name.split('_')
            data_info = data_info.append({'id': name, 'hgg_or_lgg': gg_type, 'institution': items[1], 'ins_pid': '_'.join(items[1:])}, ignore_index=True)
    return data_info
        

channels = ['t1ce', 't1', 't2', 'flair']
channel_suffix = ['_t1ce.nii.gz', '_t1.nii.gz', '_t2.nii.gz', '_flair.nii.gz']
mask_suffix = '_seg.nii.gz'
all_2018_info = get_data_info(BRATS_PREFIX + '/MICCAI_BraTS_2018_Data_Training', channels)
all_2019_info = get_data_info(BRATS_PREFIX + '/MICCAI_BraTS_2019_Data_Training', channels)
print(all_2018_info)
print(all_2019_info)

HGG: 210, LGG: 75
HGG: 259, LGG: 76
                       id hgg_or_lgg institution       ins_pid
0    Brats18_TCIA03_257_1        hgg      TCIA03  TCIA03_257_1
1    Brats18_TCIA03_296_1        hgg      TCIA03  TCIA03_296_1
2     Brats18_CBICA_ANZ_1        hgg       CBICA   CBICA_ANZ_1
3    Brats18_TCIA06_165_1        hgg      TCIA06  TCIA06_165_1
4    Brats18_TCIA01_412_1        hgg      TCIA01  TCIA01_412_1
..                    ...        ...         ...           ...
280  Brats18_TCIA09_462_1        lgg      TCIA09  TCIA09_462_1
281  Brats18_TCIA12_101_1        lgg      TCIA12  TCIA12_101_1
282  Brats18_TCIA13_653_1        lgg      TCIA13  TCIA13_653_1
283  Brats18_TCIA10_299_1        lgg      TCIA10  TCIA10_299_1
284  Brats18_TCIA09_451_1        lgg      TCIA09  TCIA09_451_1

[285 rows x 4 columns]
                       id hgg_or_lgg institution       ins_pid
0     BraTS19_CBICA_BGX_1        hgg       CBICA   CBICA_BGX_1
1     BraTS19_CBICA_ALN_1        hgg       CBICA   CBICA_A

In [13]:
# 以 brats2018 作为训练数据, 2019作为测试数据, 需要计算二者的差值
def get_different(df1, df2):
    # apply 第二参数, 1为column, 0 为 index. https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.apply.html
    c1 = df1[['ins_pid']].apply(tuple,1)
    c2 = df2[['ins_pid']].apply(tuple,1)
    return df2[~c2.isin(c1)]

brats_2019_part = get_different(all_2018_info, all_2019_info)
print(brats_2019_part)

                      id hgg_or_lgg institution      ins_pid
0    BraTS19_CBICA_BGX_1        hgg       CBICA  CBICA_BGX_1
3    BraTS19_CBICA_AYG_1        hgg       CBICA  CBICA_AYG_1
13   BraTS19_CBICA_BCF_1        hgg       CBICA  CBICA_BCF_1
15   BraTS19_CBICA_AVB_1        hgg       CBICA  CBICA_AVB_1
16   BraTS19_CBICA_BGW_1        hgg       CBICA  CBICA_BGW_1
17   BraTS19_CBICA_BJY_1        hgg       CBICA  CBICA_BJY_1
18   BraTS19_CBICA_AOC_1        hgg       CBICA  CBICA_AOC_1
23   BraTS19_CBICA_BGG_1        hgg       CBICA  CBICA_BGG_1
29   BraTS19_CBICA_BGT_1        hgg       CBICA  CBICA_BGT_1
32   BraTS19_TMC_11964_1        hgg         TMC  TMC_11964_1
36   BraTS19_TMC_12866_1        hgg         TMC  TMC_12866_1
38   BraTS19_TMC_21360_1        hgg         TMC  TMC_21360_1
48   BraTS19_CBICA_BGE_1        hgg       CBICA  CBICA_BGE_1
54   BraTS19_CBICA_ATN_1        hgg       CBICA  CBICA_ATN_1
56   BraTS19_TMC_27374_1        hgg         TMC  TMC_27374_1
59   BraTS19_CBICA_BHZ_1

In [14]:
# 输出一些统计信息
def print_insitutions_info(df):
    ds = dict()
    inst = set(df['institution'].to_list())
    print('institution: ', inst)
    for ins in inst:
        ds[ins] = len(df[df['institution'] == ins])
    for k, v in ds.items():
        print('\t', k, ':', v)

print_insitutions_info(all_2018_info)
print_insitutions_info(brats_2019_part)

institution:  {'TCIA04', 'TCIA02', 'TCIA06', 'TCIA05', 'TCIA12', 'TCIA01', 'TCIA09', 'TCIA10', 'CBICA', 'TCIA03', 'TCIA08', 'TCIA13', '2013'}
	 TCIA04 : 8
	 TCIA02 : 34
	 TCIA06 : 8
	 TCIA05 : 4
	 TCIA12 : 6
	 TCIA01 : 22
	 TCIA09 : 11
	 TCIA10 : 35
	 CBICA : 88
	 TCIA03 : 12
	 TCIA08 : 14
	 TCIA13 : 13
	 2013 : 30
institution:  {'CBICA', 'TMC'}
	 CBICA : 41
	 TMC : 9


In [16]:
# 合并一些文件, 因为资源确实有限, 目前试一下只使用若干的机构, 这些机构同时含有测试(validation)和训练
import json
import numpy as np
np.random.seed(100)


def write_to_clara_info(client_id, df, sub_name, train_validation_rate=0.85):
    path = os.path.join(PREFIX, 'config', 'fl_dataset_{}.json'.format(client_id))
    images_masks = []
    for gg_type, pid in zip(df['hgg_or_lgg'], df['id']):
        channel_filepath = [os.path.join(sub_name, gg_type.upper(), pid, pid + channel_suffix[i]) for i, c in enumerate(channels)]
        mask_filepath = os.path.join(sub_name, gg_type.upper(), pid, pid + mask_suffix)
        images_masks.append({'image': channel_filepath, 'label': mask_filepath})
    # 随机排序
    np.random.shuffle(images_masks)
    # 拆分数据
    n_trains = int(len(images_masks) * train_validation_rate)
    n_tests = len(images_masks) - n_trains
    print('Client ', client_id, ',num_train: ', n_trains, ', num_test: ', n_tests, ', write to ', path)
    # 
    target = dict(training=images_masks[:n_trains], validation=images_masks[n_trains:], description='BRATS')
    with open(path, 'w') as fp:
        json.dump(target, fp)

        
def merge_inst(inst1, inst2):
    df = pd.concat([inst1, inst2])
    return df

# 使用三张卡
# inst_merged = [
#         ['CBICA'], ['TCIA13', 'TCIA02', '2013', 'TCIA01', 'TCIA08'],
#         ['TCIA09', 'TCIA12', 'TCIA10', 'TCIA04', 'TCIA06', 'TCIA03', 'TCIA05']
#     ]
# 使用6张卡
inst_merged = [
        ['CBICA'], ['TCIA13', 'TCIA02', 'TCIA03'], ['2013'], ['TCIA01', 'TCIA08', 'TCIA06'],
        ['TCIA09', 'TCIA12', 'TCIA05'], ['TCIA10', 'TCIA04']
    ]
ms = []
for i, inss in enumerate(inst_merged):
    tag = all_2018_info[all_2018_info['institution'] == inss[0]]
    for j in range(1, len(inss)):
        tag = merge_inst(tag, all_2018_info[all_2018_info['institution'] == inss[j]])
    ms.append(tag)
for i, item in enumerate(ms):
    write_to_clara_info(i, item, 'MICCAI_BraTS_2018_Data_Training')

Client  0 ,num_train:  74 , num_test:  14 , write to  clara_seg_ct_brats_fl/config/fl_dataset_0.json
Client  1 ,num_train:  50 , num_test:  9 , write to  clara_seg_ct_brats_fl/config/fl_dataset_1.json
Client  2 ,num_train:  25 , num_test:  5 , write to  clara_seg_ct_brats_fl/config/fl_dataset_2.json
Client  3 ,num_train:  37 , num_test:  7 , write to  clara_seg_ct_brats_fl/config/fl_dataset_3.json
Client  4 ,num_train:  17 , num_test:  4 , write to  clara_seg_ct_brats_fl/config/fl_dataset_4.json
Client  5 ,num_train:  36 , num_test:  7 , write to  clara_seg_ct_brats_fl/config/fl_dataset_5.json


In [77]:
# Option 1, 写入 environment.json 格式的数据, 写入对应的 test! 似乎是以 . 开头的
target = dict()
env_json = os.path.join(BRATS_PREFIX, 'datalist.json')
sub_name = 'MICCAI_BraTS_2018_Data_Training'
# training 部分
images_masks = []
for gg_type, pid in zip(all_2018_info['hgg_or_lgg'], all_2018_info['id']):
    channel_filepath = [os.path.join('.', sub_name, gg_type.upper(), pid, pid + channel_suffix[i]) for i, c in enumerate(channels)]
    mask_filepath = os.path.join('.', sub_name, gg_type.upper(), pid, pid + mask_suffix)
    images_masks.append({'image': channel_filepath, 'label': mask_filepath})
target['training'] = images_masks
# testing 部分
images_masks = []
sub_name = 'MICCAI_BraTS_2019_Data_Training'
for gg_type, pid in zip(brats_2019_part['hgg_or_lgg'], brats_2019_part['id']):
    channel_filepath = [os.path.join('.', sub_name, gg_type.upper(), pid, pid + channel_suffix[i]) for i, c in enumerate(channels)]
    mask_filepath = os.path.join('.', sub_name, gg_type.upper(), pid, pid + mask_suffix)
    images_masks.append({'image': channel_filepath, 'label': mask_filepath})
target['test'] = images_masks
# 吸入
with open(env_json, 'w') as fp:
    json.dump(target, fp)

In [9]:
# Option 2, 写入json文件到程序的目录
target = dict()
env_json = 'clara_seg_ct_brats_fl/config/2018train_2019test.json'
sub_name = 'MICCAI_BraTS_2018_Data_Training'
# training 部分
images_masks = []
for gg_type, pid in zip(all_2018_info['hgg_or_lgg'], all_2018_info['id']):
    channel_filepath = [os.path.join(sub_name, gg_type.upper(), pid, pid + channel_suffix[i]) for i, c in enumerate(channels)]
    mask_filepath = os.path.join(sub_name, gg_type.upper(), pid, pid + mask_suffix)
    images_masks.append({'image': channel_filepath, 'label': mask_filepath})
target['training'] = images_masks
# testing 部分
images_masks = []
sub_name = 'MICCAI_BraTS_2019_Data_Training'
for gg_type, pid in zip(brats_2019_part['hgg_or_lgg'], brats_2019_part['id']):
    channel_filepath = [os.path.join(sub_name, gg_type.upper(), pid, pid + channel_suffix[i]) for i, c in enumerate(channels)]
    mask_filepath = os.path.join(sub_name, gg_type.upper(), pid, pid + mask_suffix)
    images_masks.append({'image': channel_filepath, 'label': mask_filepath})
target['test'] = images_masks
# 吸入
with open(env_json, 'w') as fp:
    json.dump(target, fp)

In [8]:
# 这里是写入普通的集中式学习方式
config_filepath = 'clara_seg_ct_brats/config/2018train_2019test.json'
target = dict()
sub_name = 'MICCAI_BraTS_2018_Data_Training'
# training 部分
images_masks = []
for gg_type, pid in zip(all_2018_info['hgg_or_lgg'], all_2018_info['id']):
    channel_filepath = [os.path.join(sub_name, gg_type.upper(), pid, pid + channel_suffix[i]) for i, c in enumerate(channels)]
    mask_filepath = os.path.join(sub_name, gg_type.upper(), pid, pid + mask_suffix)
    images_masks.append({'image': channel_filepath, 'label': mask_filepath})
target['training'] = images_masks
# testing 部分
images_masks = []
sub_name = 'MICCAI_BraTS_2019_Data_Training'
for gg_type, pid in zip(brats_2019_part['hgg_or_lgg'], brats_2019_part['id']):
    channel_filepath = [os.path.join(sub_name, gg_type.upper(), pid, pid + channel_suffix[i]) for i, c in enumerate(channels)]
    mask_filepath = os.path.join(sub_name, gg_type.upper(), pid, pid + mask_suffix)
    images_masks.append({'image': channel_filepath, 'label': mask_filepath})
target['validation'] = images_masks
target['test'] = []
# 吸入
with open(config_filepath, 'w') as fp:
    json.dump(target, fp)
