In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
from torchvision import transforms, models
import torch.optim as optim
from torch.optim import lr_scheduler
import os
import numpy as np
import time
import copy
import pandas as pd
import math
import matplotlib.pyplot as plt
import pickle
import nibabel as nib
import random
from tqdm import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2

import warnings
warnings.filterwarnings("ignore")

from resunet import ResUNet

In [2]:
torch.__version__

'1.11.0+cu102'

In [3]:
class ImageDatasetBrats(torch.utils.data.Dataset):
    def __init__(self, transforms):
        self.root = 'burdenko_numpy/'
        unnecessary_files = {'.DS_Store', '.ipynb_checkpoints'}
        self.df = pd.read_csv(f'burdenko_numpy/burdenko_slices.csv')
        self.group1 = {'1019_18','1034_18_4','1036_18','1043_18_4','1056_18_4','1072_19','1112_19_4','1159_18_4',
                  '1164_18','1170_18_4','1184_18','1185_18_4','1214_18','1257_18','1267_18_4','1275_19_4', '1333_18',
                  '1357_19_4','1484_18_4','1541_18_4','1546_18','1733_18','1734_18','1781_18','1795_18_', '255_18',
                  '608_18_4','664_18_4','672_18_4','705_18_4','746_19_4','788_18','826_18_4','856_19_4', '923_18',
                  '925_18_4','946_18','979_18_4'}

        self.group2 = {'1028_18_4','1096_18','1216_18','1254_18','1255_18','1258_18','1302_18_4','1326_18','1354_18_4',
                  '1360_18','1362_18_4','1421_18','1463_18_4','1470_18_4','1501_18_4','1515_18_4','1539_18',
                  '1566_18','1573_18_4','1635_18','1646_18','1685_18_4','1702_18','1743_18_4','1746_18_4',
                  '1764_18_4','1769_18','1770_18_4','322_18_4','349_18_4','351_18','541_18','558_18_4','573_18_4',
                  '575_18_4','593_18','644_19_4','660_18_4','745_18_4','770_18','875_18_4','971_18_4','990_18_4',
                  'Patient_1000314','Patient_1000815','Patient_1001316','Patient_102117','Patient_103717',
                  'Patient_104514','Patient_105215','Patient_107017','Patient_109017','Patient_109414',
                  'Patient_110014','Patient_110816','Patient_111016','Patient_120115','Patient_12115',
                  'Patient_12214','Patient_122315','Patient_123816','Patient_12417','Patient_127916',
                  'Patient_129316','Patient_129415','Patient_129816','Patient_130514','Patient_131416',
                  'Patient_132216','Patient_133916','Patient_135915','Patient_136415','Patient_136715',
                  'Patient_136915','Patient_137315','Patient_138316','Patient_138516','Patient_140316',
                  'Patient_146716','Patient_15215','Patient_15817','Patient_158716','Patient_161316',
                  'Patient_1815','Patient_20717','Patient_22117','Patient_24117','Patient_24717', 'Patient_24815',
                  'Patient_28514','Patient_2914','Patient_33217','Patient_43316', 'Patient_43515','Patient_45217',
                  'Patient_48417','Patient_48517','Patient_49617', 'Patient_5117','Patient_51815','Patient_52315',
                  'Patient_54317','Patient_56717', 'Patient_59315','Patient_59817','Patient_61715','Patient_61916',
                  'Patient_62315', 'Patient_62817','Patient_65516','Patient_66615','Patient_69515','Patient_70614',
                  'Patient_716','Patient_72715','Patient_74417','Patient_75116','Patient_76516', 'Patient_8017',
                  'Patient_83217','Patient_83714','Patient_84116','Patient_87114', 'Patient_88817','Patient_88917',
                  'Patient_89117','Patient_90517','Patient_90616', 'Patient_92114','Patient_9315','Patient_95717',
                  'Patient_98814','Patient_98817','Patient_99715'}

        self.group3 = {'1029_18_4','1744_18','1765_18_4','1788_18_4','423_18','607_18','668_18_4','688_18'}
        self.transforms = transforms
        
    def __getitem__(self, idx):
        folder, slice_ = self.df.iloc[idx][['patient', 'slice']]
        if (folder in self.group1) or (folder in self.group3):
            image = np.load(f'{self.root}{folder}/flair.npz')['arr_0'][:, slice_, :]
        if folder in self.group2:
            image = np.load(f'{self.root}{folder}/flair.npz')['arr_0'][:, :, slice_]
        image = image / np.max(image) * 255
        
        if (folder in self.group1) or (folder in self.group3):
            mask = np.load(f'{self.root}{folder}/mask.npz')['arr_0'][:, slice_, :]
        if folder in self.group2:
            mask = np.load(f'{self.root}{folder}/mask.npz')['arr_0'][:, :, slice_]
        mask[mask > 1] = 1
        transformed = self.transforms(image=np.array(image, dtype = np.uint8),
                                      mask=np.array(mask, dtype = np.uint8))
        image = transformed["image"].float()
        mask = transformed["mask"].float().unsqueeze(0)
        return image, mask


    def __len__(self):
        return len(self.df)

In [4]:
data_transforms = {
    'train': A.Compose(
        [
        A.RandomResizedCrop(256, 256, scale=(0.8, 1.0), ratio=(0.9, 1.1), p=0.3),
        A.Resize(256, 256),
        A.RandomBrightnessContrast(p=0.2),
        A.Normalize(mean=0, std=1),
        ToTensorV2(),
        ]
    ),
    'val': A.Compose(
        [
        A.Resize(256, 256),
        A.Normalize(mean=0, std=1),
        ToTensorV2(),
        ]
    )
}

In [5]:
dataset_train = ImageDatasetBrats(data_transforms['train'])
dataset_test = ImageDatasetBrats(data_transforms['val'])

torch.manual_seed(123) #для воспроизводимости
indices = torch.randperm(180).tolist()
t = int(180*0.7)

df_for_training = pd.read_csv("burdenko_numpy/burdenko_slices.csv")

train_indices = df_for_training[df_for_training.patient_index.isin(indices[:t])]['index'].tolist()
test_indices = df_for_training[df_for_training.patient_index.isin(indices[t:])]['index'].tolist()

dataset_train = torch.utils.data.Subset(dataset_train, train_indices)
dataset_test = torch.utils.data.Subset(dataset_test, test_indices)

dataloaders = {'train': torch.utils.data.DataLoader(dataset_train, batch_size=16, shuffle=True, num_workers=4),
               'test': torch.utils.data.DataLoader(dataset_test, batch_size=16, shuffle=False, num_workers=4)}

dataset_sizes = {'train': len(dataset_train), 'test': len(dataset_test)}

In [6]:
!nvidia-smi

Sun May 22 01:39:17 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.54       Driver Version: 510.54       CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:06:00.0 Off |                  N/A |
|  0%   56C    P2    70W / 250W |   1125MiB /  6144MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [7]:
device = torch.device("cuda:0")

In [8]:
class IGS(nn.Module):
    def __init__(self, w):
        super().__init__()
        
        self.w = torch.tensor(w, dtype=torch.float, requires_grad=True, device=torch.device('cuda:0'))
        self.unet = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', 
                                   in_channels=1, out_channels=1, init_features=32, pretrained=False)
        self.unet.load_state_dict(torch.load("model_weights/unet_burd_orig.pth", map_location=torch.device('cpu')))

    def forward(self, x):
        x_fft = torch.fft.fftshift(torch.fft.fft2(x))
        w_x = x_fft * self.w
        x_ifft = torch.fft.ifft2(w_x)
        
        out = self.unet(x_ifft.real)
        
        return out


In [9]:
def dice_coef_metric(pred, label):
    intersection = 2.0 * (pred * label).sum()
    union = pred.sum() + label.sum()
    if pred.sum() == 0 and label.sum() == 0:
        return 1.
    return intersection / union

def dice_coef_loss(pred, label):
    smooth = 1.0
    intersection = 2.0 * (pred * label).sum() + smooth
    union = pred.sum() + label.sum() + smooth
    return 1 - (intersection / union)
    
def train_loop(model, loader, loss_func):
    model.eval()
    
    loss_mean = []
    for image, mask in tqdm(loader):
        image = image.to(device)
        mask = mask.to(device)
        outputs = model(image)
        loss = loss_func(outputs, mask)
        loss_mean.append(loss.item())
        loss.backward()
    
    gradient_vec = model.w.grad.detach().cpu().numpy()
    return gradient_vec

w_orig

w_full

w_2

w_4

w_8

In [10]:
w = np.zeros(256)
w[128]=1
model = IGS(w.copy())
model.to(device)
model.eval()
w_pateeirk = []
for i in range(int(128)):
    gradient = train_loop(model, dataloaders['test'], dice_coef_loss)
    valid_idx = np.where(model.w.detach().cpu().numpy() < 1)[0]
    new_w_index = valid_idx[gradient[valid_idx].argmin()]
#     print(new_w_index)
    w_pateeirk.append(new_w_index)
    with torch.no_grad():
        model.w[new_w_index] = 1
    model.w.grad.zero_()
print(w_pateeirk)
w_orig = model.w.detach().cpu().numpy()

Using cache found in /home/i_govorova/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [05:55<00:00,  1.52s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [06:40<00:00,  1.71s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [06:08<00:00,  1.58s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [05:55<00:00,  1.52s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [05:56<00:00,  1.52s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [05:56<00:00,  1.52s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [05:55<00:00,  1.52s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [05:56<00:00,  1.52s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [05:55<00:00,  1.52s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [05:55<00:00,  1.52s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [05:55<00:00,  1.52s/it]
100%|█████████████████████████████████████████

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [05:56<00:00,  1.52s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [05:55<00:00,  1.52s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [05:55<00:00,  1.52s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [05:55<00:00,  1.52s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [05:55<00:00,  1.52s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 234/234 [05:56<00:00,  1.52s/it]
100%|█████████████████████████████████████████

[131, 132, 125, 138, 142, 108, 136, 115, 250, 235, 27, 233, 35, 101, 37, 0, 5, 158, 170, 84, 82, 221, 11, 209, 80, 89, 83, 171, 81, 178, 87, 110, 173, 218, 175, 177, 172, 76, 174, 86, 118, 78, 176, 96, 180, 74, 72, 182, 77, 70, 67, 79, 69, 184, 64, 75, 179, 194, 189, 192, 186, 195, 190, 187, 114, 191, 71, 68, 146, 85, 169, 73, 181, 63, 183, 61, 66, 188, 65, 185, 193, 204, 62, 202, 240, 54, 206, 33, 196, 91, 25, 50, 44, 197, 167, 220, 52, 228, 255, 1, 34, 245, 10, 88, 106, 157, 163, 36, 203, 219, 40, 222, 12, 241, 244, 26, 53, 156, 38, 201, 6, 15, 13, 208, 249, 226, 162, 216]





In [None]:
w_full_gradient = [5.78555744e-04, 6.48571659e-05, 9.68116183e-06, 1.54330792e-05, 6.30917202e-05,
                   2.39825968e-05, 1.37329125e-05, 1.71572774e-05, -1.33385338e-05, 2.05224624e-05,
                   3.40017596e-06, -5.42851149e-06, -3.45530589e-06, -7.51995003e-06, 2.28772533e-05,
                   -1.07714895e-05, 2.07486137e-05, 1.10780184e-05, -7.77878654e-07, 1.43234647e-05,
                   -8.38386768e-06, -1.66767095e-05, 1.12549014e-05, 1.72423142e-06, 2.10924213e-06,
                   -7.72493240e-06, 9.83910104e-06, -2.22016342e-06, 1.49013513e-05, -1.51454924e-05,
                   -5.31668866e-06,  1.22643305e-05, -1.15287594e-06, 6.55047324e-06, -1.83437203e-06,
                   1.65664533e-05, -5.36005291e-06, 1.13516871e-05, -2.71389752e-07, -1.03473430e-05,
                   1.36821209e-05, -1.76467147e-05, -2.21140408e-06, 1.02762006e-05, -1.02335343e-05,
                   1.40079715e-06, -2.46206764e-05, 4.51499227e-06, -8.49021944e-06, 2.74704253e-05,
                   4.29363899e-05, 8.71689645e-06, 1.58022124e-06, -4.49129084e-06, 3.24774919e-05,
                   -1.23301143e-05, 4.00153367e-05, -1.97823920e-05, 3.06994007e-05, 6.99820430e-06, 
                   1.67649450e-05, 2.09926275e-05, 1.02433505e-05, -1.42822355e-05, 3.32538984e-05,
                   -1.63213372e-05, 4.42776873e-05, 9.70571136e-07, 1.21265453e-04, -5.01360119e-05,
                   6.16586913e-05, 9.85989573e-06, -1.30319186e-05, 5.46723959e-06, -3.98456177e-06,
                   2.81401226e-05, 2.96296712e-05,  4.69327460e-05, -4.75627858e-05,  7.46010119e-05,
 -8.99409133e-06,  8.54544560e-05, -1.30523376e-05,  1.07132801e-04,
  1.50190090e-05,  3.41933010e-05, -5.58879547e-05,  3.18720959e-05,
  6.21863728e-05,  1.56419221e-04,  3.19039755e-05, -1.07809583e-05,
  4.58005161e-05,  2.94381665e-04, -2.83754925e-04,  4.90326027e-04,
 -1.75023888e-04,  5.15452761e-04, -2.92439567e-04,  4.09983913e-04,
 -2.17845576e-04,  5.86512790e-04, -4.49868530e-05,  1.27539504e-03,
 -3.14165227e-05,  1.38958485e-03,  1.94723485e-04,  6.94424729e-04,
  5.71255106e-04, -4.48252220e-04,  2.38459418e-03, -1.66429571e-04,
  1.09178922e-03,  3.23429878e-04,  1.13217882e-03, -1.86037229e-04,
  2.04215711e-03, -1.72455388e-03,  1.59298291e-03,  1.27258524e-03,
  4.28129314e-03, -1.64921337e-03,  2.21423199e-03,  3.99506977e-03,
 -3.38522252e-03, -1.41066201e-02,  9.11409855e-02,  1.67958438e-01,
 -6.88430786e-01,  1.67958423e-01,  9.11409929e-02, -1.41066248e-02,
 -3.38522252e-03,  3.99507070e-03,  2.21423223e-03, -1.64921232e-03,
  4.28129267e-03,  1.27258652e-03,  1.59298256e-03, -1.72455329e-03,
  2.04215664e-03, -1.86037083e-04,  1.13217894e-03,  3.23439366e-04,
  1.09178899e-03, -1.66431157e-04,  2.38459394e-03, -4.48252686e-04,
  5.71255572e-04,  6.94424787e-04,  1.94723820e-04,  1.38958462e-03,
 -3.14164572e-05,  1.27539481e-03, -4.49866639e-05,  5.86512673e-04,
 -2.17845460e-04,  4.09983651e-04, -2.92437762e-04,  5.15463122e-04,
 -1.75024048e-04,  4.90312523e-04, -2.83758534e-04,  2.94381724e-04,
  4.58009381e-05, -1.07810665e-05,  3.19038227e-05,  1.56419177e-04,
  6.21864019e-05,  3.18720522e-05, -5.58877109e-05,  3.41935483e-05,
  1.50186816e-05,  1.07133106e-04, -1.30517556e-05,  8.54600075e-05,
 -8.99395036e-06,  7.46165315e-05, -4.75601046e-05,  4.69324914e-05,
  2.96294766e-05,  2.81404300e-05, -3.98452994e-06,  5.46715000e-06,
 -1.30317494e-05,  9.85992028e-06,  6.16586622e-05, -5.01358663e-05,
  1.21265584e-04,  9.70392193e-07,  4.42778219e-05, -1.63170225e-05,
  3.32536511e-05, -1.42843310e-05,  1.02434669e-05,  2.09926147e-05,
  1.67650287e-05,  6.99824795e-06,  3.06994552e-05, -1.97824102e-05,
  4.00154677e-05, -1.23300770e-05,  3.24774956e-05, -4.49109211e-06,
  1.58018088e-06,  8.71667271e-06,  4.29370739e-05,  2.74775357e-05,
 -8.49016396e-06,  4.51907181e-06, -2.46202035e-05,  1.40068914e-06,
 -1.02336389e-05,  1.02763797e-05, -2.21147866e-06, -1.76467529e-05,
  1.36821918e-05, -1.03473385e-05, -2.71465979e-07,  1.13518727e-05,
 -5.35992285e-06,  1.65662732e-05, -1.83373720e-06,  6.55659096e-06,
 -1.15283262e-06,  1.22621495e-05, -5.31726437e-06, -1.51454469e-05,
  1.49013358e-05, -2.22005997e-06,  9.83905284e-06, -7.72489693e-06,
  2.10923599e-06,  1.72428258e-06,  1.12548341e-05, -1.66767131e-05,
 -8.38370215e-06,  1.43234620e-05, -7.78499611e-07,  1.10773417e-05,
  2.07483845e-05, -1.07739224e-05,  2.28774970e-05, -7.52008555e-06,
 -3.45537546e-06, -5.42860926e-06,  3.40014935e-06,  2.05225260e-05,
 -1.33387666e-05,  1.71574029e-05,  1.37330389e-05,  2.39825058e-05,
  6.30917712e-05,  1.54326335e-05,  9.68124550e-06,  6.48564601e-05]

In [None]:
np.array(w_full_gradient).argsort()[:128]

In [None]:
w_orig_indexes = [128,131,124,125,118,142,108,136,115,250,235,27,233,221,101,6,37,0,158,65,52,170,216,144,68,186,
                  62,64,67,61,192,70,161,189,195,191,63,89,102,194,206,202,72,69,66,91,188,71,187,219,208,75,77,
                  73,79,81,78,80,83,76,197,226,82,74,178,105,190,193,33,176,93,179,181,184,138,177,172,169,171,
                  174,232,110,183,175,173,85,185,182,180,196,209,84,237,97,87,156,255,1,225,54,223,228,11,40,25,
                  213,88,50,22,165,34,240,236,220,86,12,198,241,244,251,210,35,238,249,10,203,201,9,204]

In [None]:
with torch.no_grad():
    model.w[:] = 0
    model.w[128] = 1
model.eval()
for _ in range(int(64)):
    gradient = train_loop(model, dataloaders['test'], dice_coef_loss)
    valid_idx = np.where(model.w.detach().cpu().numpy() < 1)[0]
    new_w_index = valid_idx[gradient[valid_idx].argsort()[:2]]
    print(new_w_index)
    with torch.no_grad():
        model.w[new_w_index] = 1
    model.w.grad.zero_()
print(model.w)
w_2 = model.w.detach().cpu().numpy()

In [None]:
w_2_indexes = [128,131,125,138,118,135,121,133,123,115,141,107,149,108,148,152,104,154,102,101,155,106,150,78,
               178,180,76,80,176,77,179,175,81,177,79,172,84,227,29,174,82,171,85,229,27,212,44,168,88,145,111,
               122,134,114,142,83,173,169,87,170,86,116,140,98,158,91,165,162,94,153,103,156,100,159,97,89,167,
               90,166,75,181,71,185,73,183,187,69,186,70,188,68,72,184,74,182,228,28,15,241,214,42,231,25,233,23,
               26,230,24,232,234,22,21,235,40,216,213,43,41,215,36,220,200,56,202,54,55,201,129,127]

In [None]:
with torch.no_grad():
    model.w[:] = 0
    model.w[128] = 1
model.eval()
for _ in range(int(32)):
    gradient = train_loop(model, dataloaders['test'], dice_coef_loss)
    valid_idx = np.where(model.w.detach().cpu().numpy() < 1)[0]
    new_w_index = valid_idx[gradient[valid_idx].argsort()[:4]]
    print(new_w_index)
    with torch.no_grad():
        model.w[new_w_index] = 1
    model.w.grad.zero_()
print(model.w)
w_4 = model.w.detach().cpu().numpy()

In [None]:
w_4_indexes = [128,131,125,132,124,118,138,121,135,144,112,114,142,140,116,146,110,159,97,160,96,150,106,100,156,
               153,103,104,152,133,123,157,99,127,129,155,101,68,188,79,177,76,180,185,71,74,182,82,174,219,37,
               225,31,34,222,249,7,46,210,22,234,183,73,65,191,62,194,29,227,39,217,32,224,196,60,190,66,63,193,
               41,215,44,212,70,186,189,67,27,229,211,45,72,184,42,214,170,86,83,173,209,47,168,88,55,201,49,207,
               52,204,243,13,242,14,19,237,18,238,61,195,253,3,192,64,58,198,6,250,235,21]

In [None]:
len(w_4_indexes)

In [None]:
with torch.no_grad():
    model.w[:] = 0
    model.w[128] = 1
model.eval()
for _ in range(int(16)):
    gradient = train_loop(model, dataloaders['test'], dice_coef_loss)
    valid_idx = np.where(model.w.detach().cpu().numpy() < 1)[0]
    new_w_index = valid_idx[gradient[valid_idx].argsort()[:8]]
    print(new_w_index)
    with torch.no_grad():
        model.w[new_w_index] = 1
    model.w.grad.zero_()
print(model.w)
w_8 = model.w.detach().cpu().numpy()

In [None]:
w_8_indexes = [128,131,125,132,124,139,117,121,135,133,123,143,113,145,111,115,141,119,137,147,109,154,102,140,116,
               134,122,104,152,157,99,244,12,106,150,148,108,161,95,105,151,138,118,165,91,212,44,164,92,93,163,
               155,101,89,167,94,162,114,142,166,90,42,214,71,185,88,168,73,183,186,70,25,231,72,184,170,86,187,
               69,26,230,96,160,158,98,68,188,181,75,74,182,85,171,169,87,76,180,177,79,77,179,84,172,83,173,80,
               176,78,178,174,82,43,213,215,41,228,28,238,18,46,210,45,211,226,30,17,239,38,218]

In [None]:
len(w_8_indexes)