In [1]:
'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn


cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

In [2]:
class Weight_Matrix(nn.Module):
    def __init__(self):
        super(Weight_Matrix, self).__init__()
        self.layer = 0
        self.padding = nn.ZeroPad2d(2)
        temp_path = "./checkpoint/Cifar10_VGG19_No_Error.pth"
        temp_matrix = torch.load(temp_path)['net']['module.features.{}.weight'.format(self.layer)][:,:,:,:]
        temp_matrix = temp_matrix.reshape((8,8,3,3,3))
        temp_matrix = temp_matrix.permute(2,0,3,1,4)
        temp_matrix = temp_matrix.reshape((1,3,24,24))
        temp_matrix = torch.abs(temp_matrix)
        temp_matrix = [temp_matrix] * 17
        temp_matrix = torch.cat(temp_matrix, dim=0)
        self.input_matrix = self._onehot_padding(temp_matrix)
        self.weight_decision1 = nn.Sequential(
                                              nn.Conv2d(in_channels=3, out_channels=3, kernel_size=5, stride=5, padding=0, dilation=1, groups=1),
                                              nn.LayerNorm((3,24,24), elementwise_affine=True),
                                             )
        self.weight_decision2 = nn.Sequential(
                                              nn.Conv2d(in_channels=3, out_channels=3, kernel_size=5, stride=5, padding=0, dilation=1, groups=1),
                                              nn.LayerNorm((3,24,24), elementwise_affine=True),
                                             )
        self.weight_decision3 = nn.Sequential(
                                              nn.Conv2d(in_channels=3, out_channels=3, kernel_size=5, stride=5, padding=0, dilation=1, groups=1),
                                              nn.LayerNorm((3,24,24), elementwise_affine=True),
                                             )
        self.weight_decision4 = nn.Sequential(
                                              nn.Conv2d(in_channels=3, out_channels=3, kernel_size=5, stride=5, padding=0, dilation=1, groups=1),
                                              nn.LayerNorm((3,24,24), elementwise_affine=True),
                                             )
    def forward(self):
        out = self.weight_decision1(self.input_matrix)
        out = self.weight_decision2(self._onehot_padding(out))
        out = self.weight_decision3(self._onehot_padding(out))
        out = self.weight_decision4(self._onehot_padding(out))
        return out
    def _onehot_padding(self, x):
        out = x.reshape((17,3,24*24,1,1,))
        out = self.padding(out)
        out[0,:,:,0,0] = 1.0
        out[1,:,:,0,1] = 1.0
        out[2,:,:,0,2] = 1.0
        out[3,:,:,0,3] = 1.0
        out[4,:,:,0,4] = 1.0
        out[5,:,:,1,0] = 1.0
        out[6,:,:,1,1] = 1.0
        out[7,:,:,1,2] = 1.0
        out[8,:,:,1,3] = 1.0
        out[9,:,:,1,4] = 1.0
        out[10,:,:,2,0] = 1.0
        out[11,:,:,2,1] = 1.0
        out[12,:,:,2,3] = 1.0
        out[13,:,:,2,4] = 1.0
        out[14,:,:,3,0] = 1.0
        out[15,:,:,3,1] = 1.0
        out[16,:,:,3,2] = 1.0
        out = out.reshape((17,3,24,24,5,5))
        out = out.permute(0,1,2,4,3,5)
        out = out.reshape((17,3,24*5,24*5))
        return out

In [None]:
'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import time

PATH = "D:/Jupyter_Data/data"
device = "cuda"

best_loss = 10 ** 8 # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
weight_batch_size = 17
Batch_Size = 50

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root=PATH, train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=Batch_Size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root=PATH, train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=Batch_Size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
net = VGG('VGG19')
net = net.to(device)
net = torch.nn.DataParallel(net)
cudnn.benchmark = True

wet = Weight_Matrix().to(device)

quant = 0.002
# Load
layer = 0
output_list = ["No", 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
    
target_matrix = torch.empty((17,3,24,24), dtype=torch.float32)
mem_matrix = torch.empty((17,3,24,24), dtype=torch.float32).to(device)
for n, i in enumerate(output_list):
    temp_path = "./checkpoint/Cifar10_VGG19_{}_Error.pth".format(i)
    temp_matrix = torch.load(temp_path)['net']['module.features.{}.weight'.format(layer)][:,:,:,:]
    temp_matrix = temp_matrix.reshape((8,8,3,3,3))
    temp_matrix = temp_matrix.permute(2,0,3,1,4)
    temp_matrix = temp_matrix.reshape((3,24,24))
    mem_matrix[n,:,:,:] = temp_matrix / torch.abs(temp_matrix)
    target_matrix[n,:,:,:] = torch.abs(temp_matrix[:,:,:])

criterion_net = nn.CrossEntropyLoss()
criterion_wet = nn.MSELoss()
#criterion_wet = nn.L1Loss()
optimizer_net = optim.SGD(net.parameters(), lr=1e-3,
                      momentum=0.9, weight_decay=5e-4)
optimizer_wet = optim.SGD(wet.parameters(), lr=1e-2,
                      momentum=0.9, weight_decay=5e-4)
scheduler_net = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_net, T_max=200)
scheduler_wet = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_wet, T_max=17)

# Training
def train(epoch):
    wet.train()
    optimizer_wet.zero_grad()
    targets = target_matrix.to(device)
    outputs = wet()
    loss = criterion_wet(outputs, targets)
    loss.backward()
    optimizer_wet.step()
    print("Epoch: {}, Loss: {}".format(epoch, loss), end="\r")

def test(epoch):
    global best_loss
    wet.eval()
    with torch.no_grad():
        targets = target_matrix.to(device)
        outputs = wet()
        loss = criterion_wet(outputs, targets)
    if (loss < best_loss):
        best_loss = loss
        state = {
            'wet': wet.state_dict(),
            'loss': loss,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, file_path)
    
def result_test():
    wet.load_state_dict(torch.load(file_path)['wet'])
    wet.eval()
    with torch.no_grad():
        weights = wet()
    weights = weights * mem_matrix
    weights = weights.reshape((17,3,8,3,8,3))
    weights = weights.permute(0,2,4,1,3,5)
    weights = weights.reshape((17,64,3,3,3))
    
    for n, i in enumerate(output_list):
        target_path = './checkpoint/Cifar10_VGG19_{}_Error.pth'.format(i)
        state_dict = torch.load(target_path)['net']
        state_dict['module.features.{}.weight'.format(layer)] = weights[n,:,:,:,:]
        net.load_state_dict(state_dict)
        net.eval()
        print("\nError {}".format(i))
        
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                if (i == "No"):
                    inputs, targets = inputs.to(device), targets.to(device)
                    outputs = net(inputs)
                else:
                    temp = torch.empty((len(inputs),3,32,32,), dtype=torch.float32)
                    for i in range(len(inputs)):
                        temp[i,:] = transforms.functional.erase(inputs[i,:], 2*i, 0, 2, 32, 0.0)
                    temp, targets = temp.to(device), targets.to(device)
                    outputs = net(temp)
                loss = criterion_net(outputs, targets)
                
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

                print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
                    test_loss, correct, len(testloader.dataset),
                    100. * correct / (len(testloader.dataset))), end="\r")

def weight_reset(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()

Epoch = 10 ** 5

file_path = './checkpoint/Cifar10_VGG19_Weight_ConvNormConv.pth'
best_loss = 10000.0
start_time = time.time()
for epoch in range(Epoch):
    train(epoch)
    test(epoch)
result_test()
end_time = time.time()
print("Total Time = {}".format(end_time - start_time))

for key, val in wet.state_dict().items():
    print(val.shape)
print("Before Parameter Size = {}".format(17*64*3*3*3 + 17*3))
print(" After Parameter Size = {}".format(6*(3*24*24*3+225+3)))

wet.eval()
with torch.no_grad():
    targets = target_matrix.to(device)
    outputs = wet(inputs)
print(outputs[0,0,:5,:5])
print(target_matrix[0,0,:5,:5])
print(outputs[1,0,:5,:5])
print(target_matrix[1,0,:5,:5])

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
==> Building model..
Epoch: 8666, Loss: 0.0013472529826685786