In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
import pickle
from slack_notification import *
%matplotlib inline

In [2]:
class FSMod(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, conv_stride=1, conv_pad=0, conv_groups=1, conv_dilation=1):
        '''
        input ksize: (S1, S2)
        input channel: C
        FS stride: (x, y)
        output channel: filter_num
        '''
        super(FSMod, self).__init__()
        
        self.C = in_channels
        self.S1, self.S2 = kernel_size
        
        if self.S1 == 3 and self.S2 == 3:
            self.x = self.y = 2
            K3_dict = dict([[12, 2], [16, 4], [32, 2], [64, 4], [128, 4], [256, 4], [512, 8]])
            self.K3 = K3_dict[self.C]
            self.K1, self.K2 = self.get_div(out_channels // self.K3)
            self.z = int(self.C / self.K3)
        elif self.S1 == 1 and self.S2 == 1:
            self.x = self.y = 1
            self.K1 = self.K2 = 1
            self.K3 = out_channels
            self.z = int(self.C / self.K3)
        else:
            raise ValueError()
            
        self.K = self.K1*self.K2*self.K3
        assert out_channels == self.K, f'invalid filter num, K={self.K}, out_channels={out_channels}'
        

        self.FS = nn.Parameter(nn.init.kaiming_normal_(torch.zeros((self.C, self.K1*self.x, self.K2*self.y)))).cuda()
        
        self.grad_index = torch.zeros((self.FS.size()[0]*2, self.FS.size()[1]*2, self.FS.size()[2]*2)).cuda()
        
        self.conv_real = nn.Conv2d(in_channels=self.C, 
                                   out_channels=self.K, 
                                   kernel_size=(self.S1, self.S2), 
                                   stride=conv_stride, 
                                   padding=conv_pad,
                                   groups=conv_groups,
                                   dilation=conv_dilation,
                                   bias=False).cuda()
        
        self.backward_hook_handle = self.conv_real.register_backward_hook(self.backward_hook)
        
    def get_div(self, n):
        divisors = []
        for i in range(1, int(n**0.5)+1):
            if n % i == 0:
                divisors.append((i, n//i))
        divisors = sorted(divisors, key=lambda x: abs(x[0] - x[1]))
        return divisors[0]
        
    def __call__(self, input):
        """
        input: (N, C, H, W)
        """
        self.grad_index = torch.zeros((self.FS.size()[0]*2, self.FS.size()[1]*2, self.FS.size()[2]*2)).cuda()
        conv_weight = torch.zeros((self.K1*self.K2*self.K3, self.C, self.S1, self.S2)).cuda()
        combinations = [(i, j, k) for i in range(self.K1) 
                        for j in range(self.K2) 
                        for k in range(self.K3)]
        FS_extend = torch.cat([self.FS, self.FS], dim=0)
        FS_extend = torch.cat([FS_extend, FS_extend], dim=1)
        FS_extend = torch.cat([FS_extend, FS_extend], dim=2)
        
        for (k1, k2, k3) in combinations:
            conv_weight[k1 + k2 * self.K1 + k3 * self.K1 * self.K2] = FS_extend[k3*self.z:k3*self.z+self.C, 
                                                            k1*self.x:k1*self.x+self.S1,
                                                           k2*self.y:k2*self.y+self.S2]
            self.grad_index[k3*self.z:k3*self.z+self.C,
                            k1*self.x:k1*self.x+self.S1,
                            k2*self.y:k2*self.y+self.S2] += 1
        c, h, w = self.grad_index.size()
        self.grad_index[:c//2, :, :] += self.grad_index[c//2:, :, :]
        self.grad_index[:, :h//2, :] += self.grad_index[:, h//2:, :]
        self.grad_index[:, :, :w//2] += self.grad_index[:, :, w//2:]
        self.grad_index = self.grad_index[:c//2, :h//2, :w//2]
        
        self.conv_real.weight = nn.Parameter(conv_weight.cuda(), requires_grad=True)
        
        return self.conv_real(input)
    
    def backward_hook(self, module, grad_input, grad_output):
        '''
        grad_input[1] is the grad of weight of conv_real
        '''
        grad_extend = torch.zeros((self.FS.size()[0]*2, self.FS.size()[1]*2, self.FS.size()[2]*2)).cuda()
        for i, grad in enumerate(grad_input[1]):
            k1 = i%self.K1
            k2 = (i // self.K1) % self.K2
            k3 = i // (self.K1 * self.K2)
            grad_extend[k3*self.z:k3*self.z+self.C,
                        k1*self.x:k1*self.x+self.S1,
                        k2*self.y:k2*self.y+self.S2] = grad
        c, h, w = grad_extend.size()
        grad_extend[:c//2, :, :] += grad_extend[c//2:, :, :]
        grad_extend[:, :h//2, :] += grad_extend[:, h//2:, :]
        grad_extend[:, :, :w//2] += grad_extend[:, :, w//2:]
        grad_extend = grad_extend[:c//2, :h//2, :w//2]
        grad_extend = grad_extend / self.grad_index
        self.FS.grad = grad_extend
        

In [3]:
def fs3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    return FSMod(in_planes, out_planes, kernel_size=(3,3), conv_stride=stride,
                conv_pad=dilation, conv_groups=groups, conv_dilation=dilation)
    #return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
    #                padding=dilation, groups=groups, bias=False, dilation=dilation)


def fs1x1(in_planes, out_planes, stride=1):
    return FSMod(in_planes, out_planes, kernel_size=(1,1), conv_stride=stride)
    #return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.fs1 = fs3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.fs2 = fs3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.fs1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.fs2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.fs1 = fs1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.fs2 = fs3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.fs3 = fs1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.fs1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.fs2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.fs3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class FSNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(FSNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                fs1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

def make_fsnet(model_num, **kwargs):
    assert model_num in [18, 34, 50, 101, 152], "invalid model_num"
    if model_num == 18:
        model = FSNet(BasicBlock, [2,2,2,2], **kwargs)
    if model_num == 34:
        model = FSNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if model_num == 50:
        model = FSNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if model_num == 101:
        model = FSNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if model_num == 152:
        model = FSNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    return model

In [4]:
model = make_fsnet(34).cuda()

In [5]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        )
    ])


trainset = torchvision.datasets.CIFAR10(root='/data/unagi0/kamata', train=True,
                                        download=True, transform=transform)
train_bsize = 128
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

test_bsize = 100
testset = torchvision.datasets.CIFAR10(root='/data/unagi0/kamata', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1,weight_decay=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.3)
criterion = nn.CrossEntropyLoss()
max_epoch = 30

In [7]:
train_loss_all = []
train_acc_all = []
test_loss_all = []
test_acc_all = []

best_acc = 0

for epoch in range(max_epoch):
    '''train'''
    print('epoch: ', epoch)
    model = model.train()
    train_loss = 0
    train_correct = 0
    scheduler.step()
    for i, data in enumerate(tqdm(trainloader)):
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()
        inputs.requires_grad_()
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        train_loss += loss.cpu().detach().numpy() 
        _, predicted = outputs.max(1)
        train_correct += predicted.eq(labels).sum().item()
        optimizer.step()
    train_loss /= (len(trainloader) * train_bsize)
    train_accuracy = train_correct / (len(trainloader) * train_bsize)
    print('train_loss: ', train_loss)
    print('train_accuracy: ', train_accuracy)
    train_loss_all.append(train_loss)
    train_acc_all.append(train_accuracy)
    send_notification(text= f'epoch: {epoch}, train_loss: {train_loss}, train_accuracy: {train_accuracy}')
    
    '''test'''
    model = model.eval()
    test_loss = 0
    test_correct = 0
    for i, data in enumerate(tqdm(testloader)):
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.cpu().detach().numpy()
        _, predicted = outputs.max(1)
        test_correct += predicted.eq(labels).sum().item()
    test_loss /= (len(testloader) * test_bsize)
    test_accuracy = test_correct / (len(testloader) * test_bsize)
    print('test_loss: ', test_loss)
    print('test_accuracy:, ', test_accuracy)
    test_loss_all.append(test_loss)
    test_acc_all.append(test_accuracy)
    send_notification(text= f'epoch: {epoch}, test_loss: {test_loss}, test_accuracy: {test_accuracy}')
    if test_accuracy >= best_acc:
        best_acc = test_accuracy
        torch.save(model.state_dict(), '/data/unagi0/kamata/models/FSNet.pth')
        
np.savetxt('train_acc_all.txt', train_acc_all, fmt='%s', delimiter=',')
np.savetxt('train_loss_all.txt', train_loss_all, fmt='%s', delimiter=',')
np.savetxt('test_acc_all.txt', test_acc_all, fmt='%s', delimiter=',')
np.savetxt('test_loss_all.txt', test_loss_all, fmt='%s', delimiter=',')

epoch:  0


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))


train_loss:  0.0
train_accuracy:  0.0


HBox(children=(IntProgress(value=0), HTML(value='')))


test_loss:  0.0
test_accuracy:,  0.0
epoch:  1


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))


train_loss:  0.0
train_accuracy:  0.0


HBox(children=(IntProgress(value=0), HTML(value='')))


test_loss:  0.0
test_accuracy:,  0.0
epoch:  2


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))


train_loss:  0.0
train_accuracy:  0.0


KeyboardInterrupt: 