In [1]:
import reg_mri
import os
from glob import glob
from utils import compute_mean_dice
import nibabel as nib
from scipy.spatial.distance import dice
import numpy as np
import itk
import SimpleITK as sitk
import scipy.ndimage
import scipy
import matplotlib.pyplot as plt
from transforms_dict import getRegistrationEvalInverseTransformForMRI, SaveTransformForMRI
from tqdm import tqdm
import monai
import subprocess
from monai.networks.blocks import Warp
import torch

In [None]:
datasets = ["Feminad", "GIN", "Painfact", "IRIS"]
atlas_name = "dataset3/Atlas/P56_Atlas_128_norm_id.nii.gz"
atlas_mask_name = "dataset3/Atlas/P56_Annotation_128_norm_id_mask.nii.gz"
atlas_mri = nib.load(atlas_name)
atlas_mask = nib.load(atlas_mask_name).get_fdata()
for dataset in datasets:
    masks_raw = sorted(glob(os.path.join('dataset3', dataset, 'Mask', "*_id.nii.gz")))    
    masks_aff = sorted(glob(os.path.join('dataset3', dataset, 'Mask', "*_affine.nii.gz")))    
    masks_def = sorted(glob(os.path.join('dataset3', dataset, 'Mask', "*_deformable.nii.gz")))
    name_list = ["raw", "aff", "def"]
    for i, mask_list in enumerate([masks_raw, masks_aff, masks_def]):
        x = 0
        for mask in mask_list:
            mask = nib.load(mask).get_fdata()
            x += 1 - dice(mask.ravel(), atlas_mask.ravel())
        x /= len(mask_list)
        print(dataset + " - " + name_list[i] + ": " + str(x))

Feminad - raw: 0.591666735297345
Feminad - aff: 0.932753127610555
Feminad - def: 0.9374410507527555
GIN - raw: 0.15000594749456136
GIN - aff: 0.9389306258713226
GIN - def: 0.9320385349140773
Painfact - raw: 0.5770233026539917
Painfact - aff: 0.9200568594427992
Painfact - def: 0.9398382252906364
IRIS - raw: 0.6905348615957821
IRIS - aff: 0.8702028336437295
IRIS - def: 0.8570262463228943

In [2]:
outfolders = ["baseline-sym-ddf",
              "baseline-sym-noise", 
              
              "baseline-ddf",
              "baseline-noise",
               
              "baseline",
              "baseline-sym",]
models = ["gin-baseline+sym-0.1+ddf-8.0_1.0-0.0-8.0.pth",
          "gin-baseline+sym-0.1+noise-8.0_1.0-0.0-0.001.pth",
          
          "gin-baseline+ddf-8.0_1.0-0.0-8.0.pth",
          "gin-baseline+noise-0.1_1.0-0.0-0.0.pth",
          
          "gin-baseline_1.0-0.0-0.0.pth",
          "gin-baseline+sym-0.1_1.0-0.0-0.0.pth",]
for i, model in enumerate(models):
    outdataset = 'GIN'
    outfolder = outfolders[i]
    mris = sorted(glob(os.path.join('dataset3', outdataset, 'MRI', "*_affine.nii.gz")))
    atlas_name = "dataset3/Atlas/P56_Atlas_128_norm_id.nii.gz"
    warp = Warp("nearest", "border")    
    ants_warp = sorted(glob(os.path.join("output", "Feminad", "ANTS", "DeformableWarp", "*.nii.gz")))
    affine = nib.load(ants_warp[0]).affine
    header = nib.load(ants_warp[0]).header 

    for i, mri in enumerate(mris):
        outname = "output/" + outdataset + "/" + outfolder + "/MRI_N4_Registration_Deformable/" + mri.split('/')[-1].split('.')[0]
        if "newmodel" in model:
            newmodel = True
        else:
            newmodel = False
            
        if "sym" in model:
            sym = True
            pred_images, ddfs = reg_mri.main(model, mri, outname, False, True, "local", newmodel=newmodel, sym=sym)
            pred_image = pred_images[0]
            predsym_image = pred_images[1]
        else:
            sym = False
            pred_image, ddfs = reg_mri.main(model, mri, outname, False, True, "local", newmodel=newmodel, sym=sym)

        moving_image = ddfs[0].cpu()

        deformable_ddf = ddfs[1].cpu()
        deformable_ddf = deformable_ddf.cpu().numpy()
        deformable_ddf = np.transpose(deformable_ddf, (2, 3, 4, 0, 1)) 
        deformable_ddf = deformable_ddf * [-1, -1, 1]
        deformable_image = nib.Nifti1Image(deformable_ddf, affine, header)
        outname_deformable = "output/" + outdataset + "/" + outfolder + "/DeformableWarp/" + mri.split('/')[-1].split('.')[0] + "_warp.nii.gz"
        nib.save(deformable_image, outname_deformable)
        
        if sym:
            deformable_sym_ddf = ddfs[2].cpu()
            deformable_sym_ddf = deformable_sym_ddf.cpu().numpy()
            deformable_sym_ddf = np.transpose(deformable_sym_ddf, (2, 3, 4, 0, 1)) 
            deformable_sym_ddf = deformable_sym_ddf * [-1, -1, 1]
            deformable_sym_image = nib.Nifti1Image(deformable_sym_ddf, affine, header)
            outname_sym_deformable = "output/" + outdataset + "/" + outfolder + "/DeformableWarp/" + mri.split('/')[-1].split('.')[0] + "_invwarp.nii.gz"
            nib.save(deformable_sym_image, outname_sym_deformable) 
            
        
        if outdataset == 'GIN':
            labels = sorted(glob(os.path.join('dataset3', outdataset, 'Labels', "*_affine.nii.gz")))
            label = nib.load(labels[i])
            affine = label.affine
            header = label.header
            label = torch.from_numpy(np.reshape(label.get_fdata(), (1,1,128,128,128))).float()
            pred_labels = warp(label, ddfs[1].cpu()).squeeze()
            out_pred = nib.Nifti1Image(pred_labels, affine, header)
            outname = "output/" + outdataset + "/" + outfolder + "/Labels/" + mri.split('/')[-1].split('.')[0] + "_dl_labels.nii.gz"
            nib.save(out_pred, outname)           


monai.networks.blocks.Warp: Using PyTorch native grid_sample.


KeyboardInterrupt



In [2]:
import json
allen_json = "dataset3/Atlas/1.json"

extract_name = "CTX"
with open(allen_json, "r") as read_file:
    data = json.load(read_file)['msg'][0]
    while len(data['children']) != 0:
        print(len(data['children']))
        print('xd')
        break
    print(x)

5
xd


NameError: name 'x' is not defined

In [3]:
atlas_labels_name = "dataset3/Atlas/P56_Annotation_128_norm_id.nii.gz"
atlas_labels = nib.load(atlas_labels_name).get_fdata()
deep_labels = sorted(glob(os.path.join('output', 'GIN', 'test', 'Labels', '*.nii.gz')))
ants_labels = sorted(glob(os.path.join('dataset3', 'GIN', 'Labels', "*_deformable.nii.gz")))
labels_mapping = "dataset3/Atlas/labels_mapping.csv"
for i, deep_label in enumerate(deep_labels):
    deep_label = nib.load(deep_label).get_fdata()
    ants_label = nib.load(ants_labels[i]).get_fdata()
    for i in range(-1,5000):
        test = atlas_labels == i
        print(str(i) + ': ' + str(test.sum()))
    break
    

In [4]:
def query_data_by_name(data, name):
    list = []
    for i in range(len(data)):
        if data['acronym'][i] == name:
            j = i 
            id = data['id'][j]
    id = "/" + str(id) + "/"
    for i in range(len(data)):
        if id in data['structure_id_path'][i]:
            list.append(data['graph_order'][i])
    return list    

In [5]:
import pandas as pd
import nibabel as nib
import numpy as np
atlas_labels_name = "dataset3/Atlas/P56_Annotation_128_norm_id.nii.gz"
atlas_labels = nib.load(atlas_labels_name)


allen_csv = "dataset3/Atlas/query.csv"
data = pd.read_csv(allen_csv)
liste = query_data_by_name(data, "HPF")

atlas_labels_data = np.isin(atlas_labels.get_fdata(),liste)

image = nib.Nifti1Image(atlas_labels_data, atlas_labels.affine, atlas_labels.header)
nib.save(image, "test.nii.gz")



In [6]:
mapping_csv = "dataset3/GIN/labels_mapping.csv"
data = pd.read_csv(mapping_csv)
allen_csv = "dataset3/Atlas/query.csv"
data_atlas = pd.read_csv(allen_csv)



image_zero = np.zeros(atlas_labels.shape)
image_zero2 = np.zeros(atlas_labels.shape)
print(data)
for i in range(len(data)):
    code = data['code'][i]
    print(code)
    value = data['map6nifti'][i]
    value2 = data['svbpnifti'][i]
    liste = query_data_by_name(data_atlas, code)
    if code == 'MOB':
        #liste.append(1106)
        liste.append(267)
        #liste.append(268)
        #liste.append(379)
        liste.append(390)
        liste.append(416)
        liste.append(559)
        #liste.append(416)
        liste.append(388)
        liste.append(405)
        #liste.append(1111)
        liste.remove(416)
        #liste.remove(379)
        liste.remove(559)
        #liste.append(826)           
        liste.remove(267)
        liste.append(379)
    if code == 'CB':
        liste.append(1188)
    #if code == 'STR':
    #    liste.remove(575)
    if code == 'CTX':
        print(liste)
        liste.append(268)
        liste.append(416) 
        liste.append(514)
        liste.append(512)
        liste.append(509)
        liste.append(516)
        liste.append(508)
        liste.append(503)
        liste.append(500)
        liste.append(496)
        liste.append(505)
        liste.append(495)
        liste.append(379)
        liste.append(559)
        liste.append(554)
        liste.append(523)
        liste.append(267)
        liste.remove(410)
        liste.remove(399)
        liste.remove(379)
        print(liste)
    if code == 'MBsen':
        liste.append(831)
        liste.append(832)
        liste.append(833)
        liste.append(834)
        #liste.append(806)
        liste.append(809)
        #liste.append(810)
        #liste.append(826)
        liste.remove(818)
        liste.remove(817)
        #liste.remove(815)
    if code == 'HPF':        
        liste.remove(514)
        liste.remove(512)
        liste.remove(509)
        liste.remove(516)
        liste.remove(508)
        liste.remove(503)
        liste.remove(500)
        liste.remove(496)
        liste.remove(505)
        liste.remove(495)
        liste.remove(554)
        liste.remove(523)
    if code == 'MY':
        liste.append(1180)
    if code == 'TH':        
        liste.remove(712)
        liste.remove(713)
        
        
        
    image_zero[np.where(np.isin(atlas_labels.get_fdata(), liste))] = value
    image_zero2[np.where(np.isin(atlas_labels.get_fdata(), liste))] = value2
    
#image = nib.Nifti1Image(image_zero, atlas_labels.affine, atlas_labels.header)
#nib.save(image, "atlas_gin_map6.nii.gz")
#image2 = nib.Nifti1Image(image_zero2, atlas_labels.affine, atlas_labels.header)
#nib.save(image2, "atlas_gin_svbp.nii.gz")

                    name     code     map6tif  map6nifti        svbptif  \
0                 cortex      CTX      yellow          5      darkgreen   
1  hippocampal formation      HPF   darkgreen          3            idk   
2         olfactory bulb      OLF  lightgreen          9     lightgreen   
3               striatum      STR    darkblue          4      lightblue   
4        globus pallidus  GPe/GPi        pink          6  darkgreenblue   
5              cerebelum       CB      salmon          8         yellow   
6               coliculi    MBsen         red          2           pink   
7               thalamus       TH   lightblue          7         salmon   
8           hypothalamus       HY      purple          1            red   
9              brainstem       BS      orange         10           none   

   svbpnifti           atlasnifti  Unnamed: 7  
0          2                    3         NaN  
1          3                  454         NaN  
2          1              379;

UnboundLocalError: local variable 'id' referenced before assignment