In [None]:
import torch
import torch.nn.functional as F
from torch import nn, optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
from IPython import display
import time
import sys

if torch.cuda.is_available():
    device = torch.device('cuda')
    print(torch.cuda.get_device_name(0))
else:
    device = torch.device('cpu')

In [21]:
class Mish(nn.Module):
    def __init__(self):
        super(Mish, self).__init__()

    def forward(self, x):
        x = x * (torch.tanh(F.softplus(x)))
        return x

In [22]:
class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, x):
        x = x * torch.sigmoid(x)
        return x

In [23]:
class DenseBlock(nn.Module):
    def __init__(self,
                 num_convs,
                 in_channels,
                 out_channels,
                 activation='relu'):
        super(DenseBlock, self).__init__()

        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'mish':
            self.activation = Mish()
        elif activation == 'swish':
            self.activation = Swish()
        else:
            raise NotImplementedError

        net = []
        for i in range(num_convs):
            in_c = in_channels + i * out_channels
            net.append(self._block(in_c, out_channels))
        self.net = nn.ModuleList(net)
        self.out_channels = in_channels + num_convs * out_channels

    def _block(self, in_channels, out_channels):
        blk = nn.Sequential(
            nn.BatchNorm2d(in_channels), self.activation,
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        return blk

    def forward(self, x):
        for blk in self.net:
            out = blk(x)
            x = torch.cat((x, out), dim=1)
        return x

In [24]:
class TransitionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, activation='relu'):
        super(TransitionBlock, self).__init__()

        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'mish':
            self.activation = Mish()
        elif activation == 'swish':
            self.activation = Swish()
        else:
            raise NotImplementedError

        self.net = nn.Sequential(
            nn.BatchNorm2d(in_channels), self.activation,
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.AvgPool2d(kernel_size=2, stride=2))

    def forward(self, x):
        return self.net(x)

In [33]:
class DenseNet(nn.Module):
    def __init__(self,
                 dense_block,
                 transit_block,
                 num_convs,
                 growth_rate=32,
                 activation='relu',
                 num_classes=10):
        super(DenseNet, self).__init__()
        assert len(num_convs) == 4, 'Invalid Conv Number!'

        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'mish':
            self.activation = Mish()
        elif activation == 'swish':
            self.activation = Swish()
        else:
            raise NotImplementedError

        num_channels = 64
        self.conv1 = nn.Conv2d(3,
                               num_channels,
                               kernel_size=3,
                               stride=1,
                               padding=1)
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.layer1, num_channels = self._make_layer(dense_block,
                                                     transit_block,
                                                     num_convs[0],
                                                     num_channels, growth_rate,
                                                     activation)
        self.layer2, num_channels = self._make_layer(dense_block,
                                                     transit_block,
                                                     num_convs[1],
                                                     num_channels, growth_rate,
                                                     activation)
        self.layer3, num_channels = self._make_layer(dense_block,
                                                     transit_block,
                                                     num_convs[2],
                                                     num_channels, growth_rate,
                                                     activation)
        self.layer4, num_channels = self._make_layer(dense_block,
                                                     transit_block,
                                                     num_convs[3],
                                                     num_channels,
                                                     growth_rate,
                                                     activation,
                                                     use_transit=False)
        self.bn2 = nn.BatchNorm2d(num_channels)
        self.linear = nn.Linear(num_channels, num_classes)

    def _make_layer(self,
                    dense_block,
                    transit_block,
                    num_convs,
                    num_channels,
                    growth_rate,
                    activation='relu',
                    use_transit=True):
        blk = dense_block(num_convs, num_channels, growth_rate, activation)
        num_channels = blk.out_channels
        layers = [blk]
        if use_transit:
            layers.append(
                transit_block(num_channels, num_channels // 2, activation))
            num_channels = num_channels // 2
        return nn.Sequential(*layers), num_channels

    def forward(self, x):
        out = self.bn1(self.conv1(x))
        out = self.activation(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.activation(self.bn2(out))
        out = F.avg_pool2d(out, 4)
        out = torch.flatten(out, 1)
        return self.linear(out)
        return out


net = DenseNet(DenseBlock, TransitionBlock, [4, 4, 4, 4], activation='mish')

In [None]:
def load_data_cifar_10(batch_size, resize=None, root='/tmp/CIFAR10'):
    """Download and load the CIFAR-10 dataset."""
    norm = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
                                std=(0.2023, 0.1994, 0.2010))

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        norm,
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        norm,
    ])

    cifar10_train = torchvision.datasets.CIFAR10(root=root,
                                                 train=True,
                                                 download=True,
                                                 transform=transform_train)
    cifar10_test = torchvision.datasets.CIFAR10(root=root,
                                                train=False,
                                                download=True,
                                                transform=transform_test)

    train_iter = torch.utils.data.DataLoader(cifar10_train,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=4)
    test_iter = torch.utils.data.DataLoader(cifar10_test,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=4)

    return train_iter, test_iter


batch_size = 256
train_iter, test_iter = load_data_cifar_10(batch_size)

In [None]:
def evaluate_accuracy(data_iter, net, device=None):
    if device is None and isinstance(net, torch.nn.Module):
        device = list(net.parameters())[0].device
    acc_sum, n = 0.0, 0
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(net, torch.nn.Module):
                net.eval()
                output = net(X.to(device))
                pred = output.argmax(dim=1, keepdim=True)
                acc_sum += pred.eq(y.view_as(pred).to(device)).sum().item()
                net.train()
            else:
                raise NotImplementedError
            n += y.shape[0]
    return acc_sum / n

In [None]:
def train_model(net, train_iter, test_iter, batch_size, optimizer, scheduler,
                device, num_epochs, comment='DenseNet_C10'):
    net = net.to(device)
    writer = SummaryWriter(comment=comment)
    print("training on ", device)
    loss = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.cpu().item()
            pred = y_hat.argmax(dim=1, keepdim=True)
            train_acc_sum += pred.eq(y.view_as(pred).to(device)).sum().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = evaluate_accuracy(test_iter, net)
        scheduler.step()
        print(
            'epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
            % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n,
               test_acc, time.time() - start))

        writer.add_scalar('loss', train_l_sum / batch_count, epoch + 1)
        writer.add_scalar('train acc', train_acc_sum / n, epoch + 1)
        writer.add_scalar('test acc', test_acc, epoch + 1)

In [None]:
lr, num_epochs = 0.01, 20
optimizer = optim.Adam(net.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
train_model(net, train_iter, test_iter, batch_size, optimizer, scheduler, device, num_epochs)

In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir runs/