In [197]:
import torch
import torch.nn as nn
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor
import torchvision as tv
import matplotlib.pyplot as plt
from monai.data import Dataset, DataLoader
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from glob import glob
import os
import nibabel as nib
import numpy as np
import cv2
from ipywidgets import interact
from tqdm import tqdm

In [36]:
sensitive_path = 'sensitive.npy'
resistant_path = 'resistant.npy'
sensitive = None
resistant = None
with open(sensitive_path, 'rb') as f:
    sensitive = np.load(f)
    f.close()

with open(resistant_path, 'rb') as f:
    resistant = np.load(f)
    f.close()


In [37]:
def show(layer):
    plt.imshow(sensitive[layer], cmap="gray")
    
interact(show, layer=(0, sensitive.shape[0]-1))

interactive(children=(IntSlider(value=153, description='layer', max=306), Output()), _dom_classes=('widget-int…

<function __main__.show(layer)>

In [38]:
def show(layer):
    plt.imshow(resistant[layer], cmap="gray")
    
interact(show, layer=(0, resistant.shape[0]-1))

interactive(children=(IntSlider(value=123, description='layer', max=247), Output()), _dom_classes=('widget-int…

<function __main__.show(layer)>

In [39]:
model = tv.models.resnet101(weights=tv.models.ResNet101_Weights.IMAGENET1K_V2)

In [40]:
def prepare(image):
    image = cv2.cvtColor(image.astype('float32'), cv2.COLOR_GRAY2RGB)
    image = image.transpose((2,0,1))
    image = np.expand_dims(image, 0)
    image = torch.from_numpy(image)
    return image

In [42]:
return_nodes = ['layer4']

feat_ext = create_feature_extractor(model, return_nodes=return_nodes)
sensitive_features = []
resistant_features = []

with torch.no_grad():
    for sens_item in tqdm(sensitive):
        sensitive_features.append(torch.squeeze(feat_ext(prepare(sens_item))['layer4']))
        
    for res_item in tqdm(resistant):
        resistant_features.append(torch.squeeze(feat_ext(prepare(res_item))['layer4'])) 
    

100%|█████████████████████████████████████████| 307/307 [01:12<00:00,  4.25it/s]
100%|█████████████████████████████████████████| 248/248 [00:57<00:00,  4.29it/s]


In [43]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, sens_features, res_features):
        self.sens_features = sens_features
        self.res_features = res_features
        self.features, self.labels = self.get_lists()

        
    def get_lists(self):

        res_labels = np.ones(len(self.res_features))
        sens_labels = np.zeros(len(self.sens_features))
        
        features = self.sens_features + self.res_features
        labels = np.concatenate((sens_labels, res_labels), axis = None)
        
#         shuffle_indices = np.random.permutation(len(features))
        
#         features =  features[shuffle_indices]
#         labels = labels[shuffle_indices]

        return features, labels
            
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
    
        return {'feature': self.features[idx], 'label': self.labels[idx]}

In [175]:
dataset = MyDataset(sensitive_features, resistant_features)
dataset = Dataset(data = dataset)
dataloader = DataLoader(dataset, num_workers = 0, batch_size=32, shuffle = True)
work_loader = DataLoader(dataset, num_workers = 0, batch_size=1, shuffle = False)

In [165]:
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(2048, 1024, 1),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 512, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 256, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 128, 1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, stride = 2, padding = 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, stride = 2, padding = 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, stride = 2, padding = 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, stride = 2, padding = 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            )
        
        self.decoder = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 256, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 512, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 1024, 1),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 2048, 1),
            nn.ReLU(),
        )
        
    def forward(self, x, get_latent=False):
        latent  = self.encoder(x)
        decoded = self.decoder(latent)
        if get_latent:
            return latent
        else:
            return decoded

In [166]:
model = Autoencoder()
device = torch.device('mps')
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

model.to(device)
criterion.to(device)

MSELoss()

In [167]:
num_epochs = 50
loss_history = [0] * num_epochs
weights_pass = 'weights_autoencoder.pt'
for epoch in range(num_epochs):
    for sample in tqdm(dataloader):
        img, _ = sample['feature'], sample['label']
        img = torch.squeeze(img)
        img = img.to(device)
        recon = model(img)
        loss = criterion(recon, img)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"Epoch: {epoch+1}, Loss:{loss.item():.4f}")
    
    if epoch > 0:
        if loss_history[epoch] < loss_history[epoch-1]:
            torch.save(model.state_dict(), weights_pass)
    else:
        torch.save(model.state_dict(), weights_pass)
    


100%|███████████████████████████████████████████| 18/18 [00:05<00:00,  3.16it/s]


Epoch: 1, Loss:0.5418


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.21it/s]


Epoch: 2, Loss:0.5157


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.12it/s]


Epoch: 3, Loss:0.5561


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.14it/s]


Epoch: 4, Loss:0.5070


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.24it/s]


Epoch: 5, Loss:0.5221


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.15it/s]


Epoch: 6, Loss:0.4965


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.12it/s]


Epoch: 7, Loss:0.5303


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.18it/s]


Epoch: 8, Loss:0.5287


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.22it/s]


Epoch: 9, Loss:0.4464


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.17it/s]


Epoch: 10, Loss:0.4719


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.12it/s]


Epoch: 11, Loss:0.4489


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.16it/s]


Epoch: 12, Loss:0.5039


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.16it/s]


Epoch: 13, Loss:0.5369


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.15it/s]


Epoch: 14, Loss:0.4486


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.25it/s]


Epoch: 15, Loss:0.5080


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.27it/s]


Epoch: 16, Loss:0.4285


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.21it/s]


Epoch: 17, Loss:0.4863


100%|███████████████████████████████████████████| 18/18 [00:03<00:00,  5.67it/s]


Epoch: 18, Loss:0.4162


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.02it/s]


Epoch: 19, Loss:0.5094


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.00it/s]


Epoch: 20, Loss:0.4797


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.00it/s]


Epoch: 21, Loss:0.4855


100%|███████████████████████████████████████████| 18/18 [00:03<00:00,  5.94it/s]


Epoch: 22, Loss:0.4920


100%|███████████████████████████████████████████| 18/18 [00:03<00:00,  5.72it/s]


Epoch: 23, Loss:0.4120


100%|███████████████████████████████████████████| 18/18 [00:03<00:00,  5.87it/s]


Epoch: 24, Loss:0.4393


100%|███████████████████████████████████████████| 18/18 [00:03<00:00,  5.76it/s]


Epoch: 25, Loss:0.5198


100%|███████████████████████████████████████████| 18/18 [00:03<00:00,  5.96it/s]


Epoch: 26, Loss:0.4643


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.00it/s]


Epoch: 27, Loss:0.4884


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.18it/s]


Epoch: 28, Loss:0.5203


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.15it/s]


Epoch: 29, Loss:0.5299


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.08it/s]


Epoch: 30, Loss:0.4751


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.15it/s]


Epoch: 31, Loss:0.5065


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.27it/s]


Epoch: 32, Loss:0.5337


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.26it/s]


Epoch: 33, Loss:0.4505


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.10it/s]


Epoch: 34, Loss:0.4836


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.21it/s]


Epoch: 35, Loss:0.4278


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.21it/s]


Epoch: 36, Loss:0.4583


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.13it/s]


Epoch: 37, Loss:0.4598


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.23it/s]


Epoch: 38, Loss:0.4242


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.15it/s]


Epoch: 39, Loss:0.4533


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.15it/s]


Epoch: 40, Loss:0.3990


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.26it/s]


Epoch: 41, Loss:0.4209


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.28it/s]


Epoch: 42, Loss:0.4946


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.14it/s]


Epoch: 43, Loss:0.4336


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.10it/s]


Epoch: 44, Loss:0.4861


100%|███████████████████████████████████████████| 18/18 [00:03<00:00,  5.96it/s]


Epoch: 45, Loss:0.4837


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.07it/s]


Epoch: 46, Loss:0.4451


100%|███████████████████████████████████████████| 18/18 [00:03<00:00,  5.95it/s]


Epoch: 47, Loss:0.3959


100%|███████████████████████████████████████████| 18/18 [00:03<00:00,  5.79it/s]


Epoch: 48, Loss:0.4547


100%|███████████████████████████████████████████| 18/18 [00:03<00:00,  5.98it/s]


Epoch: 49, Loss:0.4474


100%|███████████████████████████████████████████| 18/18 [00:02<00:00,  6.20it/s]


Epoch: 50, Loss:0.3952


In [177]:
# model = Autoencoder()
# model.load_state_dict(torch.load(weights_pass))
model.eval()


reduced_features = []
labels = []
for sample in tqdm(work_loader):
    img, label = sample['feature'], sample['label']
#     img = torch.squeeze(img)
    img = img.to(device)
    reduced = model(img, get_latent=True)
    
    reduced_features.append(reduced.cpu().detach().numpy())
    labels.append(label.detach().numpy())

100%|████████████████████████████████████████| 555/555 [00:05<00:00, 106.51it/s]


In [178]:
for i in range(len(reduced_features)):
    reduced_features[i] = list(reduced_features[i].reshape(-1))
labels = [int(i) for i in labels]

In [194]:
x_train, x_test, y_train, y_test = train_test_split(reduced_features, labels, test_size=0.2, random_state=42)

In [195]:
forest = RandomForestClassifier(n_estimators=100)
forest.fit(x_train, y_train)

In [196]:
from sklearn.metrics import classification_report

print(classification_report(y_test, forest.predict(x_test)))

              precision    recall  f1-score   support

           0       0.89      1.00      0.94        64
           1       1.00      0.83      0.91        47

    accuracy                           0.93       111
   macro avg       0.94      0.91      0.92       111
weighted avg       0.94      0.93      0.93       111

