In [None]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torchvision.datasets import CIFAR10
from torchvision.datasets import MNIST
from datetime import datetime

def conv3x3(in_channels, out_channels, stride = 1):
    return nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1, bias = False)


class adamsnet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(adamsnet, self).__init__()

        self.k1 = torch.nn.Parameter(torch.Tensor(1).uniform_(0.0, 1.0))
        self.k2 = torch.nn.Parameter(torch.Tensor(1).uniform_(0.0, 1.0))
        self.k3 = torch.nn.Parameter(torch.Tensor(1).uniform_(0.0, 1.0))
        self.k4 = torch.nn.Parameter(torch.Tensor(1).uniform_(0.0, 1.0))
        self.k5 = torch.nn.Parameter(torch.Tensor(1).uniform_(0.0, 1.0))
        self.k6 = torch.nn.Parameter(torch.Tensor(1).uniform_(0.0, 1.0))
        self.k7 = torch.nn.Parameter(torch.Tensor(1).uniform_(0.0, 1.0))
        self.k8 = torch.nn.Parameter(torch.Tensor(1).uniform_(0.0, 1.0))

        self.MaxPool = nn.MaxPool2d(2, 2)

        self.conv1x1_residual2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1)
        self.conv1x1_residual4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1)
        self.conv1x1_residual6 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1)

        self.conv0 = nn.Conv2d(in_channels, out_channels=64, kernel_size=3, stride=1, padding=1)

        self.conv1 = conv3x3(in_channels=64, out_channels=64, stride=1)
        self.bn1 = nn.BatchNorm2d(64)

        self.conv2 = conv3x3(in_channels=64, out_channels=64, stride=1)
        self.bn2 = nn.BatchNorm2d(64)

        self.conv3 = conv3x3(in_channels=64, out_channels=64, stride=1)
        self.bn3 = nn.BatchNorm2d(64)

        self.conv4 = conv3x3(in_channels=64, out_channels=64, stride=1)
        self.bn4 = nn.BatchNorm2d(64)

        self.conv5 = conv3x3(in_channels=64, out_channels=128, stride=2)
        self.bn5 = nn.BatchNorm2d(128)

        self.conv6 = conv3x3(in_channels=128, out_channels=128, stride=1)
        self.bn6 = nn.BatchNorm2d(128)

        self.conv7 = conv3x3(in_channels=128, out_channels=128, stride=1)
        self.bn7 = nn.BatchNorm2d(128)

        self.conv8 = conv3x3(in_channels=128, out_channels=128, stride=1)
        self.bn8 = nn.BatchNorm2d(128)

        self.conv9 = conv3x3(in_channels=128, out_channels=256, stride=2)
        self.bn9 = nn.BatchNorm2d(256)

        self.conv10 = conv3x3(in_channels=256, out_channels=256, stride=1)
        self.bn10 = nn.BatchNorm2d(256)

        self.conv11 = conv3x3(in_channels=256, out_channels=256, stride=1)
        self.bn11 = nn.BatchNorm2d(256)

        self.conv12 = conv3x3(in_channels=256, out_channels=256, stride=1)
        self.bn12 = nn.BatchNorm2d(256)

        self.conv13 = conv3x3(in_channels=256, out_channels=512, stride=2)
        self.bn13 = nn.BatchNorm2d(512)

        self.conv14 = conv3x3(in_channels=512, out_channels=512, stride=1)
        self.bn14 = nn.BatchNorm2d(512)

        self.conv15 = conv3x3(in_channels=512, out_channels=512, stride=1)
        self.bn15 = nn.BatchNorm2d(512)

        self.conv16 = conv3x3(in_channels=512, out_channels=512, stride=1)
        self.bn16 = nn.BatchNorm2d(512)

        self.AvgPool = nn.AvgPool2d(4)

        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        h = 1.0

        out = self.conv0(x)
        out_conv0 = out

        out = self.conv1(out)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)
        out_conv2 = out
        out += out_conv0
        out = F.relu(out)
        out_residual1 = out

        out = self.conv3(out)
        out = self.bn3(out)
        out = F.relu(out)
        out = self.conv4(out)
        out = self.bn4(out)
        out = F.relu(out)
        out_conv4 = self.MaxPool(out)
        out_conv4 = self.conv1x1_residual2(out_conv4)
        out = h * (1.0 - self.k2) * out + out_residual1 + h * self.k2 * out_conv2
        out = F.relu(out)
        out_residual2 = self.MaxPool(out)
        out_residual2 = self.conv1x1_residual2(out_residual2)

        out = self.conv5(out)
        out = self.bn5(out)
        out = F.relu(out)
        out = self.conv6(out)
        out = self.bn6(out)
        out = F.relu(out)
        out_conv6 = out
        out = h * (1.0 - self.k3) * out + out_residual2 + h * self.k3 * out_conv4
        out = F.relu(out)
        out_residual3 = out

        out = self.conv7(out)
        out = self.bn7(out)
        out = F.relu(out)
        out = self.conv8(out)
        out = self.bn8(out)
        out = F.relu(out)
        out_conv8 = self.MaxPool(out)
        out_conv8 = self.conv1x1_residual4(out_conv8)
        out = h * (1.0 - self.k4) * out + out_residual3 + h * self.k4 * out_conv6
        out = F.relu(out)
        out_residual4 = self.MaxPool(out)
        out_residual4 = self.conv1x1_residual4(out_residual4)

        out = self.conv9(out)
        out = self.bn9(out)
        out = F.relu(out)
        out = self.conv10(out)
        out = self.bn10(out)
        out = F.relu(out)
        out_conv10 = out
        out = h * (1.0 - self.k5) * out + out_residual4 + h * self.k5 * out_conv8
        out = F.relu(out)
        out_residual5 = out

        out = self.conv11(out)
        out = self.bn11(out)
        out = F.relu(out)
        out = self.conv12(out)
        out = self.bn12(out)
        out = F.relu(out)
        out_conv12 = self.MaxPool(out)
        out_conv12 = self.conv1x1_residual6(out_conv12)
        out = h * (1.0 - self.k6) * out + out_residual5 + h * self.k6 * out_conv10
        out = F.relu(out)
        out_residual6 = self.MaxPool(out)
        out_residual6 = self.conv1x1_residual6(out_residual6)

        out = self.conv13(out)
        out = self.bn13(out)
        out = F.relu(out)
        out = self.conv14(out)
        out = self.bn14(out)
        out = F.relu(out)
        out_conv14 = out
        out = h * (1.0 - self.k7) * out + out_residual6 + h * self.k7 * out_conv12
        out = F.relu(out)
        out_residual7 = out

        out = self.conv15(out)
        out = self.bn15(out)
        out = F.relu(out)
        out = self.conv16(out)
        out = self.bn16(out)
        out = F.relu(out)
        out_conv16 = out
        out = h * (1.0 - self.k8) * out + out_residual7 + h * self.k8 * out_conv14
        out = F.relu(out)

        out = self.AvgPool(out)

        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

#######################################
# 测试
test_net = adamsnet(3, 10)
test_x = Variable(torch.zeros(1, 3, 32, 32))
test_y = test_net(test_x)
print('output: {}'.format(test_y.shape))
#########################################




def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().item()
    return num_correct / total



losses = []
acces = []
eval_losses = []
eval_acces = []

def train(net, train_data, test_data, num_epochs, optimizer, criterion):
    if torch.cuda.is_available():
        net = net.cuda()

    prev_time = datetime.now()

    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        net = net.train()

        for step, (im, label) in enumerate(train_data):
            if torch.cuda.is_available():
                im = Variable(im.cuda())  # (bs, 3, h, w)
                label = Variable(label.cuda())  # (bs, h, w)
            else:
                im = Variable(im)
                label = Variable(label)
            # forward
            output = net(im)
            loss = criterion(output, label)
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_acc += get_acc(output, label)

            losses.append(loss.item())
            acces.append(train_acc / len(train_data))

        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)

        验证集
        eval_loss = 0
        eval_acc = 0
        net = net.eval()
        
        for im, label in test_data:
            if torch.cuda.is_available():
                im = Variable(im.cuda())
                label = Variable(label.cuda())
            else:
                im = Variable(im)
                label = Variable(label)
        
            output = net(im)
            loss = criterion(output, label)
        
            eval_loss += loss.item()
            eval_acc += get_acc(output, label)
        
        eval_losses.append(eval_loss / len(test_data))
        eval_acces.append(eval_acc / len(test_data))

        epoch_str = (
                "Epoch %d. Train Loss: %f, Train Acc: %f, Eval Loss: %f, Eval Acc: %f, "
                % (epoch, train_loss / len(train_data),
                   train_acc / len(train_data), eval_loss / len(test_data),
                   eval_acc / len(test_data)))

        epoch_str = (
                "Epoch %d. Train Loss: %f, Train Acc: %f"
                % (epoch, train_loss / len(train_data), train_acc / len(train_data)))

        prev_time = cur_time
        print(epoch_str + time_str)



def data_tf(x):
    x = x.resize((32, 32), 2)
    x = np.array(x, dtype='float32') / 255
    x = (x - 0.5) / 0.5
    x = torch.from_numpy(x)
    x = x.unsqueeze(0)
    # 加噪
    noise = torch.randn(x.size()) * 0.3
    x = x + Variable(noise, requires_grad=False)
    return x


train_set = MNIST('./mnist', train=True, transform=data_tf, download=False)
train_data = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)
test_set = MNIST('./mnist', train=False, transform=data_tf, download=False)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = adamsnet(1, 10)
optimizer = torch.optim.Adam(net.parameters(), lr=0.01, betas=(0.9, 0.99))
criterion = nn.CrossEntropyLoss()



# 画图
im,la = train_set[0]
print(im.shape,'label:',la)
plt.imshow(im.squeeze(0))
plt.show()

# 训练
train(net, train_data, test_data, 20, optimizer, criterion)


# Loss曲线
plt.plot(losses, color = 'red', label = 'train_loss')
plt.xlabel('Iter')
plt.ylabel('Training Loss')
plt.ylim(-0.2, 3)
plt.legend(loc = 'best')
plt.show()