In [1]:
import time
import torch
from torch import nn,optim
import torch.nn.functional as F
import torchvision 
import torchvision.transforms as transforms

In [9]:
def nin_block(in_channels, out_channels, kernel_size, stride, padding):
    blk = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.ReLU(), 
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU(), 
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU())
    return blk

In [10]:
class GlobalAvgPool2d(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        return F.avg_pool2d(x,kernel_size=x.size()[2:])

In [11]:
class FlattenLayer(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        return x.view(x.shape[0],-1)

In [12]:
net = nn.Sequential(nin_block(1, 96, kernel_size=11, stride=4, padding=0),
                    nn.MaxPool2d(kernel_size=3, stride=2),
                    nin_block(96, 256, kernel_size=4, stride=1, padding=2),
                    nn.MaxPool2d(kernel_size=3, stride=2),
                    nin_block(256, 384, kernel_size=3, stride=1, padding=1),
                    nn.MaxPool2d(kernel_size=3, stride=2), nn.Dropout(0.5),
                    nin_block(384, 10, kernel_size=3, stride=1, padding=1),
                    GlobalAvgPool2d(), 
                    FlattenLayer())

In [13]:
x=torch.rand(1,1,224,224)
for name,blk in net.named_children():
    x=blk(x)
    print(name,'output shape: ',x.shape)

0 output shape:  torch.Size([1, 96, 54, 54])
1 output shape:  torch.Size([1, 96, 26, 26])
2 output shape:  torch.Size([1, 256, 27, 27])
3 output shape:  torch.Size([1, 256, 13, 13])
4 output shape:  torch.Size([1, 384, 13, 13])
5 output shape:  torch.Size([1, 384, 6, 6])
6 output shape:  torch.Size([1, 384, 6, 6])
7 output shape:  torch.Size([1, 10, 6, 6])
8 output shape:  torch.Size([1, 10, 1, 1])
9 output shape:  torch.Size([1, 10])


In [14]:
print(net)

Sequential(
  (0): Sequential(
    (0): Conv2d(1, 96, kernel_size=(11, 11), stride=(4, 4))
    (1): ReLU()
    (2): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))
    (3): ReLU()
    (4): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))
    (5): ReLU()
  )
  (1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (2): Sequential(
    (0): Conv2d(96, 256, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (3): ReLU()
    (4): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (5): ReLU()
  )
  (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))
    (3): ReLU()
    (4): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))
    (5): ReLU()
  )
  (5): MaxPool2d(kernel_size=3, stri

In [15]:
def load_data_fashion_mnist(batch_size,
                            resize=None,
                            root=r'F:\study\ml\DataSet\FashionMNIST'):
    trans = []
    if resize:
        trans.append(torchvision.transforms.Resize(size=resize))
    trans.append(torchvision.transforms.ToTensor())
    transform = torchvision.transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(root=root,
                                                    train=True,
                                                    download=True,
                                                    transform=transform)
    mnist_test = torchvision.datasets.FashionMNIST(root=root,
                                                   train=False,
                                                   download=True,
                                                   transform=transform)

    train_iter = torch.utils.data.DataLoader(mnist_train,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=4)
    test_iter = torch.utils.data.DataLoader(mnist_test,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=4)
    return train_iter, test_iter

In [16]:
def evaluate_accuracy(data_iter,net):
    acc_sum,n=0.0,0
    with torch.no_grad():
        for x,y in data_iter:
            net.eval()
            acc_sum +=(net(x).argmax(dim=1)==y).float().sum().item()
            net.train()
            n+=y.shape[0]
        return acc_sum /n

In [17]:
def train_ch5(net, train_iter, test_iter, batch_size, optimizer, num_epochs):
    loss = torch.nn.CrossEntropyLoss()
    batch_count = 0
    for epoch in range(num_epochs):
        train_l_sum, train_acc, n, start = 0.0, 0.0, 0, time.time()
        for x, y in train_iter:
            y_hat = net(x)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.item()
            n += y.shape[0]
            batch_count += 1
            train_acc += (y_hat.argmax(dim=1)==y).sum().item()
#             break
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d,loss %.4f,train acc %.3f,test acc %.3f,time %.1f sec' %
              (epoch + 1, train_l_sum / batch_count, train_acc / n, test_acc,
               time.time() - start))

In [21]:
batch_size=128
train_iter,test_iter=load_data_fashion_mnist(batch_size,resize=224)
lr,num_epochs=0.002,5
optimizer=torch.optim.Adam(net.parameters(),lr=lr)
train_ch5(net,train_iter,test_iter,batch_size,optimizer,num_epochs)

epoch 1,loss 2.3027,train acc 0.100,test acc 0.100,time 4661.8 sec
epoch 2,loss 1.1513,train acc 0.100,test acc 0.100,time 4222.5 sec
epoch 3,loss 0.7675,train acc 0.100,test acc 0.100,time 4216.0 sec
epoch 4,loss 0.5756,train acc 0.100,test acc 0.100,time 4230.2 sec
epoch 5,loss 0.4605,train acc 0.100,test acc 0.100,time 4220.9 sec


In [25]:
(0.001/5)**(1/2),1-(0.001/5)**(1/2)

(0.01414213562373095, 0.9858578643762691)

In [26]:
(0.001/2)**(1/2),1-(0.001/2)**(1/2)

(0.022360679774997897, 0.9776393202250021)

In [27]:
(0.001/20)**(1/2),1-(0.001/20)**(1/2)

(0.007071067811865475, 0.9929289321881345)

In [28]:
(0.00001/20)**(1/2),1-(0.00001/20)**(1/2)

(0.0007071067811865476, 0.9992928932188134)