In [36]:
import os
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torchvision.transforms.functional as TF
import torch.utils.data as data
from tqdm import tqdm

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score

from torchvision import transforms
from torchsample.transforms import RandomRotate, RandomTranslate, RandomFlip, ToTensor, Compose, RandomAffine

In [2]:
class MRDataset(data.Dataset):
    def __init__(self, root_dir, task, plane, train=True, transform=None, weights=None):
        super().__init__()
        self.task = task
        self.plane = plane
        self.root_dir = root_dir
        self.train = train
        if self.train:
            self.fold_path = self.root_dir + 'train/{}/'.format(plane)
            self.records = pd.read_csv(self.root_dir + 'train-{}.csv'.format(task), header=None, names=['id', 'label'])
        else:
            transform = None
            self.fold_path = self.root_dir + 'valid/{}/'.format(plane)
            self.records = pd.read_csv(self.root_dir + 'valid-{}.csv'.format(task), header=None, names=['id', 'label'])
            
        self.records['id'] = self.records['id'].map(lambda i: '0'*(4 - len(str(i))) + str(i))
        self.paths = [self.fold_path + filename + '.npy' for filename in self.records['id'].tolist()]
        self.labels = self.records['label'].tolist()
        
        self.transform = transform
        
        if weights is None:
            pos = np.sum(self.labels)
            neg = len(self.labels) - pos
            self.weights = torch.FloatTensor([1, neg/pos])
        else:
            self.weights = torch.FloatTensor(weights)
            
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        array = np.load(self.paths[index])
        label = self.labels[index]
        
        if label == 1:
            label = torch.FloatTensor([0, 1])
        elif label == 0:
            label = torch.FloatTensor([1, 0])
            
        if self.transform:
            array = self.transform(array)
        else:
            array = np.stack((array, )*3, axis=1)
            array = torch.FloatTensor(array)
            
#         if label.item() == 1:
#             weight = np.array([self.weights[1]])
#             weight = torch.FloatTensor(weight)
#         else:
#             weight = np.array([self.weights[0]])
#             weight = torch.FloatTensor(weight)
            
        return array, label, self.weights

In [3]:
class MRNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.pretrained_model = models.alexnet(pretrained=True)
        self.pooling_layer = nn.AdaptiveAvgPool2d(1)
        self.classifer = nn.Linear(256, 2)
        
    def forward(self, x):
        x = torch.squeeze(x, dim=0)
        features = self.pretrained_model.features(x)
        pooled_features = self.pooling_layer(features)
        pooled_features = pooled_features.view(pooled_features.size(0), -1)
        flattened_features = torch.max(pooled_features, 0, keepdim=True)[0]
        output = self.classifer(flattened_features)
        return output

In [5]:
coronal_mrnet = torch.load('coronal_best.pth')
axial_mrnet = torch.load('axial_best.pth')
sagittal_mrnet = torch.load('sagittal_best.pth')

In [22]:
def extract_predictions(model, task, plane, train=True):
    
    train_dataset = MRDataset(root_dir='', task=task, plane=plane, transform=None, train=train)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False)
    
    predictions = []
    labels = []
    
    if torch.cuda.is_available():
        model.cuda()
    model.eval()
    with torch.no_grad():
        for image, label, _ in tqdm(train_loader):
            logit = model(image.cuda())
            prediction = torch.sigmoid(logit)
            predictions.append(prediction[0].cpu().numpy())
            labels.append(label[0].numpy())
    return predictions, labels

In [23]:
acl_models = {'coronal': coronal_mrnet, 'axial': axial_mrnet, 'sagittal': sagittal_mrnet}

In [24]:
task = 'acl'
results = {}

for plane in acl_models.keys():
    predictions, labels = extract_predictions(model=acl_models[plane], task=task, plane=plane)
    results['labels'] = labels
    results[plane] = np.array(predictions)
    
# x = np.zeros((len(predictions), 3))
# x[:, 0] = results['coronal']
# x[:, 1] = results['axial']
# x[:, 2] = results['sagittal']




  0%|          | 0/1130 [00:00<?, ?it/s][A
  0%|          | 1/1130 [00:01<20:57,  1.11s/it][A
  0%|          | 2/1130 [00:01<15:21,  1.22it/s][A
  0%|          | 3/1130 [00:01<11:43,  1.60it/s][A
  0%|          | 4/1130 [00:01<09:00,  2.08it/s][A
  0%|          | 5/1130 [00:01<06:53,  2.72it/s][A
  1%|          | 7/1130 [00:01<05:16,  3.54it/s][A
  1%|          | 9/1130 [00:02<04:14,  4.40it/s][A
  1%|          | 11/1130 [00:02<03:28,  5.35it/s][A
  1%|          | 13/1130 [00:02<02:58,  6.24it/s][A
  1%|          | 14/1130 [00:02<02:39,  6.98it/s][A
  1%|▏         | 16/1130 [00:02<02:19,  8.01it/s][A
  2%|▏         | 18/1130 [00:02<02:07,  8.69it/s][A
  2%|▏         | 20/1130 [00:03<01:57,  9.46it/s][A
  2%|▏         | 22/1130 [00:03<01:44, 10.64it/s][A
  2%|▏         | 24/1130 [00:03<01:43, 10.73it/s][A
  2%|▏         | 26/1130 [00:03<01:40, 10.99it/s][A
  2%|▏         | 28/1130 [00:03<01:42, 10.79it/s][A
  3%|▎         | 30/1130 [00:03<01:46, 10.30it/s][A
  3%|▎  

 26%|██▌       | 290/1130 [00:27<01:11, 11.81it/s][A
 26%|██▌       | 292/1130 [00:27<01:06, 12.58it/s][A
 26%|██▌       | 294/1130 [00:27<00:59, 13.99it/s][A
 26%|██▌       | 296/1130 [00:27<00:59, 13.91it/s][A
 26%|██▋       | 298/1130 [00:28<00:58, 14.15it/s][A
 27%|██▋       | 300/1130 [00:28<00:59, 13.91it/s][A
 27%|██▋       | 302/1130 [00:28<01:02, 13.22it/s][A
 27%|██▋       | 304/1130 [00:28<01:11, 11.63it/s][A
 27%|██▋       | 306/1130 [00:28<01:04, 12.72it/s][A
 27%|██▋       | 308/1130 [00:28<01:05, 12.50it/s][A
 27%|██▋       | 310/1130 [00:29<01:01, 13.44it/s][A
 28%|██▊       | 312/1130 [00:29<00:56, 14.53it/s][A
 28%|██▊       | 314/1130 [00:29<00:56, 14.40it/s][A
 28%|██▊       | 316/1130 [00:29<00:58, 13.87it/s][A
 28%|██▊       | 318/1130 [00:29<00:55, 14.62it/s][A
 28%|██▊       | 320/1130 [00:29<00:55, 14.72it/s][A
 28%|██▊       | 322/1130 [00:29<00:55, 14.44it/s][A
 29%|██▊       | 324/1130 [00:29<00:51, 15.54it/s][A
 29%|██▉       | 326/1130 [0

 52%|█████▏    | 593/1130 [00:53<00:55,  9.63it/s][A
 53%|█████▎    | 595/1130 [00:53<00:53,  9.94it/s][A
 53%|█████▎    | 597/1130 [00:53<00:53,  9.89it/s][A
 53%|█████▎    | 599/1130 [00:53<00:49, 10.73it/s][A
 53%|█████▎    | 601/1130 [00:53<00:44, 12.00it/s][A
 53%|█████▎    | 603/1130 [00:53<00:46, 11.29it/s][A
 54%|█████▎    | 605/1130 [00:54<00:46, 11.23it/s][A
 54%|█████▎    | 607/1130 [00:54<00:45, 11.51it/s][A
 54%|█████▍    | 609/1130 [00:54<00:46, 11.21it/s][A
 54%|█████▍    | 611/1130 [00:54<00:43, 11.90it/s][A
 54%|█████▍    | 613/1130 [00:54<00:40, 12.90it/s][A
 54%|█████▍    | 615/1130 [00:54<00:39, 13.09it/s][A
 55%|█████▍    | 617/1130 [00:55<00:37, 13.77it/s][A
 55%|█████▍    | 619/1130 [00:55<00:39, 13.09it/s][A
 55%|█████▍    | 621/1130 [00:55<00:39, 12.94it/s][A
 55%|█████▌    | 623/1130 [00:55<00:37, 13.41it/s][A
 55%|█████▌    | 625/1130 [00:55<00:41, 12.15it/s][A
 55%|█████▌    | 627/1130 [00:55<00:41, 12.02it/s][A
 56%|█████▌    | 629/1130 [0

 79%|███████▉  | 895/1130 [01:18<00:22, 10.36it/s][A
 79%|███████▉  | 897/1130 [01:18<00:20, 11.11it/s][A
 80%|███████▉  | 899/1130 [01:18<00:19, 11.62it/s][A
 80%|███████▉  | 901/1130 [01:19<00:18, 12.07it/s][A
 80%|███████▉  | 903/1130 [01:19<00:17, 12.74it/s][A
 80%|████████  | 905/1130 [01:19<00:16, 13.36it/s][A
 80%|████████  | 907/1130 [01:19<00:17, 12.90it/s][A
 80%|████████  | 909/1130 [01:19<00:18, 12.03it/s][A
 81%|████████  | 911/1130 [01:19<00:18, 11.96it/s][A
 81%|████████  | 913/1130 [01:20<00:17, 12.42it/s][A
 81%|████████  | 915/1130 [01:20<00:17, 12.36it/s][A
 81%|████████  | 917/1130 [01:20<00:17, 12.43it/s][A
 81%|████████▏ | 919/1130 [01:20<00:18, 11.21it/s][A
 82%|████████▏ | 921/1130 [01:20<00:18, 11.08it/s][A
 82%|████████▏ | 923/1130 [01:20<00:17, 11.58it/s][A
 82%|████████▏ | 925/1130 [01:21<00:16, 12.22it/s][A
 82%|████████▏ | 927/1130 [01:21<00:15, 13.34it/s][A
 82%|████████▏ | 929/1130 [01:21<00:16, 12.27it/s][A
 82%|████████▏ | 931/1130 [0

  5%|▍         | 56/1130 [00:05<01:43, 10.40it/s][A
  5%|▌         | 58/1130 [00:06<01:43, 10.40it/s][A
  5%|▌         | 60/1130 [00:06<01:40, 10.64it/s][A
  5%|▌         | 62/1130 [00:06<01:41, 10.53it/s][A
  6%|▌         | 64/1130 [00:06<01:41, 10.51it/s][A
  6%|▌         | 66/1130 [00:06<01:41, 10.47it/s][A
  6%|▌         | 68/1130 [00:06<01:40, 10.61it/s][A
  6%|▌         | 70/1130 [00:07<01:37, 10.84it/s][A
  6%|▋         | 72/1130 [00:07<01:29, 11.79it/s][A
  7%|▋         | 74/1130 [00:07<01:26, 12.20it/s][A
  7%|▋         | 76/1130 [00:07<01:27, 12.06it/s][A
  7%|▋         | 78/1130 [00:07<01:24, 12.40it/s][A
  7%|▋         | 80/1130 [00:07<01:32, 11.36it/s][A
  7%|▋         | 82/1130 [00:08<01:30, 11.57it/s][A
  7%|▋         | 84/1130 [00:08<01:30, 11.55it/s][A
  8%|▊         | 86/1130 [00:08<01:36, 10.76it/s][A
  8%|▊         | 88/1130 [00:08<01:49,  9.55it/s][A
  8%|▊         | 90/1130 [00:08<01:41, 10.27it/s][A
  8%|▊         | 92/1130 [00:09<01:37, 10.69it

 31%|███       | 353/1130 [00:34<01:25,  9.06it/s][A
 31%|███▏      | 355/1130 [00:34<01:13, 10.48it/s][A
 32%|███▏      | 357/1130 [00:34<01:13, 10.57it/s][A
 32%|███▏      | 359/1130 [00:34<01:14, 10.35it/s][A
 32%|███▏      | 361/1130 [00:34<01:15, 10.21it/s][A
 32%|███▏      | 363/1130 [00:35<01:15, 10.12it/s][A
 32%|███▏      | 365/1130 [00:35<01:15, 10.15it/s][A
 32%|███▏      | 367/1130 [00:35<01:17,  9.88it/s][A
 33%|███▎      | 369/1130 [00:35<01:19,  9.54it/s][A
 33%|███▎      | 371/1130 [00:36<01:19,  9.53it/s][A
 33%|███▎      | 372/1130 [00:36<01:22,  9.22it/s][A
 33%|███▎      | 374/1130 [00:36<01:12, 10.48it/s][A
 33%|███▎      | 376/1130 [00:36<01:09, 10.90it/s][A
 33%|███▎      | 378/1130 [00:36<01:03, 11.78it/s][A
 34%|███▎      | 380/1130 [00:36<01:05, 11.40it/s][A
 34%|███▍      | 382/1130 [00:36<01:02, 11.93it/s][A
 34%|███▍      | 384/1130 [00:37<01:04, 11.58it/s][A
 34%|███▍      | 386/1130 [00:37<01:03, 11.73it/s][A
 34%|███▍      | 388/1130 [0

 58%|█████▊    | 652/1130 [01:01<01:10,  6.78it/s][A
 58%|█████▊    | 653/1130 [01:02<01:05,  7.25it/s][A
 58%|█████▊    | 654/1130 [01:02<01:09,  6.85it/s][A
 58%|█████▊    | 656/1130 [01:02<01:01,  7.75it/s][A
 58%|█████▊    | 658/1130 [01:02<00:56,  8.40it/s][A
 58%|█████▊    | 659/1130 [01:02<00:53,  8.76it/s][A
 58%|█████▊    | 660/1130 [01:02<00:52,  9.03it/s][A
 59%|█████▊    | 662/1130 [01:03<00:50,  9.18it/s][A
 59%|█████▊    | 663/1130 [01:03<00:49,  9.39it/s][A
 59%|█████▉    | 665/1130 [01:03<00:46, 10.05it/s][A
 59%|█████▉    | 667/1130 [01:03<00:43, 10.65it/s][A
 59%|█████▉    | 669/1130 [01:03<00:41, 11.19it/s][A
 59%|█████▉    | 671/1130 [01:03<00:40, 11.23it/s][A
 60%|█████▉    | 673/1130 [01:03<00:40, 11.24it/s][A
 60%|█████▉    | 675/1130 [01:04<00:40, 11.17it/s][A
 60%|█████▉    | 677/1130 [01:04<00:39, 11.37it/s][A
 60%|██████    | 679/1130 [01:04<00:38, 11.86it/s][A
 60%|██████    | 681/1130 [01:04<00:38, 11.68it/s][A
 60%|██████    | 683/1130 [0

 83%|████████▎ | 935/1130 [01:29<00:18, 10.42it/s][A
 83%|████████▎ | 937/1130 [01:29<00:17, 10.87it/s][A
 83%|████████▎ | 939/1130 [01:29<00:16, 11.26it/s][A
 83%|████████▎ | 941/1130 [01:29<00:16, 11.38it/s][A
 83%|████████▎ | 943/1130 [01:29<00:18, 10.34it/s][A
 84%|████████▎ | 945/1130 [01:29<00:18, 10.09it/s][A
 84%|████████▍ | 947/1130 [01:30<00:17, 10.47it/s][A
 84%|████████▍ | 949/1130 [01:30<00:16, 10.98it/s][A
 84%|████████▍ | 951/1130 [01:30<00:17, 10.06it/s][A
 84%|████████▍ | 953/1130 [01:30<00:17, 10.38it/s][A
 85%|████████▍ | 955/1130 [01:30<00:15, 11.14it/s][A
 85%|████████▍ | 957/1130 [01:31<00:14, 12.29it/s][A
 85%|████████▍ | 959/1130 [01:31<00:14, 11.69it/s][A
 85%|████████▌ | 961/1130 [01:31<00:13, 12.37it/s][A
 85%|████████▌ | 963/1130 [01:31<00:14, 11.72it/s][A
 85%|████████▌ | 965/1130 [01:31<00:14, 11.12it/s][A
 86%|████████▌ | 967/1130 [01:31<00:15, 10.79it/s][A
 86%|████████▌ | 969/1130 [01:32<00:16,  9.48it/s][A
 86%|████████▌ | 970/1130 [0

  9%|▉         | 104/1130 [00:08<01:23, 12.24it/s][A
  9%|▉         | 106/1130 [00:09<01:20, 12.72it/s][A
 10%|▉         | 108/1130 [00:09<01:20, 12.71it/s][A
 10%|▉         | 110/1130 [00:09<01:20, 12.71it/s][A
 10%|▉         | 112/1130 [00:09<01:19, 12.87it/s][A
 10%|█         | 114/1130 [00:09<01:17, 13.12it/s][A
 10%|█         | 116/1130 [00:09<01:16, 13.30it/s][A
 10%|█         | 118/1130 [00:09<01:15, 13.43it/s][A
 11%|█         | 120/1130 [00:10<01:20, 12.55it/s][A
 11%|█         | 122/1130 [00:10<01:18, 12.79it/s][A
 11%|█         | 124/1130 [00:10<01:15, 13.29it/s][A
 11%|█         | 126/1130 [00:10<01:15, 13.37it/s][A
 11%|█▏        | 128/1130 [00:10<01:17, 13.00it/s][A
 12%|█▏        | 130/1130 [00:10<01:19, 12.59it/s][A
 12%|█▏        | 132/1130 [00:11<01:19, 12.60it/s][A
 12%|█▏        | 134/1130 [00:11<01:22, 12.01it/s][A
 12%|█▏        | 136/1130 [00:11<01:17, 12.77it/s][A
 12%|█▏        | 138/1130 [00:11<01:21, 12.23it/s][A
 12%|█▏        | 140/1130 [0

 36%|███▌      | 403/1130 [00:34<00:59, 12.17it/s][A
 36%|███▌      | 405/1130 [00:34<00:57, 12.67it/s][A
 36%|███▌      | 407/1130 [00:34<00:53, 13.64it/s][A
 36%|███▌      | 409/1130 [00:35<01:36,  7.51it/s][A
 36%|███▋      | 411/1130 [00:35<01:25,  8.45it/s][A
 37%|███▋      | 413/1130 [00:35<01:16,  9.43it/s][A
 37%|███▋      | 415/1130 [00:35<01:13,  9.71it/s][A
 37%|███▋      | 417/1130 [00:35<01:08, 10.43it/s][A
 37%|███▋      | 419/1130 [00:36<01:04, 10.97it/s][A
 37%|███▋      | 421/1130 [00:36<01:01, 11.53it/s][A
 37%|███▋      | 423/1130 [00:36<00:58, 12.18it/s][A
 38%|███▊      | 425/1130 [00:36<00:56, 12.38it/s][A
 38%|███▊      | 427/1130 [00:36<01:11,  9.81it/s][A
 38%|███▊      | 429/1130 [00:37<01:13,  9.53it/s][A
 38%|███▊      | 431/1130 [00:37<01:07, 10.30it/s][A
 38%|███▊      | 433/1130 [00:37<01:11,  9.71it/s][A
 38%|███▊      | 435/1130 [00:37<01:05, 10.66it/s][A
 39%|███▊      | 437/1130 [00:37<00:59, 11.71it/s][A
 39%|███▉      | 439/1130 [0

 62%|██████▏   | 705/1130 [01:00<00:33, 12.51it/s][A
 63%|██████▎   | 707/1130 [01:00<00:35, 11.85it/s][A
 63%|██████▎   | 709/1130 [01:00<00:37, 11.16it/s][A
 63%|██████▎   | 711/1130 [01:00<00:39, 10.64it/s][A
 63%|██████▎   | 713/1130 [01:00<00:42,  9.89it/s][A
 63%|██████▎   | 715/1130 [01:01<00:41, 10.05it/s][A
 63%|██████▎   | 717/1130 [01:01<00:38, 10.70it/s][A
 64%|██████▎   | 719/1130 [01:01<00:39, 10.43it/s][A
 64%|██████▍   | 721/1130 [01:01<00:35, 11.45it/s][A
 64%|██████▍   | 723/1130 [01:02<01:07,  6.02it/s][A
 64%|██████▍   | 724/1130 [01:02<00:59,  6.79it/s][A
 64%|██████▍   | 726/1130 [01:02<00:52,  7.69it/s][A
 64%|██████▍   | 728/1130 [01:02<00:46,  8.70it/s][A
 65%|██████▍   | 730/1130 [01:02<00:44,  9.00it/s][A
 65%|██████▍   | 732/1130 [01:03<00:40,  9.90it/s][A
 65%|██████▍   | 734/1130 [01:03<00:36, 10.91it/s][A
 65%|██████▌   | 736/1130 [01:03<00:37, 10.60it/s][A
 65%|██████▌   | 738/1130 [01:03<00:35, 11.15it/s][A
 65%|██████▌   | 740/1130 [0

 89%|████████▉ | 1006/1130 [01:25<00:12, 10.27it/s][A
 89%|████████▉ | 1008/1130 [01:26<00:14,  8.41it/s][A
 89%|████████▉ | 1010/1130 [01:26<00:13,  9.08it/s][A
 90%|████████▉ | 1012/1130 [01:26<00:11, 10.55it/s][A
 90%|████████▉ | 1014/1130 [01:26<00:10, 11.60it/s][A
 90%|████████▉ | 1016/1130 [01:26<00:10, 10.99it/s][A
 90%|█████████ | 1018/1130 [01:27<00:09, 11.92it/s][A
 90%|█████████ | 1020/1130 [01:27<00:10, 10.63it/s][A
 90%|█████████ | 1022/1130 [01:27<00:10, 10.68it/s][A
 91%|█████████ | 1024/1130 [01:27<00:09, 11.50it/s][A
 91%|█████████ | 1026/1130 [01:27<00:08, 11.85it/s][A
 91%|█████████ | 1028/1130 [01:27<00:08, 12.09it/s][A
 91%|█████████ | 1030/1130 [01:28<00:07, 13.14it/s][A
 91%|█████████▏| 1032/1130 [01:28<00:07, 12.58it/s][A
 92%|█████████▏| 1034/1130 [01:28<00:07, 12.96it/s][A
 92%|█████████▏| 1036/1130 [01:28<00:06, 13.81it/s][A
 92%|█████████▏| 1038/1130 [01:28<00:06, 13.59it/s][A
 92%|█████████▏| 1040/1130 [01:28<00:06, 13.69it/s][A
 92%|█████

In [35]:
x = np.concatenate((results['coronal'], results['axial'], results['sagittal']), axis=1)
print(x.shape) # (1130, 6)
# y = np.array([np.argmax(label) for labe in results['labels']])
y = np.argmax(np.array(results['labels']), axis=1)
print(y.shape)

clf = LogisticRegression()
clf.fit(x, y)
clf.score(x, y)

(1130, 6)
(1130,)


0.8893805309734514

In [40]:
y_pred = clf.predict_proba(x)[:, 1]
roc_auc_score(y, y_pred), clf.score(x, y)

(0.9090762139162355, 0.8893805309734514)

In [42]:
task = 'acl'
results_val = {}

for plane in acl_models.keys():
    predictions, labels = extract_predictions(model=acl_models[plane], task=task, plane=plane, train=False)
    results_val['labels'] = labels
    results_val[plane] = np.array(predictions)
    
x_val = np.concatenate((results_val['coronal'], results_val['axial'], results_val['sagittal']), axis=1)
print(x_val.shape) # (1130, 6)
# y = np.array([np.argmax(label) for labe in results['labels']])
y_val = np.argmax(np.array(results_val['labels']), axis=1)
print(y_val.shape)
    
    
y_pred = clf.predict_proba(x_val)[:, 1]
roc_auc_score(y_val, y_pred), clf.score(x_val, y_val)


  0%|          | 0/120 [00:00<?, ?it/s][A
  2%|▏         | 2/120 [00:00<00:06, 18.07it/s][A
  3%|▎         | 4/120 [00:00<00:06, 18.41it/s][A
  5%|▌         | 6/120 [00:00<00:06, 17.68it/s][A
  7%|▋         | 8/120 [00:00<00:06, 17.15it/s][A
  9%|▉         | 11/120 [00:00<00:06, 17.25it/s][A
 11%|█         | 13/120 [00:00<00:06, 16.74it/s][A
 12%|█▎        | 15/120 [00:00<00:06, 15.71it/s][A
 14%|█▍        | 17/120 [00:01<00:06, 15.81it/s][A
 17%|█▋        | 20/120 [00:01<00:05, 17.78it/s][A
 18%|█▊        | 22/120 [00:01<00:05, 17.63it/s][A
 21%|██        | 25/120 [00:01<00:05, 17.90it/s][A
 22%|██▎       | 27/120 [00:01<00:05, 15.79it/s][A
 24%|██▍       | 29/120 [00:01<00:05, 15.32it/s][A
 26%|██▌       | 31/120 [00:01<00:05, 15.60it/s][A
 28%|██▊       | 34/120 [00:02<00:05, 16.23it/s][A
 30%|███       | 36/120 [00:02<00:04, 17.17it/s][A
 32%|███▏      | 38/120 [00:02<00:05, 16.01it/s][A
 34%|███▍      | 41/120 [00:02<00:04, 17.67it/s][A
 36%|███▌      | 43/120 

 80%|████████  | 96/120 [00:05<00:01, 15.57it/s][A
 82%|████████▏ | 98/120 [00:05<00:01, 15.60it/s][A
 84%|████████▍ | 101/120 [00:05<00:01, 16.33it/s][A
 86%|████████▌ | 103/120 [00:06<00:01, 15.82it/s][A
 88%|████████▊ | 105/120 [00:06<00:00, 15.89it/s][A
 89%|████████▉ | 107/120 [00:06<00:00, 15.21it/s][A
 91%|█████████ | 109/120 [00:06<00:00, 15.49it/s][A
 92%|█████████▎| 111/120 [00:06<00:00, 14.75it/s][A
 94%|█████████▍| 113/120 [00:06<00:00, 15.95it/s][A
 96%|█████████▌| 115/120 [00:06<00:00, 14.64it/s][A
 98%|█████████▊| 117/120 [00:06<00:00, 14.80it/s][A
100%|██████████| 120/120 [00:07<00:00, 16.78it/s][A

(120, 6)
(120,)





(0.878226711560045, 0.75)

In [86]:
clf.predict(x_val)

array([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
       1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0,
       0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0,
       0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 1, 1, 0, 0, 0], dtype=int64)

In [104]:
case = '1247'
acl_models = {'coronal': coronal_mrnet, 'axial': axial_mrnet, 'sagittal': sagittal_mrnet}
results_pre = {}

for plane in ['coronal', 'axial', 'sagittal']:
    filename = 'valid/'+plane+'/'+case+'.npy'
    img = np.load(filename)
    img = np.stack((img, )*3, axis=1)
    img = torch.FloatTensor(img)
    if torch.cuda.is_available():
        acl_models[plane].cuda()
    acl_models[plane].eval()
    with torch.no_grad():
        pre = acl_models[plane](img.cuda())
        pre = torch.sigmoid(pre)
        results_pre[plane] = np.array(pre.cpu().numpy())

In [105]:
x_pre = np.concatenate((results_pre['coronal'], results_pre['axial'], results_pre['sagittal']), axis=1)
print(x_pre.shape) # (1, 6)


(1, 6)


In [106]:
clf.predict(x_pre), clf.predict_proba(x_pre)

(array([0], dtype=int64), array([[0.93574139, 0.06425861]]))