In [27]:
import numpy as np
import pandas as pd
import time
import torch

In [9]:
from functools import wraps
import time


def timeit(func):
    @wraps(func)
    def timeit_wrapper(*args, **kwargs):
        start_time = time.perf_counter()
        result = func(*args, **kwargs)
        end_time = time.perf_counter()
        total_time = end_time - start_time
        print(f'Function {func.__name__} Took {total_time:.4f} seconds')
        return result
    return timeit_wrapper


## import some test data

In [3]:
skels = np.load("/neurospin/dico/data/deep_folding/current/datasets/UkBioBank/crops/2mm/CINGULATE/mask/subsets/Rskeleton_most_folded_551.npy")
foldlabels = np.load("/neurospin/dico/data/deep_folding/current/datasets/UkBioBank/crops/2mm/CINGULATE/mask/subsets/Rlabel_most_folded_551.npy")

print(skels.shape)
print(foldlabels.shape)

(551, 17, 40, 38, 1)
(551, 17, 40, 38, 1)


In [4]:
np.unique(foldlabels)

array([   0,    1,    2, ..., 7626, 7690, 7713], dtype=int16)

## intersection

In [10]:
@timeit
def count_non_null(arr):
    return (arr != 0).sum()

In [11]:
count_non_null(foldlabels)

Function count_non_null Took 0.0287 seconds


497595

In [12]:
@timeit
def intersection_skeleton_foldlabel(arr_foldlabel, arr_skel):
    """It returns the intersection between skeleton and foldlabel
    """
    mask = ( (arr_foldlabel != 0) ).astype(int) 
    intersec = arr_skel*mask
    count_intersec = count_non_null(intersec)
    count_skel = count_non_null(arr_skel)
    count_foldlabel = count_non_null(arr_foldlabel)
    if count_intersec != count_skel or count_foldlabel != count_skel:
        raise ValueError("Probably misaligned skeleton and foldlabel\n"
                         f"Intersection between skeleton and foldlabel "
                         f"has {count_intersec} non-null elements.\n"
                         f"Skeleton has {count_skel} non-null elements.\n"
                         f"Foldlabel has {count_foldlabel} non-null elements.")
    return count_intersec

start = time.time()

for i in range(1000):
    print(i, intersection_skeleton_foldlabel(foldlabels, skels))

end = time.time()

end - start

Function count_non_null Took 0.0325 seconds
Function count_non_null Took 0.0239 seconds
Function count_non_null Took 0.0231 seconds
Function intersection_skeleton_foldlabel Took 0.1993 seconds
0 497595
Function count_non_null Took 0.0318 seconds
Function count_non_null Took 0.0219 seconds
Function count_non_null Took 0.0213 seconds
Function intersection_skeleton_foldlabel Took 0.1908 seconds
1 497595
Function count_non_null Took 0.0300 seconds
Function count_non_null Took 0.0211 seconds
Function count_non_null Took 0.0212 seconds
Function intersection_skeleton_foldlabel Took 0.1794 seconds
2 497595
Function count_non_null Took 0.0293 seconds
Function count_non_null Took 0.0214 seconds
Function count_non_null Took 0.0208 seconds
Function intersection_skeleton_foldlabel Took 0.1789 seconds
3 497595
Function count_non_null Took 0.0295 seconds
Function count_non_null Took 0.0210 seconds
Function count_non_null Took 0.0213 seconds
Function intersection_skeleton_foldlabel Took 0.1762 seconds

172.2940390110016

In [13]:
@timeit
def intersection_skeleton_foldlabel(arr_foldlabel, arr_skel):
    """It returns the intersection between skeleton and foldlabel
    """
    intersec = np.copy(arr_skel)
    intersec[arr_foldlabel == 0] = 0
    count_intersec = count_non_null(intersec)
    count_skel = count_non_null(arr_skel)
    count_foldlabel = count_non_null(arr_foldlabel)
    if count_intersec != count_skel or count_foldlabel != count_skel:
        raise ValueError("Probably misaligned skeleton and foldlabel\n"
                         f"Intersection between skeleton and foldlabel "
                         f"has {count_intersec} non-null elements.\n"
                         f"Skeleton has {count_skel} non-null elements.\n"
                         f"Foldlabel has {count_foldlabel} non-null elements.")
    return count_intersec

start = time.time()

for i in range(1000):
    print(i, intersection_skeleton_foldlabel(foldlabels, skels))

end = time.time()

end - start

Function count_non_null Took 0.0205 seconds
Function count_non_null Took 0.0194 seconds
Function count_non_null Took 0.0199 seconds
Function intersection_skeleton_foldlabel Took 0.0996 seconds
0 497595
Function count_non_null Took 0.0204 seconds
Function count_non_null Took 0.0198 seconds
Function count_non_null Took 0.0208 seconds
Function intersection_skeleton_foldlabel Took 0.0882 seconds
1 497595
Function count_non_null Took 0.0206 seconds
Function count_non_null Took 0.0208 seconds
Function count_non_null Took 0.0213 seconds
Function intersection_skeleton_foldlabel Took 0.0886 seconds
2 497595
Function count_non_null Took 0.0201 seconds
Function count_non_null Took 0.0196 seconds
Function count_non_null Took 0.0195 seconds
Function intersection_skeleton_foldlabel Took 0.0864 seconds
3 497595
Function count_non_null Took 0.0200 seconds
Function count_non_null Took 0.0198 seconds
Function count_non_null Took 0.0196 seconds
Function intersection_skeleton_foldlabel Took 0.0853 seconds

88.08584809303284

In [14]:
@timeit
def intersection_skeleton_foldlabel(arr_foldlabel, arr_skel):
    """It returns the intersection between skeleton and foldlabel
    """
    intersec = np.zeros(arr_skel.shape)
    intersec[(arr_foldlabel != 0)&(arr_skel != 0)] = 1
    count_intersec = count_non_null(intersec)
    count_skel = count_non_null(arr_skel)
    count_foldlabel = count_non_null(arr_foldlabel)
    if count_intersec != count_skel or count_foldlabel != count_skel:
        raise ValueError("Probably misaligned skeleton and foldlabel\n"
                         f"Intersection between skeleton and foldlabel "
                         f"has {count_intersec} non-null elements.\n"
                         f"Skeleton has {count_skel} non-null elements.\n"
                         f"Foldlabel has {count_foldlabel} non-null elements.")
    return count_intersec

start = time.time()

for i in range(1000):
    print(i, intersection_skeleton_foldlabel(foldlabels, skels))

end = time.time()

end - start

Function count_non_null Took 0.0425 seconds
Function count_non_null Took 0.0234 seconds
Function count_non_null Took 0.0234 seconds
Function intersection_skeleton_foldlabel Took 0.1604 seconds
0 497595
Function count_non_null Took 0.0360 seconds
Function count_non_null Took 0.0191 seconds
Function count_non_null Took 0.0196 seconds
Function intersection_skeleton_foldlabel Took 0.1396 seconds
1 497595
Function count_non_null Took 0.0373 seconds
Function count_non_null Took 0.0206 seconds
Function count_non_null Took 0.0200 seconds
Function intersection_skeleton_foldlabel Took 0.1401 seconds
2 497595
Function count_non_null Took 0.0357 seconds
Function count_non_null Took 0.0192 seconds
Function count_non_null Took 0.0193 seconds
Function intersection_skeleton_foldlabel Took 0.1338 seconds
3 497595
Function count_non_null Took 0.0373 seconds
Function count_non_null Took 0.0207 seconds
Function count_non_null Took 0.0205 seconds
Function intersection_skeleton_foldlabel Took 0.1407 seconds

138.45057916641235

## Optimize remove_branch

In [2]:
arr = np.random.randint(0,10000, size=(50,25,40))
n_branches = len(np.unique(arr))
n_branches

9935

In [3]:
deltas = []

for j in range(551):
    new_arr = np.copy(arr)

    idx_to_remove = np.random.choice(np.unique(new_arr), size=int(0.4*n_branches), replace=False)

    start = time.time()

    for i in idx_to_remove:
        mask = ( (new_arr != 0) & (new_arr != i))  
        mask = mask.astype(int)
        new_arr = new_arr * mask

    end = time.time()

    deltas.append(end - start)

    print(j, len(np.unique(new_arr)))

print(np.sum(deltas), np.mean(deltas))

0 5961
1 5961
2 5962
3 5962
4 5961
5 5961
6 5961
7 5961
8 5962
9 5962
10 5962
11 5962
12 5962
13 5962
14 5962
15 5962
16 5961
17 5962
18 5961
19 5962
20 5961
21 5961
22 5962
23 5962
24 5961
25 5961
26 5962
27 5962
28 5961
29 5961
30 5962
31 5961
32 5961
33 5961
34 5962
35 5962
36 5961
37 5962
38 5962
39 5961
40 5961
41 5961
42 5961
43 5961
44 5961
45 5961
46 5962
47 5962
48 5961
49 5962
50 5962
51 5962
52 5961
53 5961
54 5961
55 5961
56 5962
57 5961
58 5961
59 5961
60 5962
61 5961
62 5961
63 5962
64 5961
65 5961
66 5961
67 5961
68 5961
69 5961
70 5961
71 5962
72 5961
73 5961
74 5962
75 5961
76 5962
77 5961
78 5961
79 5961
80 5961
81 5961
82 5962
83 5961
84 5962
85 5961
86 5961
87 5962
88 5961
89 5962
90 5961
91 5961
92 5962
93 5961
94 5961
95 5962
96 5961
97 5961
98 5962
99 5961
100 5962
101 5962
102 5962
103 5961
104 5961
105 5961
106 5961
107 5961
108 5962
109 5961
110 5961
111 5962
112 5961
113 5962
114 5961
115 5962
116 5962
117 5961
118 5961
119 5961
120 5961
121 5961
122 5962
123

In [10]:
deltas = []

for j in range(551):
    new_arr = np.copy(arr)

    idx_to_remove = np.random.choice(np.unique(new_arr), size=int(0.4*n_branches), replace=False)

    start = time.time()

    for i in idx_to_remove:
        new_arr[arr == i] = 0

    end = time.time()

    deltas.append(end - start)

    print(j, len(np.unique(new_arr)))


print(np.sum(deltas), np.mean(deltas))

0 5961
1 5961
2 5961
3 5961
4 5962
5 5961
6 5962
7 5962
8 5962
9 5961
10 5962
11 5961
12 5962
13 5961
14 5961
15 5962
16 5961
17 5961
18 5961
19 5961
20 5961
21 5962
22 5961
23 5962
24 5962
25 5962
26 5962
27 5962
28 5961
29 5962
30 5962
31 5961
32 5962
33 5962
34 5961
35 5961
36 5961
37 5962
38 5961
39 5961
40 5962
41 5961
42 5961
43 5962
44 5961
45 5962
46 5961
47 5961
48 5962
49 5961
50 5962
51 5961
52 5961
53 5962
54 5962
55 5961
56 5961
57 5961
58 5961
59 5962
60 5961
61 5962
62 5962
63 5962
64 5961
65 5962
66 5961
67 5961
68 5961
69 5961
70 5961
71 5961
72 5961
73 5961
74 5962
75 5961
76 5961
77 5962
78 5961
79 5962
80 5962
81 5962
82 5962
83 5961
84 5962
85 5962
86 5961
87 5961
88 5961
89 5962
90 5961
91 5962
92 5962
93 5961
94 5962
95 5962
96 5962
97 5961
98 5961
99 5961
100 5962
101 5961
102 5961
103 5962
104 5962
105 5961
106 5962
107 5962
108 5961
109 5961
110 5962
111 5961
112 5961
113 5961
114 5961
115 5961
116 5962
117 5961
118 5962
119 5962
120 5961
121 5961
122 5962
123

In [41]:
copy_folds = np.copy(foldlabels)

print(np.unique(foldlabels))

copy_folds[(7000 < copy_folds) | (copy_folds <= 4)] = 0

print(len(np.unique(foldlabels)), len(np.unique(copy_folds)))

[   0    1    2 ... 7626 7690 7713]
1951 1516


## Rotate

In [16]:
arr = np.arange(0,15000,1)
arr = arr.reshape(100,15,10)

arr.shape

(100, 15, 10)

In [19]:
print(arr[..., 0])

print(arr[:,:,0])

[[    0    10    20 ...   120   130   140]
 [  150   160   170 ...   270   280   290]
 [  300   310   320 ...   420   430   440]
 ...
 [14550 14560 14570 ... 14670 14680 14690]
 [14700 14710 14720 ... 14820 14830 14840]
 [14850 14860 14870 ... 14970 14980 14990]]
[[    0    10    20 ...   120   130   140]
 [  150   160   170 ...   270   280   290]
 [  300   310   320 ...   420   430   440]
 ...
 [14550 14560 14570 ... 14670 14680 14690]
 [14700 14710 14720 ... 14820 14830 14840]
 [14850 14860 14870 ... 14970 14980 14990]]


In [25]:
from contrastive.augmentations import RotateTensor, BinarizeTensor

In [36]:
bin_tens = BinarizeTensor()
rot_tens = RotateTensor(1)

tens_skels = torch.tensor(skels)
print(np.unique(tens_skels))

rot_skels = rot_tens(tens_skels)
print(np.unique(rot_skels))

bin_skels = bin_tens(tens_skels)
print(np.unique(bin_skels))

bin_rot_skels = rot_tens(bin_skels)
print(np.unique(bin_rot_skels))

[  0  30  35  60 100 120]
Function rotate Took 0.0798 seconds
Function rotate Took 0.0581 seconds
Function rotate Took 0.0611 seconds
Function rotate Took 0.0602 seconds
Function rotate Took 0.0597 seconds
Function rotate Took 0.0805 seconds
Function rotate Took 0.0613 seconds
Function rotate Took 0.0586 seconds
Function rotate Took 0.0608 seconds
Function rotate Took 0.0609 seconds
Function rotate Took 0.0732 seconds
Function rotate Took 0.0756 seconds
Function rotate Took 0.0779 seconds
Function rotate Took 0.0786 seconds
Function rotate Took 0.0802 seconds
Function __call__ Took 1.1100 seconds
[  0  30  35  60 100]
[0 1]
Function rotate Took 0.0754 seconds
Function rotate Took 0.0552 seconds
Function rotate Took 0.0704 seconds
Function rotate Took 0.0543 seconds
Function rotate Took 0.0696 seconds
Function rotate Took 0.0704 seconds
Function __call__ Took 0.4392 seconds
[0 1]


In [42]:
from scipy.ndimage import rotate

i = 51
skel = skels[51]

angle = 6
axes = (0,1)
const = 0

print(skel.shape)
print(np.unique(skel))

rot_skel = rotate(skel,
                  angle=angle,
                  axes=axes,
                  reshape=False,
                  mode='constant',
                  cval=const)

np.unique(rot_skels)

(551, 17, 40, 38, 1)
[  0  30  35  60 100 120]
Function rotate Took 1.5100 seconds


array([-28, -27, -26, -25, -24, -23, -22, -21, -20, -19, -18, -17, -16,
       -15, -14, -13, -12, -11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,
        -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,
        11,  12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,
        24,  25,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,
        37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,
        50,  51,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,
        63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,
        76,  77,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,
        89,  90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101,
       102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114,
       115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
       128, 129, 130, 131, 132, 133, 134, 135, 136, 138, 139, 140, 141],
      dtype=int16)