In [1]:
from collections import OrderedDict
from typing import Callable, Sequence

import torch
import torch.nn as nn

import sys
import logging
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import NiftiDataset
from monai.transforms import Compose, Spacing, SpatialPad, AddChannel, ScaleIntensity, Resize, RandRotate90, RandRotate, RandZoom, ToTensor
import os

from monai.networks.layers.factories import Conv, Dropout, Pool, Norm

In [2]:
class _DenseLayer(nn.Sequential):
    def __init__(
        self, spatial_dims: int, in_channels: int, growth_rate: int, bn_size: int, dropout_prob: float
    ) -> None:
        super(_DenseLayer, self).__init__()

        out_channels = bn_size * growth_rate
        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
        dropout_type: Callable = Dropout[Dropout.DROPOUT, spatial_dims]

        self.add_module("norm1", norm_type(in_channels))
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.add_module("conv1", conv_type(in_channels, out_channels, kernel_size=1, bias=False))

        self.add_module("norm2", norm_type(out_channels))
        self.add_module("relu2", nn.ReLU(inplace=True))
        self.add_module("conv2", conv_type(out_channels, growth_rate, kernel_size=3, padding=1, bias=False))

        if dropout_prob > 0:
            self.add_module("dropout", dropout_type(dropout_prob))

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        return torch.cat([x, new_features], 1)


class _DenseBlock(nn.Sequential):
    def __init__(
        self, spatial_dims: int, layers: int, in_channels: int, bn_size: int, growth_rate: int, dropout_prob: float
    ) -> None:
        super(_DenseBlock, self).__init__()
        for i in range(layers):
            layer = _DenseLayer(spatial_dims, in_channels, growth_rate, bn_size, dropout_prob)
            in_channels += growth_rate
            self.add_module("denselayer%d" % (i + 1), layer)


class _Transition(nn.Sequential):
    def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None:
        super(_Transition, self).__init__()

        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
        pool_type: Callable = Pool[Pool.AVG, spatial_dims]

        self.add_module("norm", norm_type(in_channels))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", conv_type(in_channels, out_channels, kernel_size=1, bias=False))
        self.add_module("pool", pool_type(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    """
    Densenet based on: "Densely Connected Convolutional Networks" https://arxiv.org/pdf/1608.06993.pdf
    Adapted from PyTorch Hub 2D version:
    https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py

    Args:
        spatial_dims: number of spatial dimensions of the input image.
        in_channels: number of the input channel.
        out_channels: number of the output classes.
        init_features: number of filters in the first convolution layer.
        growth_rate: how many filters to add each layer (k in paper).
        block_config: how many layers in each pooling block.
        bn_size: multiplicative factor for number of bottle neck layers.
                      (i.e. bn_size * k features in the bottleneck layer)
        dropout_prob: dropout rate after each dense layer.
    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        init_features: int = 64,
        growth_rate: int = 32,
        block_config: Sequence[int] = (6, 12, 24, 16),
        bn_size: int = 4,
        dropout_prob: float = 0.0,
        bins=8,
        pool_kernel=32,
        pool_stride=32
    ) -> None:

        super(DenseNet, self).__init__()

        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
        pool_type: Callable = Pool[Pool.MAX, spatial_dims]
        avg_pool_type: Callable = Pool[Pool.ADAPTIVEAVG, spatial_dims]

        self.features = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", conv_type(in_channels, init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                    ("norm0", norm_type(init_features)),
                    ("relu0", nn.ReLU(inplace=True)),
                    ("pool0", pool_type(kernel_size=3, stride=2, padding=1)),
                ]
            )
        )

        in_channels = init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(
                spatial_dims=spatial_dims,
                layers=num_layers,
                in_channels=in_channels,
                bn_size=bn_size,
                growth_rate=growth_rate,
                dropout_prob=dropout_prob,
            )
            self.features.add_module("denseblock%d" % (i + 1), block)
            in_channels += num_layers * growth_rate
            if i == len(block_config) - 1:
                self.features.add_module("norm5", norm_type(in_channels))
            else:
                _out_channels = in_channels // 2
                trans = _Transition(spatial_dims, in_channels=in_channels, out_channels=_out_channels)
                self.features.add_module("transition%d" % (i + 1), trans)
                in_channels = _out_channels

        # pooling and classification
        '''
        self.class_layers = nn.Sequential(
            OrderedDict(
                [
                    ("relu", nn.ReLU(inplace=True)),
                    ("norm", avg_pool_type(1)),
                    ("flatten", nn.Flatten(1)), 
                    ("class", nn.Linear(in_channels, out_channels)),
                ]
            )
        )
        '''
        self.relu = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(1024*4*4*4, 1024)
        self.fc2 = nn.Linear(2048, out_channels)

        for m in self.modules():
            if isinstance(m, conv_type):  # type: ignore
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, norm_type):  # type: ignore
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)
        
        # 3D-ImHistNet
        self.conv1 = nn.Conv3d(1, bins, 1, 1)
        nn.init.constant_(self.conv1.weight, 1.0)
        
        
        self.conv2 = nn.Conv3d(bins, bins, 1, 1, groups=bins)
        nn.init.constant_(self.conv2.bias, 1.0)
    
        self.avgpool = nn.AvgPool3d(pool_kernel, pool_stride)
        self.hist_fc = nn.Linear(bins*4*4*4, 1024)
        #self.fc2 = nn.Linear(1024, no_classes)

        #initialize_params(self)

    def forward(self, x):
        x1 = self.features(x)
        x1 = torch.flatten(x1, 1)
        x1 = self.fc1(x1)
        
        x2 = self.conv1(x)
        x2 = torch.abs(x2)
        x2 = self.conv2(x2)
        x2 = self.relu(x2)
        x2 = self.avgpool(x2)
        x2 = torch.flatten(x2, 1)
        x2 = self.hist_fc(x2)
        
        x_cat = torch.cat([x1, x2], 1)
        x_cat = self.fc2(x_cat)
        return x_cat


def densenet121(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
    return model


def densenet169(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
    return model


def densenet201(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
    return model


def densenet264(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 64, 48), **kwargs)
    return model

In [3]:
# Training data paths
data_dir = '/home/marafath/scratch/iran_organized_data2'

covid_pat = 0
non_covid_pat = 0

pos_pat = []
neg_pat = []

for patient in os.listdir(data_dir):
    if int(patient[-1]) == 1:
        covid_pat += 1
        pos_pat.append(os.path.join(data_dir,patient))
    else:
        non_covid_pat += 1
        neg_pat.append(os.path.join(data_dir,patient))
        
print('Corona+ patient {}'.format(covid_pat))
print('Corona- patient {}'.format(non_covid_pat))

Corona+ patient 223
Corona- patient 399


In [5]:
# 4-fold cross-validation
import math

pos_val = math.floor(covid_pat/4)
neg_val = math.floor(non_covid_pat/4)

print('Corona+ patient {}'.format(pos_val))
print('Corona- patient {}'.format(neg_val))

Corona+ patient 55
Corona- patient 99


In [6]:
# Training data paths
fold = 4

for i in range(0,4): # i = current fold
    p_strt = pos_val*i + 1
    p_end = pos_val*i + pos_val
    n_strt = p_strt
    n_end = p_end
    
    pos_resample_rate = math.floor((len(neg_pat)-pos_val)/(len(pos_pat)-pos_val))
    
    val_image = []
    val_label = []
    trn_image = []
    trn_label = []
        
    for j, pat_link in enumerate(pos_pat): 
        if j >= p_strt and j < p_end:
            for series in os.listdir(pat_link):
                val_label.append(1)
                val_image.append(os.path.join(pat_link,series,'cropped_and_resized_image.nii.gz'))
                break
        else:
            for series in os.listdir(pat_link):
                for rs in range(0, pos_resample_rate):
                    trn_label.append(1)
                    trn_image.append(os.path.join(pat_link,series,'cropped_and_resized_image.nii.gz'))
        
    for k, pat_link in enumerate(neg_pat): 
        if k >= n_strt and k < n_end:
            for series in os.listdir(pat_link):
                val_label.append(0)
                val_image.append(os.path.join(pat_link,series,'cropped_and_resized_image.nii.gz'))
                break
        else:
            for series in os.listdir(pat_link):
                trn_label.append(0)
                trn_image.append(os.path.join(pat_link,series,'cropped_and_resized_image.nii.gz'))
    break  

trn_label = np.asarray(trn_label,np.int64)
val_label = np.asarray(val_label,np.int64)

In [7]:
print('train scans {}'.format(len(trn_label)))
print('val scans {}'.format(len(val_label)))

train scans 1468
val scans 103


In [None]:
print('pos_train scans {}'.format(np.sum(trn_label)))
print('pos_val scans {}'.format(np.sum(val_label)))

In [8]:
# Define transforms
train_transforms = Compose([
    ScaleIntensity(),
    AddChannel(),
    RandRotate(range_x=10.0, range_y=10.0, range_z=10.0, prob=0.5),
    RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
    #Spacing(pixdim=(2.0,2.0,2.0),mode='bilinear'),
    ToTensor()
])

val_transforms = Compose([
    ScaleIntensity(),
    AddChannel(),
    #Spacing(pixdim=(2.0,2.0,2.0),mode='bilinear'),
    ToTensor()
])

In [9]:
check_ds = NiftiDataset(image_files=trn_image, labels=trn_label, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())
im, label = monai.utils.misc.first(check_loader)
print(type(im), im.shape, label)

<class 'torch.Tensor'> torch.Size([2, 1, 128, 128, 128]) tensor([1, 1])


In [10]:
# create a training data loader
train_ds = NiftiDataset(image_files=trn_image, labels=trn_label, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())

# create a validation data loader
val_ds = NiftiDataset(image_files=val_image, labels=val_label, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())

In [18]:
device = torch.device('cuda:0')
model = densenet121(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
).to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-6)

In [19]:
model.load_state_dict(torch.load('/home/marafath/scratch/saved_models/best_test.pth'))

<All keys matched successfully>

In [None]:
# start a typical PyTorch training
val_interval = 5
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()
epc = 300 # Number of epoch
for epoch in range(epc):
    print('-' * 10)
    print('epoch {}/{}'.format(epoch + 1, epc))
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device=device, dtype=torch.int64)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item()))
        writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step)
        
        #
        if step % val_interval == 0:
            model.eval()
            with torch.no_grad():
                num_correct = 0.
                metric_count = 0
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                    val_outputs = model(val_images)
                    value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                    metric_count += len(value)
                    num_correct += value.sum().item()
                metric = num_correct / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), '/home/marafath/scratch/saved_models/best_test2.pth')
                    print('saved new best metric model')
                print('current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}'.format(
                    epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar('val_accuracy', metric, epoch + 1)
        #
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss))

    '''
    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            num_correct = 0.
            metric_count = 0
            for val_data in val_loader:
                val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                val_outputs = model(val_images)
                value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                metric_count += len(value)
                num_correct += value.sum().item()
            metric = num_correct / metric_count
            metric_values.append(metric)
            #torch.save(model.state_dict(), 'model_d121_epoch_{}.pth'.format(epoch + 1))
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), '/home/marafath/scratch/saved_models/best_metric_model_d121_test.pth')
                print('saved new best metric model')
            print('current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}'.format(
                epoch + 1, metric, best_metric, best_metric_epoch))
            writer.add_scalar('val_accuracy', metric, epoch + 1)
    '''

----------
epoch 1/300
1/734, train_loss: 0.3683
2/734, train_loss: 0.3260
3/734, train_loss: 0.3296
4/734, train_loss: 0.4846
5/734, train_loss: 0.1345
saved new best metric model
current epoch: 1 current accuracy: 0.6990 best accuracy: 0.6990 at epoch 1
6/734, train_loss: 0.2691
7/734, train_loss: 0.7893
8/734, train_loss: 1.3630
9/734, train_loss: 0.5123
10/734, train_loss: 0.7170
current epoch: 1 current accuracy: 0.6990 best accuracy: 0.6990 at epoch 1
11/734, train_loss: 1.2071
12/734, train_loss: 0.4477
13/734, train_loss: 0.1638
14/734, train_loss: 0.5658
15/734, train_loss: 0.5789
current epoch: 1 current accuracy: 0.6990 best accuracy: 0.6990 at epoch 1
16/734, train_loss: 0.2418
17/734, train_loss: 1.0758
18/734, train_loss: 0.4993
19/734, train_loss: 0.4664
20/734, train_loss: 0.9071
current epoch: 1 current accuracy: 0.6990 best accuracy: 0.6990 at epoch 1
21/734, train_loss: 0.3396
22/734, train_loss: 0.4855
23/734, train_loss: 0.2296
24/734, train_loss: 0.1677
25/734, tr

194/734, train_loss: 0.8075
195/734, train_loss: 0.3475
current epoch: 1 current accuracy: 0.6990 best accuracy: 0.6990 at epoch 1
196/734, train_loss: 0.1883
197/734, train_loss: 0.4717
198/734, train_loss: 1.1180
199/734, train_loss: 1.1304
200/734, train_loss: 0.4354
current epoch: 1 current accuracy: 0.6990 best accuracy: 0.6990 at epoch 1
201/734, train_loss: 0.7474
202/734, train_loss: 0.8342
203/734, train_loss: 0.4091
204/734, train_loss: 0.6441
205/734, train_loss: 1.0549
current epoch: 1 current accuracy: 0.6990 best accuracy: 0.6990 at epoch 1
206/734, train_loss: 0.3707
207/734, train_loss: 0.1371
208/734, train_loss: 0.3908
209/734, train_loss: 0.6686
210/734, train_loss: 0.7353
current epoch: 1 current accuracy: 0.6990 best accuracy: 0.6990 at epoch 1
211/734, train_loss: 0.7794
212/734, train_loss: 0.6235
213/734, train_loss: 0.2947
214/734, train_loss: 0.3047
215/734, train_loss: 0.5717
current epoch: 1 current accuracy: 0.6990 best accuracy: 0.6990 at epoch 1
216/734, 

385/734, train_loss: 0.8264
current epoch: 1 current accuracy: 0.6990 best accuracy: 0.6990 at epoch 1
386/734, train_loss: 0.6782
387/734, train_loss: 0.9273
388/734, train_loss: 0.3986
389/734, train_loss: 0.4380
390/734, train_loss: 0.1812
current epoch: 1 current accuracy: 0.6990 best accuracy: 0.6990 at epoch 1
391/734, train_loss: 0.3632
392/734, train_loss: 0.4211
393/734, train_loss: 0.4336
394/734, train_loss: 0.1487
395/734, train_loss: 0.7576
current epoch: 1 current accuracy: 0.6990 best accuracy: 0.6990 at epoch 1
396/734, train_loss: 0.3269
397/734, train_loss: 0.7310
398/734, train_loss: 1.1189
399/734, train_loss: 0.2815
400/734, train_loss: 0.2961
current epoch: 1 current accuracy: 0.6990 best accuracy: 0.6990 at epoch 1
401/734, train_loss: 0.2179
402/734, train_loss: 0.6123
403/734, train_loss: 0.8509
404/734, train_loss: 0.3743
405/734, train_loss: 0.4980
current epoch: 1 current accuracy: 0.6990 best accuracy: 0.6990 at epoch 1
406/734, train_loss: 0.7589
407/734, 

In [None]:
# start a typical PyTorch training
val_interval = 1
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()
epc = 300 # Number of epoch
for epoch in range(epc):
    print('-' * 10)
    print('epoch {}/{}'.format(epoch + 1, epc))
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device=device, dtype=torch.int64)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item()))
        writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss))

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            num_correct = 0.
            metric_count = 0
            for val_data in val_loader:
                val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                val_outputs = model(val_images)
                value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                metric_count += len(value)
                num_correct += value.sum().item()
            metric = num_correct / metric_count
            metric_values.append(metric)
            #torch.save(model.state_dict(), 'model_d121_epoch_{}.pth'.format(epoch + 1))
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), '/home/marafath/scratch/saved_models/best_metric_model_d121_test.pth')
                print('saved new best metric model')
            print('current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}'.format(
                epoch + 1, metric, best_metric, best_metric_epoch))
            writer.add_scalar('val_accuracy', metric, epoch + 1)

In [1]:
from collections import OrderedDict
from typing import Callable, Sequence

import torch
import torch.nn as nn

import sys
import logging
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import NiftiDataset
from monai.transforms import Compose, SpatialPad, AddChannel, ScaleIntensity, Resize, RandRotate90, RandRotate, RandZoom, ToTensor
import os
import math


from monai.networks.layers.factories import Conv, Dropout, Pool, Norm


class _DenseLayer(nn.Sequential):
    def __init__(
        self, spatial_dims: int, in_channels: int, growth_rate: int, bn_size: int, dropout_prob: float
    ) -> None:
        super(_DenseLayer, self).__init__()

        out_channels = bn_size * growth_rate
        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
        dropout_type: Callable = Dropout[Dropout.DROPOUT, spatial_dims]

        self.add_module("norm1", norm_type(in_channels))
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.add_module("conv1", conv_type(in_channels, out_channels, kernel_size=1, bias=False))

        self.add_module("norm2", norm_type(out_channels))
        self.add_module("relu2", nn.ReLU(inplace=True))
        self.add_module("conv2", conv_type(out_channels, growth_rate, kernel_size=3, padding=1, bias=False))

        if dropout_prob > 0:
            self.add_module("dropout", dropout_type(dropout_prob))

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        return torch.cat([x, new_features], 1)


class _DenseBlock(nn.Sequential):
    def __init__(
        self, spatial_dims: int, layers: int, in_channels: int, bn_size: int, growth_rate: int, dropout_prob: float
    ) -> None:
        super(_DenseBlock, self).__init__()
        for i in range(layers):
            layer = _DenseLayer(spatial_dims, in_channels, growth_rate, bn_size, dropout_prob)
            in_channels += growth_rate
            self.add_module("denselayer%d" % (i + 1), layer)


class _Transition(nn.Sequential):
    def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None:
        super(_Transition, self).__init__()

        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
        pool_type: Callable = Pool[Pool.AVG, spatial_dims]

        self.add_module("norm", norm_type(in_channels))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", conv_type(in_channels, out_channels, kernel_size=1, bias=False))
        self.add_module("pool", pool_type(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    """
    Densenet based on: "Densely Connected Convolutional Networks" https://arxiv.org/pdf/1608.06993.pdf
    Adapted from PyTorch Hub 2D version:
    https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py

    Args:
        spatial_dims: number of spatial dimensions of the input image.
        in_channels: number of the input channel.
        out_channels: number of the output classes.
        init_features: number of filters in the first convolution layer.
        growth_rate: how many filters to add each layer (k in paper).
        block_config: how many layers in each pooling block.
        bn_size: multiplicative factor for number of bottle neck layers.
                      (i.e. bn_size * k features in the bottleneck layer)
        dropout_prob: dropout rate after each dense layer.
    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        init_features: int = 64,
        growth_rate: int = 32,
        block_config: Sequence[int] = (6, 12, 24, 16),
        bn_size: int = 4,
        dropout_prob: float = 0.0,
        bins=8,
        pool_kernel=32,
        pool_stride=32
    ) -> None:

        super(DenseNet, self).__init__()

        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
        pool_type: Callable = Pool[Pool.MAX, spatial_dims]
        avg_pool_type: Callable = Pool[Pool.ADAPTIVEAVG, spatial_dims]

        self.features = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", conv_type(in_channels, init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                    ("norm0", norm_type(init_features)),
                    ("relu0", nn.ReLU(inplace=True)),
                    ("pool0", pool_type(kernel_size=3, stride=2, padding=1)),
                ]
            )
        )

        in_channels = init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(
                spatial_dims=spatial_dims,
                layers=num_layers,
                in_channels=in_channels,
                bn_size=bn_size,
                growth_rate=growth_rate,
                dropout_prob=dropout_prob,
            )
            self.features.add_module("denseblock%d" % (i + 1), block)
            in_channels += num_layers * growth_rate
            if i == len(block_config) - 1:
                self.features.add_module("norm5", norm_type(in_channels))
            else:
                _out_channels = in_channels // 2
                trans = _Transition(spatial_dims, in_channels=in_channels, out_channels=_out_channels)
                self.features.add_module("transition%d" % (i + 1), trans)
                in_channels = _out_channels

        # pooling and classification
        '''
        self.class_layers = nn.Sequential(
            OrderedDict(
                [
                    ("relu", nn.ReLU(inplace=True)),
                    ("norm", avg_pool_type(1)),
                    ("flatten", nn.Flatten(1)), 
                    ("class", nn.Linear(in_channels, out_channels)),
                ]
            )
        )
        '''
        self.relu = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(1024*4*4*4, 1024)
        self.fc2 = nn.Linear(2048, out_channels)

        for m in self.modules():
            if isinstance(m, conv_type):  # type: ignore
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, norm_type):  # type: ignore
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)
        
        # 3D-ImHistNet
        self.conv1 = nn.Conv3d(1, bins, 1, 1)
        nn.init.constant_(self.conv1.weight, 1.0)
        
        
        self.conv2 = nn.Conv3d(bins, bins, 1, 1, groups=bins)
        nn.init.constant_(self.conv2.bias, 1.0)
    
        self.avgpool = nn.AvgPool3d(pool_kernel, pool_stride)
        self.hist_fc = nn.Linear(bins*4*4*4, 1024)
        #self.fc2 = nn.Linear(1024, no_classes)

        #initialize_params(self)

    def forward(self, x):
        x1 = self.features(x)
        x1 = torch.flatten(x1, 1)
        x1 = self.fc1(x1)
        
        x2 = self.conv1(x)
        x2 = torch.abs(x2)
        x2 = self.conv2(x2)
        x2 = self.relu(x2)
        x2 = self.avgpool(x2)
        x2 = torch.flatten(x2, 1)
        x2 = self.hist_fc(x2)
        
        x_cat = torch.cat([x1, x2], 1)
        x_cat = self.fc2(x_cat)
        return x_cat


def densenet121(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
    return model


def densenet169(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
    return model


def densenet201(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
    return model


def densenet264(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 64, 48), **kwargs)
    return model

def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print('Iran data: cropped and common resized, densenet+imhistnet-8bins, fold-1')

    # Training data paths
    data_dir = '/home/marafath/scratch/iran_organized_data2'
    
    #
    covid_pat = 0
    non_covid_pat = 0

    pos_pat = []
    neg_pat = []

    for patient in os.listdir(data_dir):
        if int(patient[-1]) == 1:
            covid_pat += 1
            pos_pat.append(os.path.join(data_dir,patient))
        else:
            non_covid_pat += 1
            neg_pat.append(os.path.join(data_dir,patient))
            
    pos_val = math.floor(covid_pat/4)
    fold = 1

    for i in range(fold-1,4): # i = current fold
        p_strt = pos_val*i + 1
        p_end = pos_val*i + pos_val
        n_strt = p_strt
        n_end = p_end

        pos_resample_rate = math.floor((len(neg_pat)-pos_val)/(len(pos_pat)-pos_val))

        val_image = []
        val_label = []
        trn_image = []
        trn_label = []

        for j, pat_link in enumerate(pos_pat): 
            if j >= p_strt and j < p_end:
                for series in os.listdir(pat_link):
                    val_label.append(1)
                    val_image.append(os.path.join(pat_link,series,'cropped_and_resized_image.nii.gz'))
                    break
            else:
                for series in os.listdir(pat_link):
                    for rs in range(0, pos_resample_rate):
                        trn_label.append(1)
                        trn_image.append(os.path.join(pat_link,series,'cropped_and_resized_image.nii.gz'))

        for k, pat_link in enumerate(neg_pat): 
            if k >= n_strt and k < n_end:
                for series in os.listdir(pat_link):
                    val_label.append(0)
                    val_image.append(os.path.join(pat_link,series,'cropped_and_resized_image.nii.gz'))
                    break
            else:
                for series in os.listdir(pat_link):
                    trn_label.append(0)
                    trn_image.append(os.path.join(pat_link,series,'cropped_and_resized_image.nii.gz'))
        break  

    trn_label = np.asarray(trn_label,np.int64)
    val_label = np.asarray(val_label,np.int64)


    # Define transforms
    train_transforms = Compose([
        ScaleIntensity(),
        AddChannel(),
        RandRotate(range_x=10.0, range_y=10.0, range_z=10.0, prob=0.5),
        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
        ToTensor()
    ])
    
    val_transforms = Compose([
        ScaleIntensity(),
        AddChannel(),
        ToTensor()
    ])

    # create a training data loader
    train_ds = NiftiDataset(image_files=trn_image, labels=trn_label, transform=train_transforms)
    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())

    # create a validation data loader
    val_ds = NiftiDataset(image_files=val_image, labels=val_label, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=2, pin_memory=torch.cuda.is_available())
    
    device = torch.device('cuda:0')
    model = densenet121(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
    ).to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    
    # finetuning
    #model.load_state_dict(torch.load('best_metric_model_d121.pth'))

    # start a typical PyTorch training
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    epc = 300 # Number of epoch
    for epoch in range(epc):
        print('-' * 10)
        print('epoch {}/{}'.format(epoch + 1, epc))
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device=device, dtype=torch.int64)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item()))
            writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss))

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                num_correct = 0.
                metric_count = 0
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                    val_outputs = model(val_images)
                    value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                    metric_count += len(value)
                    num_correct += value.sum().item()
                metric = num_correct / metric_count
                metric_values.append(metric)
                #torch.save(model.state_dict(), 'model_d121_epoch_{}.pth'.format(epoch + 1))
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), '/home/marafath/scratch/saved_models/best_model_densenet_imhistnet_f1.pth')
                    print('saved new best metric model')
                print('current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}'.format(
                    epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar('val_accuracy', metric, epoch + 1)
    print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch))
    writer.close()

if __name__ == '__main__':
    main()

MONAI version: 0.1.0+626.g63eec3a.dirty
Python version: 3.7.4 (default, Jul 18 2019, 19:34:02)  [GCC 5.4.0]
Numpy version: 1.18.1
Pytorch version: 1.5.0

Optional dependencies:
Pytorch Ignite version: 0.3.0
Nibabel version: 3.1.0
scikit-image version: 0.14.2
Pillow version: 7.0.0
Tensorboard version: 2.3.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

Iran data: cropped and common resized, densenet+imhistnet-8bins, fold-1
----------
epoch 1/300
1/734, train_loss: 0.6807
2/734, train_loss: 40.1362
3/734, train_loss: 95.0193
4/734, train_loss: 30.9444
5/734, train_loss: 177.1658
6/734, train_loss: 66.8560
7/734, train_loss: 113.6423
8/734, train_loss: 51.1847
9/734, train_loss: 0.0126
10/734, train_loss: 50.4751
11/734, train_loss: 39.8669
12/734, train_loss: 56.7543
13/734, train_loss: 0.1207
14/734, train_loss: 0.0000
15/734, train_loss: 92.4143
16/734, train_loss: 12

274/734, train_loss: 1.7931
275/734, train_loss: 15.9482
276/734, train_loss: 5.8232
277/734, train_loss: 4.0864
278/734, train_loss: 15.8034
279/734, train_loss: 1.4467


KeyboardInterrupt: 