In [1]:
import os
import imp
import time
import utils
import torch
import pickle
import datetime
import numpy as np
from tqdm import tqdm
from utils.fcd import *
from utils.data_processor import *
from utils.crop import BrainDataSegCrop
from sklearn.model_selection import KFold

import warnings
warnings.filterwarnings("ignore")

In [2]:
"""
There 3 types of model training:
- whole: on the whole brain 
- temple: on the temple part of the brain
- nottemple: on the whole brain without temple part 
"""
BRAIN_TYPE = 'nottemple' # 'whole', 'temple', 'nottemple'

"""
There 2 types of Local Agragation Operators:
- pospoolxyz
- pointwisemlp
"""
CFG = 'cfgs/brain/brain_pospoolxyz.yaml' # 'cfgs/brain/brain_pointwisemlp.yaml'

"""
There 3 GPUs:
- 0
- 1
- 2
"""
DEVICE = 0
torch.cuda.set_device(DEVICE)
print(f'Current device: {torch.cuda.current_device()}')

FOLDED = True

REPEAT = 1
CROP_SIZE = 64
STEP_SIZE = 32
EPOCHS = 400

IS_RETURN_PC_WITHOUT_AIR_POINTS = False

EXP_NAME = "2021-11-04_16bs_abscoords_air_cropsize64_epochs400_nottemple"
IS_RETURN_ABS_COORDS = "abs" in EXP_NAME


#For loss
LOSS_TYPE = 'BCE'

Current device: 0


In [3]:
config = config_seting(CFG)

with open(f"predictions/KFold/{EXP_NAME}.pkl", 'rb') as f:
    kf = pickle.load(f)

In [4]:
allowed_subjects = np.load('raw_data/sub_with_all_data.npy', allow_pickle=True).tolist()

if not allowed_subjects:
    allowed_subdirs = [f"sub-{name[:-4]}" for name in os.listdir(f"raw_data/normalized_label")]
else:
    allowed_subdirs = [f'sub-{name}' for name in allowed_subjects]
    
brains = [[f"raw_data/output/fmriprep/{subdir}/anat/{name}" for name in os.listdir(f"raw_data/output/fmriprep/{subdir}/anat") if name.endswith("Asym_desc-preproc_T1w.nii.gz")] for subdir in os.listdir(f"raw_data/output/fmriprep/") if subdir.count(".") == 0 and subdir.startswith("sub") and subdir in allowed_subdirs]
brains = [item for sublist in brains for item in sublist]
brain_nums = [x.split('/')[3][4:] for x in brains]

if BRAIN_TYPE == 'whole':
    BRAIN_TYPE = 'full'
    
brains_loaded = [f"dataset_ready_to_use/data_only_usefull_areas/{BRAIN_TYPE}/brains/{num}.npy" for num in tqdm(brain_nums)]
labels_loaded = [f"dataset_ready_to_use/data_only_usefull_areas/{BRAIN_TYPE}/labels/{num}.npy" for num in tqdm(brain_nums)]
curv_loaded = [f"dataset_ready_to_use/data_only_usefull_areas/{BRAIN_TYPE}/curv/{num}.npy" for num in tqdm(brain_nums)]
thickness_loaded = [f"dataset_ready_to_use/data_only_usefull_areas/{BRAIN_TYPE}/thickness/{num}.npy" for num in tqdm(brain_nums)]
sulc_loaded = [f"dataset_ready_to_use/data_only_usefull_areas/{BRAIN_TYPE}/sulc/{num}.npy" for num in tqdm(brain_nums)]

100%|██████████| 81/81 [00:00<00:00, 179471.01it/s]
100%|██████████| 81/81 [00:00<00:00, 246187.41it/s]
100%|██████████| 81/81 [00:00<00:00, 229553.12it/s]
100%|██████████| 81/81 [00:00<00:00, 295682.01it/s]
100%|██████████| 81/81 [00:00<00:00, 402057.54it/s]


In [5]:
!mkdir predictions/{EXP_NAME}

In [6]:
def model_and_brains_to_points_with_predictions(config,
                                                kf,
                                                num_votes,
                                                repeat,
                                                brains_loaded,
                                                labels_loaded,
                                                curv_loaded,
                                                thickness_loaded,
                                                sulc_loaded,
                                                exp_name
                                               ):
    """
    """
    
    res = {}
    for e, (train_idxs, test_idxs) in tqdm(enumerate(kf.split(brains_loaded))): 
        
        test_dicts = [{'brains': [brains_loaded[idx]],
                       'labels': [labels_loaded[idx]],
#                        'curvs': [curv_loaded[idx]],
#                        'thickness': [thickness_loaded[idx]],
#                        'sulc': [sulc_loaded[idx]]
                      } for idx in test_idxs]
        
        model, criterion = build_multi_part_segmentation(config=config, weights=None, type=LOSS_TYPE)
        model.load_state_dict(torch.load(f'checkpoints/{exp_name}_{e + 1}_fold.pth'))
        model.cuda()
        model.eval()
        criterion.cuda()
        
        with torch.no_grad():
            
            transform_for_vote = d_utils.BatchPointcloudScaleAndJitter(scale_low=config.scale_low,
                                                                       scale_high=config.scale_high,
                                                                       std=config.noise_std,
                                                                       clip=config.noise_clip
                                                                      )
            
            softmax = torch.nn.Softmax(dim=0)
            
            data_loader = [BrainDataSegCrop(num_points=4096,
                                            task='test',
                                            data_dict = test_dict,
                                            step_size = STEP_SIZE,
                                            crop_size = CROP_SIZE,
                                            return_center = True,
                                            return_air_mask_test = True,
                                            is_folded = FOLDED,
                                            return_pc_without_air_points = IS_RETURN_PC_WITHOUT_AIR_POINTS, 
                                            return_abs_coords = IS_RETURN_ABS_COORDS,
                                            MEANS = {'brains': 161.2455},
                                            STDS = {'brains': 167.36417}
                                           ) for test_dict in test_dicts]
            
            for i, test_brain in tqdm(enumerate(test_idxs)):
                print(f"Working on brain {test_brain}")
                pred_soft_flats = []
                air_masks_flats = []
                points_labels_flats = []
                center_coords_flats = []
                points_orig_flats = []
                points_masks_flats = []#NNN
                all_logits = []
                all_points_labels = []
                all_shape_labels = []
                all_masks = []
                print(len(data_loader[i]))
                for k, crop in enumerate(data_loader[i]):
                    for _ in range(repeat):
                        points_orig, mask, points_labels, shape_labels, air_mask = [
                            crop[key] for key in [
                                "current_points", "mask", "current_points_labels", "label",'air_mask'
                            ]
                        ]
                        vote_logits = None
                        vote_points_labels = None
                        vote_shape_labels = None
                        vote_masks = None
                        points_orig = points_orig.unsqueeze(0)
                        mask = mask.unsqueeze(0)
                        points_labels = points_labels.unsqueeze(0)
                        shape_labels = shape_labels.unsqueeze(0)
                        preds = []
                        for v in range(num_votes):
                            batch_logits = []
                            batch_points_labels = []
                            batch_shape_labels = []
                            batch_masks = []
                            if v > 0:
                                points = transform_for_vote(points_orig)
                            else:
                                points = points_orig
                            # forward
                            features = points
                            features = features.transpose(1, 2).contiguous()
                            points = points[:, :, :3].cuda(non_blocking=True)
                            mask = mask.cuda(non_blocking=True)
                            features = features.cuda(non_blocking=True)
                            points_labels = points_labels.cuda(non_blocking=True)
                            shape_labels = shape_labels.cuda(non_blocking=True)
                            pred = model(points, mask, features)
                            preds.append(pred[0])
                        preds = torch.cat(preds).mean(dim=0)
                        
                        points_orig = points_orig.squeeze(0)                        
                        pred_soft_flats += softmax(preds)[1,:].reshape(-1).detach().cpu().numpy().tolist()
                        points_labels_flats += points_labels.reshape(-1).detach().cpu().numpy().tolist()
                        air_masks_flats += air_mask.reshape(-1).tolist()
                        if IS_RETURN_ABS_COORDS:
                            size_abs = (241, 336, 283)
                            means = np.array([x // 2 for x in size_abs])
                            half_range = np.array([x // 2 for x in size_abs])
                            points_orig_flats += (points_orig[:, :3].detach().cpu().numpy() * half_range + means).tolist()
                            
                        else:
                            points_orig_flats += (((points_orig[:, :3]) * (CROP_SIZE//2) + (CROP_SIZE//2)).detach().cpu().numpy() + np.array(center_coords)).tolist()

                np.save(f"predictions/{exp_name}/brain{test_brain}", (points_orig_flats, pred_soft_flats, points_labels_flats,air_masks_flats), allow_pickle=True)
                print(k)
                print(f"brain {test_brain} done")
                res[test_brain] = k
            del data_loader


In [7]:
model_and_brains_to_points_with_predictions(config=config,
                                                kf=kf,
                                                num_votes=1,
                                                repeat=1,
                                                brains_loaded=brains_loaded,
                                                labels_loaded=labels_loaded,
                                                curv_loaded=curv_loaded,
                                                thickness_loaded=thickness_loaded,
                                                sulc_loaded = sulc_loaded,
                                                exp_name = EXP_NAME)

0it [00:00, ?it/s]
0it [00:00, ?it/s][A

Working on brain 0
49



1it [00:11, 11.29s/it][A

48
brain 0 done
Working on brain 4
49



2it [00:22, 11.24s/it][A

48
brain 4 done
Working on brain 10
49



3it [00:33, 11.22s/it][A

48
brain 10 done
Working on brain 12
49



4it [00:44, 11.22s/it][A

48
brain 12 done
Working on brain 18
45



5it [00:55, 11.02s/it][A

44
brain 18 done
Working on brain 22
49



6it [01:06, 11.06s/it][A

48
brain 22 done
Working on brain 28
48



7it [01:17, 11.05s/it][A

47
brain 28 done
Working on brain 30
49



8it [01:28, 11.08s/it][A

48
brain 30 done
Working on brain 31
48



9it [01:39, 11.08s/it][A

47
brain 31 done
Working on brain 33
49



10it [01:50, 11.10s/it][A

48
brain 33 done
Working on brain 35
46



11it [02:01, 11.06s/it][A

45
brain 35 done
Working on brain 45
49



12it [02:12, 11.08s/it][A

48
brain 45 done
Working on brain 49
49



13it [02:24, 11.10s/it][A

48
brain 49 done
Working on brain 67
48



14it [02:35, 11.10s/it][A

47
brain 67 done
Working on brain 68
48



15it [02:46, 11.10s/it][A

47
brain 68 done
Working on brain 70
49



16it [02:57, 11.12s/it][A

48
brain 70 done
Working on brain 73
48



17it [03:09, 11.12s/it][A
1it [04:08, 248.53s/it]

47
brain 73 done



0it [00:00, ?it/s][A

Working on brain 5
49



1it [00:11, 11.45s/it][A

48
brain 5 done
Working on brain 7
49



2it [00:22, 11.40s/it][A

48
brain 7 done
Working on brain 9
49



3it [00:34, 11.39s/it][A

48
brain 9 done
Working on brain 16
49



4it [00:45, 11.38s/it][A

48
brain 16 done
Working on brain 34
49



5it [00:56, 11.39s/it][A

48
brain 34 done
Working on brain 39
49



6it [01:08, 11.38s/it][A

48
brain 39 done
Working on brain 40
49



7it [01:19, 11.38s/it][A

48
brain 40 done
Working on brain 42
48



8it [01:30, 11.35s/it][A

47
brain 42 done
Working on brain 47
48



9it [01:41, 11.32s/it][A

47
brain 47 done
Working on brain 54
47



10it [01:52, 11.28s/it][A

46
brain 54 done
Working on brain 55
49



11it [02:04, 11.29s/it][A

48
brain 55 done
Working on brain 56
49



12it [02:15, 11.29s/it][A

48
brain 56 done
Working on brain 61
40



13it [02:24, 11.14s/it][A

39
brain 61 done
Working on brain 62
49



14it [02:36, 11.15s/it][A

48
brain 62 done
Working on brain 64
49



15it [02:47, 11.17s/it][A

48
brain 64 done
Working on brain 80
49



16it [02:58, 11.18s/it][A
2it [07:43, 231.56s/it]

48
brain 80 done



0it [00:00, ?it/s][A

Working on brain 3
48



1it [00:11, 11.24s/it][A

47
brain 3 done
Working on brain 6
49



2it [00:22, 11.30s/it][A

48
brain 6 done
Working on brain 8
48



3it [00:33, 11.25s/it][A

47
brain 8 done
Working on brain 13
48



4it [00:44, 11.22s/it][A

47
brain 13 done
Working on brain 17
47



5it [00:55, 11.15s/it][A

46
brain 17 done
Working on brain 19
47



6it [01:06, 11.10s/it][A

46
brain 19 done
Working on brain 25
48



7it [01:17, 11.11s/it][A

47
brain 25 done
Working on brain 36
48



8it [01:28, 11.11s/it][A

47
brain 36 done
Working on brain 38
44



9it [01:39, 11.01s/it][A

43
brain 38 done
Working on brain 44
48



10it [01:50, 11.02s/it][A

47
brain 44 done
Working on brain 50
49



11it [02:01, 11.06s/it][A

48
brain 50 done
Working on brain 53
49



12it [02:12, 11.08s/it][A

48
brain 53 done
Working on brain 65
49



13it [02:24, 11.10s/it][A

48
brain 65 done
Working on brain 66
48



14it [02:35, 11.10s/it][A

47
brain 66 done
Working on brain 72
48



15it [02:46, 11.11s/it][A

47
brain 72 done
Working on brain 77
49



16it [02:57, 11.12s/it][A
3it [11:12, 224.33s/it]

48
brain 77 done



0it [00:00, ?it/s][A

Working on brain 11
48



1it [00:11, 11.14s/it][A

47
brain 11 done
Working on brain 15
49



2it [00:22, 11.27s/it][A

48
brain 15 done
Working on brain 24
49



3it [00:33, 11.30s/it][A

48
brain 24 done
Working on brain 26
48



4it [00:45, 11.25s/it][A

47
brain 26 done
Working on brain 27
48



5it [00:56, 11.23s/it][A

47
brain 27 done
Working on brain 32
48



6it [01:07, 11.21s/it][A

47
brain 32 done
Working on brain 41
42



7it [01:16, 11.00s/it][A

41
brain 41 done
Working on brain 43
47



8it [01:27, 10.98s/it][A

46
brain 43 done
Working on brain 46
49



9it [01:39, 11.03s/it][A

48
brain 46 done
Working on brain 48
48



10it [01:50, 11.03s/it][A

47
brain 48 done
Working on brain 57
49



11it [02:01, 11.07s/it][A

48
brain 57 done
Working on brain 58
49



12it [02:13, 11.09s/it][A

48
brain 58 done
Working on brain 59
49



13it [02:24, 11.11s/it][A

48
brain 59 done
Working on brain 76
39



14it [02:33, 10.96s/it][A

38
brain 76 done
Working on brain 78
42



15it [02:43, 10.88s/it][A

41
brain 78 done
Working on brain 79
49



16it [02:54, 10.91s/it][A
4it [14:40, 220.03s/it]

48
brain 79 done



0it [00:00, ?it/s][A

Working on brain 1
49



1it [00:11, 11.42s/it][A

48
brain 1 done
Working on brain 2
49



2it [00:22, 11.39s/it][A

48
brain 2 done
Working on brain 14
48



3it [00:33, 11.32s/it][A

47
brain 14 done
Working on brain 20
49



4it [00:45, 11.33s/it][A

48
brain 20 done
Working on brain 21
48



5it [00:56, 11.29s/it][A

47
brain 21 done
Working on brain 23
49



6it [01:07, 11.30s/it][A

48
brain 23 done
Working on brain 29
49



7it [01:19, 11.31s/it][A

48
brain 29 done
Working on brain 37
44



8it [01:29, 11.16s/it][A

43
brain 37 done
Working on brain 51
45



9it [01:39, 11.08s/it][A

44
brain 51 done
Working on brain 52
49



10it [01:51, 11.11s/it][A

48
brain 52 done
Working on brain 60
49



11it [02:02, 11.14s/it][A

48
brain 60 done
Working on brain 63
49



12it [02:13, 11.15s/it][A

48
brain 63 done
Working on brain 69
49



13it [02:25, 11.17s/it][A

48
brain 69 done
Working on brain 71
49



14it [02:36, 11.18s/it][A

48
brain 71 done
Working on brain 74
48



15it [02:47, 11.17s/it][A

47
brain 74 done
Working on brain 75
49



16it [02:59, 11.19s/it][A
5it [18:11, 218.35s/it]

48
brain 75 done





In [10]:
!nvidia-smi

Fri Oct 29 08:18:55 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.64       Driver Version: 440.64       CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 108...  Off  | 00000000:01:00.0 Off |                  N/A |
| 20%   33C    P8    16W / 250W |     10MiB / 11177MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce GTX 108...  Off  | 00000000:02:00.0 Off |                  N/A |
| 20%   31C    P2    54W / 250W |  11174MiB / 11178MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-------

In [10]:
import numpy as np
from tqdm import tqdm

In [15]:
if BRAIN_TYPE == 'full':
    res_dict = {}
    for file in glob.glob(f'predictions/{EXP_NAME}/*'):
        a = np.load(file, allow_pickle=True)
        brain = int(file.split('/')[-1].split('brain')[-1].split('.npy')[0])
        res_dict[brain] = a
elif BRAIN_TYPE == 'temple_nottemple':
    res_dict = {}
    for num in range(len(brain_nums)):
        file1 = f'predictions/{EXP_NAME1}/brain{num}.npy'
        file2 = f'predictions/{EXP_NAME2}/brain{num}.npy'
        a = np.load(file1, allow_pickle=True)
        b = np.load(file2, allow_pickle=True)
        c = []
        for i in range(len(a)):
            c.append(np.concatenate([a[i],b[i]]))
        res_dict[num] = c
        np.save(f'predictions/{EXP_NAME}/brain{num}.npy',c)

In [16]:
def result_to_metrics(res, th = 0.5):
    pred = res[1]
    true = res[2]
    mask = res[3]
    pred = [pred[i] for i in range(len(mask)) if mask[i]==0]
    true = [true[i] for i in range(len(mask)) if mask[i]==0]
    if np.sum(true) == 0:
        return -999, -999, -999, -999, -999, -999
    pred = np.array(pred)
    conf = confusion_matrix(true, pred > th)
    IoU = conf[1,1] / (conf[1, 1] + conf[1, 0] + conf[0,1])
    dice = 2 * conf[1,1] / (2 * conf[1, 1] + conf[1, 0] + conf[0, 1])
    d_in = np.dot(true, pred) / np.sum(true)
    d_out = np.dot(np.ones(len(true)) - np.array(true), pred) / (len(true) - np.sum(true))
    contrast = (d_in - d_out) /  (d_in + d_out)
    roc = roc_auc_score(true, pred)
    recall = conf[1,1] / (conf[1, 1] + conf[1, 0])
    return conf, IoU, dice, contrast, roc, recall


def res_dict_to_metrics(res_dict, th = 0.5):
    res_metrics = {}
    for brain in tqdm(res_dict):
        res_metrics[brain] = {}
        res = res_dict[brain]
        res_metrics[brain]['conf_all'],res_metrics[brain]['IoU_all'],res_metrics[brain]['dice_all'],res_metrics[brain]['contrast_all'],res_metrics[brain]['roc_all'],res_metrics[brain]['recall_all'] = result_to_metrics(res,th)
        len_1_part = len(res[0]) // REPEAT
        res_metrics[brain]['confs'] = []
        res_metrics[brain]['IoUs'] = []
        res_metrics[brain]['dices'] = []
        res_metrics[brain]['contrasts'] = []
        res_metrics[brain]['roc'] = []
        res_metrics[brain]['recall'] = []
        for repeat in range(REPEAT):
            res_part = [res[i][len_1_part * repeat:len_1_part * (repeat + 1)] for i in range(4)]
            conf, IoU, dice, contrast, roc, recall = result_to_metrics(res_part, th)
            res_metrics[brain]['confs'].append(conf)
            res_metrics[brain]['IoUs'].append(IoU)
            res_metrics[brain]['dices'].append(dice)
            res_metrics[brain]['contrasts'].append(contrast)
            res_metrics[brain]['roc'].append(roc)
            res_metrics[brain]['recall'].append(recall)
    return res_metrics

In [17]:
final = {}
for th in [0.5]:#np.linspace(0.001,0.999,35):
    res_metrics=res_dict_to_metrics(res_dict,th)
    final[th] = res_metrics


  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:00<00:07,  2.23it/s][A
 12%|█▏        | 2/17 [00:00<00:04,  3.55it/s][A
 18%|█▊        | 3/17 [00:00<00:04,  3.01it/s][A
 24%|██▎       | 4/17 [00:01<00:04,  2.74it/s][A
 29%|██▉       | 5/17 [00:01<00:04,  2.60it/s][A
 35%|███▌      | 6/17 [00:02<00:04,  2.52it/s][A
 41%|████      | 7/17 [00:02<00:04,  2.46it/s][A
 47%|████▋     | 8/17 [00:02<00:03,  2.71it/s][A
 53%|█████▎    | 9/17 [00:03<00:03,  2.66it/s][A
 59%|█████▉    | 10/17 [00:03<00:02,  2.63it/s][A
 65%|██████▍   | 11/17 [00:03<00:02,  2.81it/s][A
 71%|███████   | 12/17 [00:04<00:01,  2.75it/s][A
 76%|███████▋  | 13/17 [00:04<00:01,  2.71it/s][A
 82%|████████▏ | 14/17 [00:05<00:01,  2.67it/s][A
 88%|████████▊ | 15/17 [00:05<00:00,  2.63it/s][A
 94%|█████████▍| 16/17 [00:06<00:00,  2.61it/s][A
100%|██████████| 17/17 [00:06<00:00,  2.58it/s][A
[A

In [18]:
for th in [0.5]:#np.linspace(0.001,0.999,35):
    tmp = []
    for brain in final[th]:
        tmp.append(final[th][brain]['dice_all'])
    print(th, np.mean(tmp))

0.5 -176.2890909980588


In [28]:
print(EXP_NAME)
for th in [0.5]:
    for brain in final[th]:
        print(str(final[th][brain]['contrast_all']).replace('.',','))

2021-10-16_16bs_abscoords_air_cropsize64_epochs400_whole
-0,5453059781310424
-999
0,16212565185072944
0,3254029387266164
0,289756800487056
-0,07076440306089676
-0,20348429852306948
-999
0,35766872661995175
0,28154538878747326
-999
0,08520021962822351
-0,23943755223604346
0,20373747463776357
0,08962772003588299
0,27886874113518345
0,025790812399981825


In [23]:
final[th]

{22: {'conf_all': array([[32156, 21070],
         [    6,     0]]),
  'IoU_all': 0.0,
  'dice_all': 0.0,
  'contrast_all': -0.5453059781310424,
  'roc_all': 0.11893310286952491,
  'recall_all': 0.0,
  'confs': [array([[32156, 21070],
          [    6,     0]])],
  'IoUs': [0.0],
  'dices': [0.0],
  'contrasts': [-0.5453059781310424],
  'roc': [0.11893310286952491],
  'recall': [0.0]},
 30: {'conf_all': -999,
  'IoU_all': -999,
  'dice_all': -999,
  'contrast_all': -999,
  'roc_all': -999,
  'recall_all': -999,
  'confs': [-999],
  'IoUs': [-999],
  'dices': [-999],
  'contrasts': [-999],
  'roc': [-999],
  'recall': [-999]},
 49: {'conf_all': array([[38101, 13042],
         [   55,    44]]),
  'IoU_all': 0.0033482992161935924,
  'dice_all': 0.006674251042851726,
  'contrast_all': 0.16212565185072944,
  'roc_all': 0.661540714617382,
  'recall_all': 0.4444444444444444,
  'confs': [array([[38101, 13042],
          [   55,    44]])],
  'IoUs': [0.0033482992161935924],
  'dices': [0.0066742