In [1]:
import torch
from torch import nn
from torch import optim
import torch.utils.data as Data

import torchvision
from torchvision import transforms

import numpy as np

import matplotlib.pyplot as plt

from tensorboardX import SummaryWriter

from tqdm import tqdm
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import shutil

print(torch.__version__)
# torch.manual_seed(1)

BATCH_SIZE = 64
DOWNLOAD_MNIST = True
preceed=False
log='log1'

1.4.0


In [2]:
transformation = [
    #transforms.RandomGrayscale(p=1.0),
    transforms.RandomCrop(32),
    #transforms.ColorJitter(0.1,0.1,0.1,0.1),
    transforms.RandomHorizontalFlip(p=1.0),
    #transforms.RandomVerticalFlip(p=1.0), 
    ]
random_transformation = [transforms.RandomChoice(transformation)]
augmentation_transform = transforms.Compose(
    [transforms.RandomApply(random_transformation, p=1),
     transforms.ToTensor(), ])
     #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])
normal_transform = transforms.Compose(
    [transforms.ToTensor(), ])
     #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])

trainset = torchvision.datasets.CIFAR10(root='./mnist', train=True,
                                        download=True, transform=normal_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

augmentation_trainset = torchvision.datasets.CIFAR10(root='./mnist', train=True,
                                        download=True, transform=augmentation_transform)
augmentation_trainloader = torch.utils.data.DataLoader(augmentation_trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./mnist', train=False,
                                       download=True, transform=normal_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=True, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [None]:
import matplotlib.pyplot as plt


# functions to show an image


def imshow(img):
    img = img# * 0.5 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(testloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(BATCH_SIZE)))
#print(images[0])
#img = torch.ones((3,32,32))
#imshow(torchvision.utils.make_grid(img))
print(images[0].size())
print(torch.mean(images,dim=[2,3],keepdim=True).size())
print(torch.std (images,dim=[2,3],keepdim=True).size())

class CNN(nn.Module):
    def __init__(
        self, 
        B = True, 
        C = False, 
        L = 12,
        k = 32,
        theta = 0.5,
        first_layer = 16
    ):
        super(CNN, self).__init__()
        self.L = L
        self.k = k
        self.B = B
        self.C = C
        self.theta = theta
        self.first_layer = first_layer
        self.in_channels = [self.first_layer] + [k] + [k]
        self.conv = nn.ModuleList( [
            nn.Sequential(
                nn.Conv2d(3, self.in_channels[0], 7, 2, 3,bias=False), ),
            nn.Sequential(                
                nn.BatchNorm2d(int(self.in_channels[1] * (theta if C else 1))),
                nn.ReLU(),
                nn.Conv2d(self.in_channels[1], int(self.in_channels[2] * (theta if C else 1)), 1, 1, 1,bias=False), ),
            nn.Sequential(
                nn.BatchNorm2d(int(self.in_channels[2] * (theta if C else 1))),
                nn.ReLU(),
                nn.Conv2d(k, int(self.in_channels[2] * (theta if C else 1)), 1, 1, 1,bias=False), 
            ),
                ] )
        self.denseblock = nn.ModuleList(
            [self._build_denseblock( 
                int(self.in_channels[i] * (theta if C else 1)) if i > 0 else self.in_channels[0] ) 
             for i in range(len(self.in_channels))]
        )
        #self.DenseBlock = nn.ModuleList(denseblock)
        
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(1)
        self.out = nn.Sequential( 
            nn.Linear(k, 1024,bias=False),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024,10,bias=False), )
        #self.dropout = nn.Dropout(0.5)
    def _build_denseblock(self, n_in_channel) :
        modules = []
        layers = [
            nn.BatchNorm2d(n_in_channel),
            nn.ReLU(), 
            nn.Conv2d(n_in_channel, self.k, 1, 1, 0,bias=False)] if self.B else []
        layers.extend([
            nn.BatchNorm2d(self.k),
            nn.ReLU(), 
            nn.Conv2d(self.k, self.k, 3, 1, 1,bias=False), ])
        modules.append(nn.Sequential(*layers))
        for l in range(1,self.L) :
            layers = [nn.Conv2d(n_in_channel+self.k*(l), n_in_channel+self.k*(l), 1, 1, 0, bias=False)] if self.B else []
            layers.extend([
                nn.BatchNorm2d(n_in_channel+self.k*(l)),
                nn.ReLU(), 
                nn.Conv2d(n_in_channel+self.k*(l), self.k, 3, 1, 1,bias=False),
                nn.Dropout(0.5), ])
            modules.append(nn.Sequential(*layers))
        return nn.ModuleList(modules)
    def _denseblock_feedforward(self, xs, modulelist):
        cat = xs
        for layer in modulelist :
            output = layer(cat)
            cat = torch.cat( ( cat, output ), 1 )
        return output
    def forward(self, xs):
        output = xs
        for i in range(3) :
            #print(output.size())
            output = self.conv[i](output)
            #print(output.size(), end='\n\n')
            pool = self.pool2 if i else self.pool1 
            output = pool(output)
            output = self._denseblock_feedforward(output, self.denseblock[i])
        AdaptiveAvgPool = self.AdaptiveAvgPool(output).squeeze()
        #print(AdaptiveAvgPool.size())
        return self.out(AdaptiveAvgPool)


In [3]:
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict
#from .utils import load_state_dict_from_url
from torch import Tensor
from torch.jit.annotations import List


__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']

model_urls = {
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}


class _DenseLayer(nn.Module):
    def __init__(self, num_input_features, growth_rate, 
                 bn_size, drop_rate, bias , memory_efficient=True):
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
                                           growth_rate, kernel_size=1, stride=1, bias=bias,)),
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1, bias=bias,)),
        self.drop_rate = float(drop_rate)
        self.memory_efficient = memory_efficient

    def bn_function(self, inputs):
        # type: (List[Tensor]) -> Tensor
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
        return bottleneck_output

    # todo: rewrite when torchscript supports any
    def any_requires_grad(self, input):
        # type: (List[Tensor]) -> bool
        for tensor in input:
            if tensor.requires_grad:
                return True
        return False
    
    @torch.jit.unused  # noqa: T484
    def call_checkpoint_bottleneck(self, input):
        # type: (List[Tensor]) -> Tensor
        def closure(*inputs):
            return self.bn_function(*inputs)

        return cp.checkpoint(closure, input)

    @torch.jit._overload_method  # noqa: F811
    def forward(self, input):
        # type: (List[Tensor]) -> (Tensor)
        pass

    @torch.jit._overload_method  # noqa: F811
    def forward(self, input):
        # type: (Tensor) -> (Tensor)
        pass

    # torchscript does not yet support *args, so we overload method
    # allowing it to take either a List[Tensor] or single Tensor
    def forward(self, input):  # noqa: F811
        if isinstance(input, Tensor):
            prev_features = [input]
        else:
            prev_features = input

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("Memory Efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate,
                                     training=self.training, inplace=True)
        return new_features


class _DenseBlock(nn.ModuleDict):
    _version = 2

    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, 
                 bias, memory_efficient=True,):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
                bias=bias,
            )
            self.add_module('denselayer%d' % (i + 1), layer)

    def forward(self, init_features):
        features = [init_features]
        #print(init_features.size())
        for name, layer in self.items():
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features, 1)


class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features,bias):
        super(_Transition, self).__init__()
        #self.add_module('ReflectionPad', nn.ReflectionPad2d(2))
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=bias,))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
    """

    def __init__(self, growth_rate=(12,12,12), block_config=(40,40,40),
                 num_init_features=16, bn_size=4, drop_rate=0.2, num_classes=10,
                 memory_efficient=False, bias=False):

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=1,
                                padding=3,
                                bias=bias, )),
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            #('pool0', nn.AvgPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, (k, num_layers) in enumerate(zip(growth_rate, block_config)):
            print(k,num_layers)
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=k,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
                bias=bias
            )
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * k
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=num_features // 2, bias=bias)
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Sequential(
            nn.Linear(num_features, num_classes, bias=True,) )

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        #print(out.size())
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out


In [4]:
LR = 1e-1

cnn = DenseNet()
# !!!!!!!! Change in here !!!!!!!!! #
cnn.cuda()      # Moves all model parameters and buffers to the GPU.
#cnn = torch.nn.DataParallel(cnn,device_ids=[0,1]).cuda()
if preceed :
    cnn.load_state_dict(torch.load("DenseNetL=100,k=12"))

SGD     = torch.optim.SGD
Adagrad = torch.optim.Adagrad
Adam    = torch.optim.Adam

opt = SGD

optimizer = opt(cnn.parameters(), lr=LR, weight_decay=1e-4, momentum=0.9)
#optimizer = opt(cnn.parameters(), lr=LR, weight_decay=1e-3)
#semantic_optimizer = torch.optim.Adagrad(cnn.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()

epoch = 0

shutil.rmtree(log)
writer = SummaryWriter(log)

12 40
12 40
12 40


In [None]:
LR = 1e-2
optimizer = opt(cnn.parameters(), lr=LR, momentum = 0.9, weight_decay=1e-4)
#cnn

In [None]:
epoch_duration = 1
max_epoch = 300

for epoch in range(epoch, max_epoch):  # loop over the dataset multiple times

    running_loss = 0.0
    training_accuracy_list = []
    #semantic_loss_list = []
    cnn.train()
    if epoch > 1 and (epoch % 50 == 0 or epoch % 75 == 0) :
        LR/=10;
        optimizer = opt(cnn.parameters(), lr=LR, momentum=0.9, weight_decay=1e-4)
        print("Optimizer updated.")
        
    for i, data in tqdm(enumerate(trainloader, 1)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.cuda()
        inputs = ( inputs - inputs.mean(dim=[2,3], keepdim=True) ) \
                    / inputs.std(dim=[2,3], keepdim=True) * 1e-2
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        outputs = cnn(inputs)
        ys = outputs.argmax(dim = 1).detach().cpu()
        result = torch.eq(ys,labels)
        training_accuracy_list.append(result.type(torch.FloatTensor).mean().item())
        crossEntropyLoss = criterion(outputs, labels.cuda())
        crossEntropyLoss.backward()
        optimizer.step()
        
        # print statistics
        running_loss += crossEntropyLoss.detach().cpu().item()
        
#     for i, data in tqdm(enumerate(augmentation_trainloader, 1)):
#         # get the inputs; data is a list of [inputs, labels]
#         inputs, labels = data
#         inputs = inputs.cuda()
#         inputs = ( inputs - inputs.mean(dim=[2,3], keepdim=True) ) \
#                     / inputs.std(dim=[2,3], keepdim=True) * 1e-2
        
#         # zero the parameter gradients
#         optimizer.zero_grad()
        
#         # forward + backward + optimize
#         outputs = cnn(inputs)
#         ys = outputs.argmax(dim = 1).detach().cpu()
#         result = torch.eq(ys,labels)
#         training_accuracy_list.append(result.type(torch.FloatTensor).mean().item())
#         crossEntropyLoss = criterion(outputs, labels.cuda())
#         crossEntropyLoss.backward()
#         optimizer.step()
        
#         # print statistics
#         running_loss += crossEntropyLoss.detach().cpu().item()
        
    if i % epoch_duration == (epoch_duration - 1):    # print every 2000 mini-batches
        #writer.add_scalar('Training/Loss', running_loss / i, epoch + 1)
        print('[%d, %5d] Training loss: %.4f    Training Accuracy : %.4f'% # Semantic loss : %.4f' %
              (epoch + 1, 
               i, 
               running_loss / i, 
               np.mean(training_accuracy_list), )
               #np.mean(semantic_loss_list) )
             ) 
        #running_loss = 0.0
    with torch.no_grad():
        testing_loss = 0.0
        testing_accuracy_list = []
        cnn.eval()
        for j, data in tqdm( enumerate(testloader, 1) ):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.cuda()
            inputs = ( inputs - inputs.mean(dim=[2,3], keepdim=True) ) \
                        / inputs.std(dim=[2,3], keepdim=True) * 1e-2
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            outputs = cnn(inputs)
            ys = outputs.argmax(dim = 1).detach().cpu()
            result = torch.eq(ys, labels)
            testing_accuracy_list.append(result.type(torch.FloatTensor).mean().item())
            loss = criterion(outputs, labels.cuda()).detach().cpu().item()
            #loss.backward()
            #optimizer.step()

            testing_loss += loss

    print('[%d, %5d]  Testing loss: %.4f    Testing  Accuracy : %.4f\n' % 
          ( epoch + 1, 
           j, 
           testing_loss / j, 
           np.mean(testing_accuracy_list) ))
    writer.add_scalars('cifar10/Accuracy', 
                       {'Training' : np.mean(training_accuracy_list),
                        'Testing'  : np.mean(testing_accuracy_list), 
                       }, epoch + 1)
    writer.add_scalars('cifar10/loss', 
                       {'Training' : running_loss / i, 
                        'Testing'  : testing_loss / j, 
                        #'Semantic' : np.mean(semantic_loss_list),    
                       }, epoch + 1)
        
print('Finished Training')

782it [05:42,  2.28it/s]

[1,   782] Training loss: 1.6234    Training Accuracy : 0.4204



157it [00:20,  7.50it/s]

[1,   157]  Testing loss: 1.4673    Testing  Accuracy : 0.4889




782it [05:42,  2.29it/s]

[2,   782] Training loss: 1.1295    Training Accuracy : 0.6013



157it [00:20,  7.60it/s]

[2,   157]  Testing loss: 1.1254    Testing  Accuracy : 0.6109




782it [05:41,  2.29it/s]

[3,   782] Training loss: 0.8610    Training Accuracy : 0.7009



157it [00:20,  7.59it/s]

[3,   157]  Testing loss: 0.9184    Testing  Accuracy : 0.7024




782it [05:41,  2.29it/s]

[4,   782] Training loss: 0.6730    Training Accuracy : 0.7671



157it [00:20,  7.57it/s]

[4,   157]  Testing loss: 0.9644    Testing  Accuracy : 0.7123




782it [05:42,  2.28it/s]

[5,   782] Training loss: 0.5652    Training Accuracy : 0.8055



157it [00:20,  7.58it/s]

[5,   157]  Testing loss: 0.6370    Testing  Accuracy : 0.7892




214it [01:33,  2.28it/s]

In [None]:
#cnn = cnn.cpu()
torch.save(cnn.state_dict(), "DenseNetL=100,k=12")

In [None]:
cnn.load_state_dict(torch.load("DenseNetL=100,k=12"))

In [3]:
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict
#from .utils import load_state_dict_from_url
from torch import Tensor
from torch.jit.annotations import List


__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']

model_urls = {
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}


class _DenseLayer(nn.Module):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=True):
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
                                           growth_rate, kernel_size=1, stride=1,
                                           bias=False)),
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1,
                                           bias=False)),
        self.drop_rate = float(drop_rate)
        self.memory_efficient = memory_efficient

    def bn_function(self, inputs):
        # type: (List[Tensor]) -> Tensor
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
        return bottleneck_output

    # todo: rewrite when torchscript supports any
    def any_requires_grad(self, input):
        # type: (List[Tensor]) -> bool
        for tensor in input:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused  # noqa: T484
    def call_checkpoint_bottleneck(self, input):
        # type: (List[Tensor]) -> Tensor
        def closure(*inputs):
            return self.bn_function(*inputs)

        return cp.checkpoint(closure, input)

    @torch.jit._overload_method  # noqa: F811
    def forward(self, input):
        # type: (List[Tensor]) -> (Tensor)
        pass

    @torch.jit._overload_method  # noqa: F811
    def forward(self, input):
        # type: (Tensor) -> (Tensor)
        pass

    # torchscript does not yet support *args, so we overload method
    # allowing it to take either a List[Tensor] or single Tensor
    def forward(self, input):  # noqa: F811
        if isinstance(input, Tensor):
            prev_features = [input]
        else:
            prev_features = input
        #print(self.any_requires_grad(prev_features))
        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("Memory Efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate,
                                     training=self.training)
        return new_features


class _DenseBlock(nn.ModuleDict):
    _version = 2

    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=True):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
            self.add_module('denselayer%d' % (i + 1), layer)

    def forward(self, init_features):
        features = [init_features]
        for name, layer in self.items():
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features, 1)


class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
    """

    def __init__(self, growth_rate=12, block_config=(40, 40, 40),
                 num_init_features=16, bn_size=4, drop_rate=0.2, num_classes=10, memory_efficient=True):

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=1,
                                padding=3, bias=False)),
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            #('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient
            )
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=num_features // 2)
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out
