In [5]:
from collections import OrderedDict
from typing import Callable, Sequence

import torch
import torch.nn as nn

import sys
import logging
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import NiftiDataset, CSVSaver
from monai.transforms import Compose, SpatialPad, AddChannel, ScaleIntensity, Resize, RandRotate90, RandRotate, RandZoom, ToTensor
import os

from sklearn.metrics import classification_report
from monai.networks.layers.factories import Conv, Dropout, Pool, Norm

In [6]:
class _DenseLayer(nn.Sequential):
    def __init__(
        self, spatial_dims: int, in_channels: int, growth_rate: int, bn_size: int, dropout_prob: float
    ) -> None:
        super(_DenseLayer, self).__init__()

        out_channels = bn_size * growth_rate
        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
        dropout_type: Callable = Dropout[Dropout.DROPOUT, spatial_dims]

        self.add_module("norm1", norm_type(in_channels))
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.add_module("conv1", conv_type(in_channels, out_channels, kernel_size=1, bias=False))

        self.add_module("norm2", norm_type(out_channels))
        self.add_module("relu2", nn.ReLU(inplace=True))
        self.add_module("conv2", conv_type(out_channels, growth_rate, kernel_size=3, padding=1, bias=False))

        if dropout_prob > 0:
            self.add_module("dropout", dropout_type(dropout_prob))

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        return torch.cat([x, new_features], 1)


class _DenseBlock(nn.Sequential):
    def __init__(
        self, spatial_dims: int, layers: int, in_channels: int, bn_size: int, growth_rate: int, dropout_prob: float
    ) -> None:
        super(_DenseBlock, self).__init__()
        for i in range(layers):
            layer = _DenseLayer(spatial_dims, in_channels, growth_rate, bn_size, dropout_prob)
            in_channels += growth_rate
            self.add_module("denselayer%d" % (i + 1), layer)


class _Transition(nn.Sequential):
    def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None:
        super(_Transition, self).__init__()

        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
        pool_type: Callable = Pool[Pool.AVG, spatial_dims]

        self.add_module("norm", norm_type(in_channels))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", conv_type(in_channels, out_channels, kernel_size=1, bias=False))
        self.add_module("pool", pool_type(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    """
    Densenet based on: "Densely Connected Convolutional Networks" https://arxiv.org/pdf/1608.06993.pdf
    Adapted from PyTorch Hub 2D version:
    https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py

    Args:
        spatial_dims: number of spatial dimensions of the input image.
        in_channels: number of the input channel.
        out_channels: number of the output classes.
        init_features: number of filters in the first convolution layer.
        growth_rate: how many filters to add each layer (k in paper).
        block_config: how many layers in each pooling block.
        bn_size: multiplicative factor for number of bottle neck layers.
                      (i.e. bn_size * k features in the bottleneck layer)
        dropout_prob: dropout rate after each dense layer.
    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        init_features: int = 64,
        growth_rate: int = 32,
        block_config: Sequence[int] = (6, 12, 24, 16),
        bn_size: int = 4,
        dropout_prob: float = 0.0,
        bins=8,
        pool_kernel=32,
        pool_stride=32
    ) -> None:

        super(DenseNet, self).__init__()

        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
        pool_type: Callable = Pool[Pool.MAX, spatial_dims]
        avg_pool_type: Callable = Pool[Pool.ADAPTIVEAVG, spatial_dims]

        self.features = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", conv_type(in_channels, init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                    ("norm0", norm_type(init_features)),
                    ("relu0", nn.ReLU(inplace=True)),
                    ("pool0", pool_type(kernel_size=3, stride=2, padding=1)),
                ]
            )
        )

        in_channels = init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(
                spatial_dims=spatial_dims,
                layers=num_layers,
                in_channels=in_channels,
                bn_size=bn_size,
                growth_rate=growth_rate,
                dropout_prob=dropout_prob,
            )
            self.features.add_module("denseblock%d" % (i + 1), block)
            in_channels += num_layers * growth_rate
            if i == len(block_config) - 1:
                self.features.add_module("norm5", norm_type(in_channels))
            else:
                _out_channels = in_channels // 2
                trans = _Transition(spatial_dims, in_channels=in_channels, out_channels=_out_channels)
                self.features.add_module("transition%d" % (i + 1), trans)
                in_channels = _out_channels

        # pooling and classification
        '''
        self.class_layers = nn.Sequential(
            OrderedDict(
                [
                    ("relu", nn.ReLU(inplace=True)),
                    ("norm", avg_pool_type(1)),
                    ("flatten", nn.Flatten(1)), 
                    ("class", nn.Linear(in_channels, out_channels)),
                ]
            )
        )
        '''
        self.relu = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(1024*4*4*4, 1024)
        self.fc2 = nn.Linear(2048, out_channels)

        for m in self.modules():
            if isinstance(m, conv_type):  # type: ignore
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, norm_type):  # type: ignore
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)
        
        # 3D-ImHistNet
        self.conv1 = nn.Conv3d(1, bins, 1, 1)
        nn.init.constant_(self.conv1.weight, 1.0)
        
        
        self.conv2 = nn.Conv3d(bins, bins, 1, 1, groups=bins)
        nn.init.constant_(self.conv2.bias, 1.0)
    
        self.avgpool = nn.AvgPool3d(pool_kernel, pool_stride)
        self.hist_fc = nn.Linear(bins*4*4*4, 1024)
        #self.fc2 = nn.Linear(1024, no_classes)

        #initialize_params(self)

    def forward(self, x):
        x1 = self.features(x)
        x1 = torch.flatten(x1, 1)
        x1 = self.fc1(x1)
        
        x2 = self.conv1(x)
        x2 = torch.abs(x2)
        x2 = self.conv2(x2)
        x2 = self.relu(x2)
        x2 = self.avgpool(x2)
        x2 = torch.flatten(x2, 1)
        x2 = self.hist_fc(x2)
        
        x_cat = torch.cat([x1, x2], 1)
        x_cat = self.fc2(x_cat)
        return x_cat


def densenet121(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
    return model


def densenet169(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
    return model


def densenet201(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
    return model


def densenet264(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 64, 48), **kwargs)
    return model

In [11]:
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    
    # Validation data paths
    data_dir = '/home/marafath/scratch/iran_organized_data2'
    class_names = ['No-COVID', 'COVID']
    
    covid_pat = 0
    non_covid_pat = 0

    images_p = []
    labels_p = []
    images_n = []
    labels_n = []

    for patient in os.listdir(data_dir):
        if int(patient[-1]) == 0 and non_covid_pat > 236:
            continue 

        if int(patient[-1]) == 1:
            covid_pat += 1
            for series in os.listdir(os.path.join(data_dir,patient)):
                labels_p.append(1)
                images_p.append(os.path.join(data_dir,patient,series,'cropped_and_resized_image.nii.gz'))
        else:
            non_covid_pat += 1
            for series in os.listdir(os.path.join(data_dir,patient)):
                labels_n.append(0)
                images_n.append(os.path.join(data_dir,patient,series,'cropped_and_resized_image.nii.gz'))
            
    train_images = []
    train_labels = []

    val_images = []
    val_labels = []
    
    for i in range(0,len(images_p)):
        if i < 407:
            train_images.append(images_p[i])
            train_labels.append(labels_p[i])
        else:
            val_images.append(images_p[i])
            val_labels.append(labels_p[i])

    for i in range(0,len(images_n)):
        if i < 405:
            train_images.append(images_n[i])
            train_labels.append(labels_n[i])
        else:
            val_images.append(images_n[i])
            val_labels.append(labels_n[i])  
    
    train_labels = np.asarray(train_labels,np.int64)
    val_labels = np.asarray(val_labels,np.int64)
    
    test_images = val_images
    test_labels = val_labels

    # Define transforms for image
    val_transforms = Compose([
        ScaleIntensity(),
        AddChannel(),
        #SpatialPad((128, 128, 92), mode='constant'),
        #Resize((128, 128, 92)),
        ToTensor()
    ])

    # Define nifti dataset
    val_ds = NiftiDataset(image_files=test_images, labels=test_labels, transform=val_transforms, image_only=False)
    val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())
    
    # Create DenseNet121
    device = torch.device('cuda:0')
    model = densenet121(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
    ).to(device)

    model.load_state_dict(torch.load('/home/marafath/scratch/saved_models/best_model_densenet_imhistnet.pth'))
    model.eval()
    
    y_true = list()
    y_pred = list()
    
    with torch.no_grad():
        num_correct = 0.
        metric_count = 0
        saver = CSVSaver(output_dir='./output')
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            val_outputs = model(val_images).argmax(dim=1)
            value = torch.eq(val_outputs, val_labels)
            metric_count += len(value)
            num_correct += value.sum().item()
            saver.save_batch(val_outputs, val_data[2])
            
            for i in range(len(val_outputs)):
                y_true.append(val_labels[i].item())
                y_pred.append(val_outputs[i].item())
            
        metric = num_correct / metric_count
        print('evaluation metric:', metric)
        print(classification_report(y_true, y_pred, target_names=class_names, digits=4))
        saver.finalize()

if __name__ == '__main__':
    main()

MONAI version: 0.1.0+560.gc89bf24.dirty
Python version: 3.7.4 (default, Jul 18 2019, 19:34:02)  [GCC 5.4.0]
Numpy version: 1.18.1
Pytorch version: 1.5.0

Optional dependencies:
Pytorch Ignite version: 0.3.0
Nibabel version: 3.1.0
scikit-image version: 0.14.2
Pillow version: 7.0.0
Tensorboard version: 2.1.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

evaluation metric: 0.765
              precision    recall  f1-score   support

    No-COVID     0.8193    0.6800    0.7432       100
       COVID     0.7265    0.8500    0.7834       100

    accuracy                         0.7650       200
   macro avg     0.7729    0.7650    0.7633       200
weighted avg     0.7729    0.7650    0.7633       200



In [12]:
import csv

filename = './output/predictions.csv'
rows = [] 

with open(filename, 'r') as csvfile: 
    # creating a csv reader object 
    csvreader = csv.reader(csvfile) 
  
    # extracting each data row one by one 
    for row in csvreader: 
        rows.append(row) 
  
    # get total number of rows 
    print("Total no. of rows: %d"%(csvreader.line_num))

Total no. of rows: 200


In [24]:
rw = rows[0][0]
pat = rw[44:51]
lab = float(rw[52:53])

print(pat, lab)

1934813 1.0


In [45]:
unique_pat = []
pat_label_gt = []
pat_label_pr = []

rw = rows[0][0]
gt = rows[i][1]
unique_pat.append(rw[44:51])
pat_label_pr.append(int(rw[52:53]))
pat_label_gt.append(int(gt[0]))

cnt = 0
for i in range(1,len(rows)):
    rw = rows[i][0]
    pat = rw[44:51]
    lab = int(rw[52:53])
    gt = rows[i][1]
    for j in range(cnt,len(unique_pat)):
        if pat == unique_pat[j]:
            break
        else:
            unique_pat.append(pat)
            pat_label_gt.append(int(gt[0]))
            pat_label_pr.append(lab)
            cnt += 1

print(len(unique_pat)) 
print(len(pat_label_gt))
print(len(pat_label_pr))

84
84
84


In [48]:
class_names = ['No-COVID', 'COVID']
print(classification_report(pat_label_gt, pat_label_pr, target_names=class_names, digits=4))

              precision    recall  f1-score   support

    No-COVID     0.6410    0.7576    0.6944        33
       COVID     0.8222    0.7255    0.7708        51

    accuracy                         0.7381        84
   macro avg     0.7316    0.7415    0.7326        84
weighted avg     0.7510    0.7381    0.7408        84



In [49]:
print(unique_pat) 

['1934813', '1941042', '1234665', '1234660', '1234703', '1944494', '1944098', '1932991', '1948838', '1234714', '1234683', '1936916', '1929407', '1234651', '1934990', '1935297', '1952007', '1942841', '1925208', '1944643', '1933009', '1234676', '1935046', '1927330', '1925052', '1234709', '1937280', '1234644', '1944492', '1234647', '1945594', '1234650', '1234657', '1234681', '1234649', '1934915', '1935091', '1935875', '1234697', '1923960', '1935362', '1932652', '1234675', '1234678', '1939077', '1596773', '1875695', '1234606', '1765184', '1925841', '1911030', '1739046', '1910787', '1929851', '1801474', '1234636', '1724545', '1321733', '1592040', '1761939', '1919053', '1760384', '1895254', '1611801', '1598906', '1759540', '1838821', '1604132', '1604234', '1735070', '1234633', '1234586', '1768980', '1700984', '1723943', '1601335', '1923084', '1604331', '1763514', '1922987', '1906333', '1234625', '1605230', '1930107']
