In [0]:
!nvidia-smi

Fri Jan 24 18:38:17 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.44       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   69C    P8    11W /  70W |      0MiB / 15079MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [0]:
mount = '/content/gdrive'

from google.colab import drive
drive.mount(mount)

RESULT_PATH = '{}/My Drive/result/'.format(mount)

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
# model architecture
CLASSES = 10
FIRST_CH = 8
FEATURE_DIMS = 32
N_KERNELS = 4
CYCLE = 4

# hyper params
EPOCHS = 50
BATCH_SIZE = 300
LEARNING_RATE = 1e-2

# extras
LOG_INTERVAL = 200
GPU = 0
ALPHA = 0.03
START = 1
LOAD_MODEL = True

minimum_loss = 100.

In [0]:
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms

In [0]:
def zeropad(x, ch):
    return F.pad(x, (0, 0, 0, 0, 0, ch-x.size(1), 0, 0))

In [0]:
class Dense(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Dense, self).__init__()
        self.main = nn.Sequential(
            nn.LeakyReLU(0.2),
            nn.Linear(in_ch, out_ch)
        )
        
    def forward(self, x):
        return self.main(x)

In [0]:
class SNDense(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(SNDense, self).__init__()
        self.main = nn.Sequential(
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Linear(in_ch, out_ch))
        )
        
    def forward(self, x):
        return self.main(x)

In [0]:
class BNConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=1, stride=1, padding=0):
        super(BNConv, self).__init__()
        self.main = nn.Sequential(
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_ch, out_ch, kernel, stride, padding),
            nn.BatchNorm2d(out_ch)
        )
        
    def forward(self, x, features=None):
        return self.main(x)

In [0]:
class SNConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=1, stride=1, padding=0):
        super(SNConv, self).__init__()
        self.main = nn.Sequential(
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(in_ch, out_ch, kernel, stride, padding))
        )
        
    def forward(self, x, features=None):
        return self.main(x)

In [0]:
class CondConv(nn.Module):
    def __init__(self, feature_dims, n_kernels, kernel=3, stride=1, padding=1, device=None, k=1):
        super(CondConv, self).__init__()
        self.kernel = kernel
        self.stride = stride
        self.padding = padding
        self.kernels_size = (n_kernels, kernel, kernel)

        self.kernels = nn.Parameter(torch.empty(n_kernels*kernel*kernel, k, 1))
        nn.init.xavier_normal_(self.kernels)

        self.activate = nn.Tanh()
    
    def forward(self, x, features):
        b_size = x.size(0)

        h = F.leaky_relu(x, negative_slope=0.2)
        kernels = self.kernels.expand(-1, -1, h.size(1))
        i = torch.eye(h.size(1)).to(device).view(1, h.size(1), h.size(1)).expand(kernels.size(0), -1, -1)
        kernels = i*kernels
        kernels = kernels.view(*self.kernels_size, h.size(1), -1)
        kernels = kernels.transpose(1, 3).transpose(2, 4)
        kernels = kernels.expand(b_size, -1, -1, -1, -1, -1)

        f = self.activate(features)
        f = f.view(*kernels.size()[:3], 1, 1, 1)
        f = f.expand(-1, -1, -1, h.size(1), -1, -1)
        f = torch.sum(kernels*f, dim=1)

        f = f.reshape(b_size, h.size(1), -1)

        h = F.unfold(h, self.kernel, padding=self.padding, stride=self.stride)
        h = torch.bmm(f, h)
        h = h.view(*h.size()[:2], *x.size()[2:])
        return h

In [0]:
class Net(nn.Module):
    def __init__(self, classes, ch, feature_dims, n_kernels, cycle, device=None):
        super(Net, self).__init__()
        self.ch = [[ch, ch*2, ch*3],
                   [ch*4, ch*5, ch*6],
                   [ch*7, ch*8, ch*9],
                   [ch*10, ch*11, ch*12]]

        self.cc = CondConv(feature_dims, n_kernels, device=device)
        self.first_noise = nn.Parameter(torch.empty(1, feature_dims))
        nn.init.normal_(self.first_noise)
        
        self.extractor = nn.ModuleList(
            [nn.ModuleList(
                [SNConv(ch, ch*2),
                 SNConv(ch*2, ch*3),
                 SNConv(ch*3, ch*4)]
            ),
            nn.ModuleList(
                [SNConv(ch*4, ch*5),
                 SNConv(ch*5, ch*6),
                 SNConv(ch*6, ch*7)]
            ),
            nn.ModuleList(
                [SNConv(ch*7, ch*8),
                 SNConv(ch*8, ch*9),
                 SNConv(ch*9, ch*10)]
            ),
            nn.ModuleList(
                [SNConv(ch*10, ch*11),
                 SNConv(ch*11, ch*12),
                 SNConv(ch*12, ch*12)]
            )]
        )

        self.excites = nn.ModuleList(
            [nn.ModuleList(
                [nn.ModuleList(
                    [SNDense(feature_dims, n_kernels*__ch) for __ch in _ch]
                ) for _ch in self.ch]
            ) for _ in range(cycle)]
        )
        
        self.combines = nn.ModuleList(
            [copy.deepcopy(self.extractor) for _ in range(cycle)]
        )

        self.pool = nn.ModuleList(
            [nn.AdaptiveMaxPool2d((14, 14)),
             nn.AdaptiveMaxPool2d((8, 8)),
             nn.AdaptiveMaxPool2d((4, 4)),
             nn.Sequential(
                 nn.AvgPool2d(4),
                 nn.Flatten()
                 )]
            )

        self.feature = nn.Sequential(
            SNDense(self.ch[-1][-1], feature_dims*2),
            SNDense(feature_dims*2, feature_dims)
        )
        self.features = nn.ModuleList(
            [copy.deepcopy(self.feature) for _ in range(cycle)]
        )
        self.buttom = SNDense(feature_dims, classes)
    
    def forward(self, x, cycle):
        _x = x.expand(-1, self.ch[0][0], -1, -1)
        h = _x
        f = self.first_noise.expand(x.size(0), -1)

        for excite, combine, feature in zip(self.excites[:cycle], self.combines[:cycle], self.features[:cycle]):
            h = _x
            for _excite, _combine, pool in zip(excite, combine, self.pool):
                for __excite, __combine in zip(_excite, _combine):
                    _h = self.cc(h, __excite(f))
                    _h = __combine(_h)
                    h = zeropad(h, _h.size(1)) + _h
                h = pool(h)
            f = f + feature(h)

        out = self.buttom(f)
        return F.log_softmax(out, dim=1)

In [0]:
def train(model, device, train_loader, optimizer, epoch, cycle):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data, cycle)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [0]:
def test(model, device, test_loader, cycle):
    global minimum_loss
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data, cycle)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    if test_loss <= minimum_loss:
        torch.save(model.state_dict(), RESULT_PATH+'condconv_mnist.pkl')
        minimum_loss = test_loss

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [0]:
device = torch.device("cuda:{}".format(GPU))

In [0]:
def create_trainloader(alpha):
    kwargs = {'num_workers': 1, 'pin_memory': True}
    train_loader = torch.utils.data.DataLoader(
        dset.MNIST('./data', train=True, download=True,
                        transform=transforms.Compose([
                            transforms.RandomAffine(alpha),
                            transforms.RandomPerspective(alpha),
                            transforms.RandomRotation(alpha),
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                        ])),
        batch_size=BATCH_SIZE, shuffle=True, **kwargs)
    return train_loader

In [0]:
kwargs = {'num_workers': 1, 'pin_memory': True}
test_loader = torch.utils.data.DataLoader(
    dset.MNIST('./data', train=False, download=True,
               transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)

In [0]:
model = Net(CLASSES, FIRST_CH, FEATURE_DIMS, N_KERNELS, CYCLE, device).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [0]:
if LOAD_MODEL:
    model.load_state_dict(torch.load(RESULT_PATH+'condconv_mnist.pkl'))
    model.eval()

In [0]:
def round(cycle, model, device, train_loader, test_loader, optimizer, epochs):
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch, cycle)
        test(model, device, test_loader, cycle)

In [0]:
for c in range(START, CYCLE + 1):
    alpha = ALPHA * c
    train_loader = create_trainloader(alpha)
    round(c, model, device, train_loader, test_loader, optimizer, EPOCHS)


Test set: Average loss: 0.1697, Accuracy: 9462/10000 (95%)


Test set: Average loss: 0.0836, Accuracy: 9738/10000 (97%)


Test set: Average loss: 0.0504, Accuracy: 9837/10000 (98%)


Test set: Average loss: 0.0357, Accuracy: 9891/10000 (99%)


Test set: Average loss: 0.0315, Accuracy: 9895/10000 (99%)


Test set: Average loss: 0.0282, Accuracy: 9911/10000 (99%)


Test set: Average loss: 0.0289, Accuracy: 9915/10000 (99%)


Test set: Average loss: 0.0275, Accuracy: 9912/10000 (99%)


Test set: Average loss: 0.0205, Accuracy: 9934/10000 (99%)


Test set: Average loss: 0.0252, Accuracy: 9917/10000 (99%)


Test set: Average loss: 0.0244, Accuracy: 9920/10000 (99%)


Test set: Average loss: 0.0205, Accuracy: 9933/10000 (99%)


Test set: Average loss: 0.0300, Accuracy: 9910/10000 (99%)


Test set: Average loss: 0.0230, Accuracy: 9929/10000 (99%)


Test set: Average loss: 0.0237, Accuracy: 9931/10000 (99%)


Test set: Average loss: 0.0233, Accuracy: 9935/10000 (99%)


Test set: Average loss: