In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
import torch.optim as optim
from torch.optim import lr_scheduler
import os
from PIL import Image, ImageOps
import numpy as np
import time
import copy
import pandas as pd
import math
import matplotlib.pyplot as plt
import pickle
import nibabel as nib
import random
from tqdm import tqdm
import sklearn.covariance


In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.act1 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)

        self.bn2 = nn.BatchNorm2d(in_channels)
        self.act2 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)

    def forward(self, x):
        identity = x

        out = self.bn1(x)
        out = self.act1(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.act2(out)
        out = self.conv2(out)

        out += identity
        return out


class ResUNet(nn.Module):
    def __init__(self, in_channels=4, out_channels=1, features=32, dropout=False, pooling_size=2):
        super(ResUNet, self).__init__()

        if dropout:
            dropout_layer = nn.Dropout(0.1)
        else:
            dropout_layer = nn.Identity()

        self.init_path = nn.Sequential(
            nn.Conv2d(in_channels, features, kernel_size=3, padding=1, bias=False),
            nn.ReLU(),
            ResidualBlock(features, features, kernel_size=3, padding=1),
            ResidualBlock(features, features, kernel_size=3, padding=1),
            ResidualBlock(features, features, kernel_size=3, padding=1)
        )
        self.shortcut0 = nn.Conv2d(features, features, kernel_size=1)

        self.down1 = nn.Sequential(
            nn.BatchNorm2d(features),
            nn.Conv2d(features, features * 2, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer,
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1),
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1),
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1)
        )
        self.shortcut1 = nn.Conv2d(features * 2, features * 2, 1)

        self.down2 = nn.Sequential(
            nn.BatchNorm2d(features * 2),
            nn.Conv2d(features * 2, features * 4, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer,
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1),
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1),
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1)
        )
        self.shortcut2 = nn.Conv2d(features * 4, features * 4, 1)

        self.down3 = nn.Sequential(
            nn.BatchNorm2d(features * 4),
            nn.Conv2d(features * 4, features * 8, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer,
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            dropout_layer
        )

        self.up3 = nn.Sequential(
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            ResidualBlock(features * 8, features * 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(features * 8),
            nn.ConvTranspose2d(features * 8, features * 4, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer
        )

        self.up2 = nn.Sequential(
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1),
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1),
            ResidualBlock(features * 4, features * 4, kernel_size=3, padding=1),
            nn.BatchNorm2d(features * 4),
            nn.ConvTranspose2d(features * 4, features * 2, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer
        )

        self.up1 = nn.Sequential(
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1),
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1),
            ResidualBlock(features * 2, features * 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(features * 2),
            nn.ConvTranspose2d(features * 2, features, kernel_size=pooling_size, stride=pooling_size, bias=False),
            nn.ReLU(),
            dropout_layer
        )

        self.out_path = nn.Sequential(
            ResidualBlock(features, features, kernel_size=1, padding=0),
            nn.BatchNorm2d(features),
            nn.ReLU(),
            nn.Conv2d(features, out_channels, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_channels)
        )
    
    # function to extract the multiple features
    def feature_list(self, x):
        out_list = []
        
        x0 = self.init_path(x)
        out_list.append(x0)
        x1 = self.down1(x0)
        out_list.append(x1)
        x2 = self.down2(x1)
        out_list.append(x2)
        x3 = self.down3(x2)
        out_list.append(x3)
        x2_up = self.up3(x3)
        out_list.append(x2_up)
        x1_up = self.up2(x2_up + self.shortcut2(x2))
        out_list.append(x1_up)
        x0_up = self.up1(x1_up + self.shortcut1(x1))
        out_list.append(x0_up)
        x_out = self.out_path(x0_up + self.shortcut0(x0))
        out_list.append(x_out)
        
        return x_out, out_list

    def forward(self, x):
        x0 = self.init_path(x)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)

        x2_up = self.up3(x3)
        x1_up = self.up2(x2_up + self.shortcut2(x2))
        x0_up = self.up1(x1_up + self.shortcut1(x1))
        x_out = self.out_path(x0_up + self.shortcut0(x0))
        return torch.sigmoid(x_out)

    def intermediate_forward(self, x):
        x0 = self.init_path(x)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        
        return x3

In [3]:
device = torch.device("cuda:0")

In [4]:
model = ResUNet(in_channels=1)
model.load_state_dict(torch.load('resunet_10.pth'))
model.to(device)
model_name = "resunet"
model.eval()

ResUNet(
  (init_path): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): ReLU()
    (2): ResidualBlock(
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU()
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU()
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (3): ResidualBlock(
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU()
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU()
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (4): ResidualBlock

In [5]:
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.root = 'BraTS2021_Training_Data/'
        unnecessary_files = {'.DS_Store', '.ipynb_checkpoints'}
        self.folder_mris = list(sorted(os.listdir(self.root)))
        if self.folder_mris[0] in unnecessary_files:
            self.folder_mris = self.folder_mris[1:]
        self.num_slices = 25
        self.start_slice = 20
        
    def __getitem__(self, idx):
        folder_idx = idx // self.num_slices
        slice_idx = ((idx % self.num_slices) * 4) + self.start_slice
        image_path = os.path.join(self.root, self.folder_mris[folder_idx])

        image_types = ["flair", "t1", "t1ce", "t2"]
        image = np.zeros((1, 240, 240))
        for i, image_type in enumerate(image_types):
            image_name = self.folder_mris[folder_idx] + '_' + image_type + '.nii.gz'
            image[i] = nib.load(os.path.join(image_path, image_name)).get_fdata()[:,:,slice_idx]
            break

        image = image / np.max(image) * 255
        mask_name = self.folder_mris[folder_idx] + '_' + 'seg' + '.nii.gz'
        mask = nib.load(os.path.join(image_path, mask_name)).get_fdata()[:,:,slice_idx]
        mask[mask > 0] = 1
        transformed = self.transforms(image=np.array(image, dtype = np.uint8),
                                      mask=np.array(mask, dtype = np.uint8))
        image = transformed["image"].float()
        image /= 255.
        mask = transformed["mask"].float().unsqueeze(0)
        return image, mask


    def __len__(self):
        return int(len(self.folder_mris) * self.num_slices)
    
    
data_transforms = {
    'train': A.Compose(
        [
        A.RandomResizedCrop(224, 224, scale=(0.8, 1.0), ratio=(0.9, 1.1), p=0.3),
        A.Resize(224, 224),
        ToTensorV2(),
        ]
    ),
    'val': A.Compose(
        [
        A.Resize(224, 224),
        ToTensorV2(),
        ]
    )
}


dataset_train = ImageDataset(data_transforms['train'])
dataset_val = ImageDataset(data_transforms['val'])

torch.manual_seed(123) #для воспроизводимости
indices = torch.randperm(1251).tolist()
t = int(0.8 * 1251)
train_indices =  sum([(np.array((range(40)))+(i*40)).tolist() for i in indices[:t]], [])
test_indices = sum([(np.array((range(40)))+(i*40)).tolist() for i in indices[t:]], [])

dataset_train = torch.utils.data.Subset(dataset_train, train_indices)
dataset_val = torch.utils.data.Subset(dataset_val, test_indices)


dataloaders = {'train': torch.utils.data.DataLoader(dataset_train, batch_size=6, shuffle=True, num_workers=3),
               'val': torch.utils.data.DataLoader(dataset_val, batch_size=6, shuffle=False, num_workers=3)}

dataset_sizes = {'train': len(dataset_train), 'val': len(dataset_val)}

In [6]:
# собираю эмбеддинги из обучащего датасета для трейна PCA

all_embeddings = []
with torch.no_grad():
    for i, (image, _) in tqdm(enumerate(dataloaders['train'])):
        image = image.to(device)
        output_emb = model.intermediate_forward(image).cpu().numpy()
        for output_emb_sep in output_emb:
            all_embeddings.append(output_emb_sep.flatten())
        if i == 2000:
            break

2000it [10:03,  3.31it/s]


In [7]:
from sklearn.decomposition import PCA
from sklearn.ensemble import IsolationForest

In [8]:
pca = PCA(n_components=256)

In [9]:
all_embeddings = np.array(all_embeddings)

In [10]:
pca_emb = pca.fit_transform(all_embeddings)

In [11]:
del all_embeddings

## Isolation Forest

In [12]:
clf = IsolationForest().fit(pca_emb)

In [13]:
with open("pca_unet.pkl", 'wb') as f: 
    pickle.dump(pca, f)

In [14]:
all_embeddings_val = []
folder_list = []
slice_num_list = []
with torch.no_grad():
    for i, (image, _) in tqdm(enumerate(dataloaders['val'])):
        image = image.to(device)
        output_emb = model.intermediate_forward(image).cpu().numpy()
        for output_emb_sep in output_emb:
            all_embeddings_val.append(output_emb_sep.flatten())
            
all_embeddings_val = np.array(all_embeddings_val)
pca_emb_val = pca.transform(all_embeddings_val)

1046it [04:50,  3.60it/s]


In [15]:
sum(pca.explained_variance_ratio_)

0.9931263901598868

In [16]:
del all_embeddings_val

In [17]:
prediction = clf.predict(pca_emb_val)
acc = len(np.where(prediction==1)[0]) / len(prediction)
print("accuracy", acc)

accuracy 0.9260557768924302


## LOF

In [18]:
from sklearn.neighbors import LocalOutlierFactor
clf2 = LocalOutlierFactor(n_neighbors=5, novelty=True)
clf2.fit(pca_emb)

LocalOutlierFactor(n_neighbors=5, novelty=True)

In [19]:
prediction = clf2.predict(pca_emb_val)
acc = len(np.where(prediction==1)[0]) / len(prediction)
print("accuracy", acc)

accuracy 0.9692430278884462


## OneClassSVM

In [20]:
from sklearn.svm import OneClassSVM
clf3 = OneClassSVM(degree=10).fit(pca_emb)

In [21]:
prediction = clf3.predict(pca_emb_val)
acc = len(np.where(prediction==1)[0]) / len(prediction)
print("accuracy", acc)

accuracy 0.48717131474103587
