In [1]:
import time
import torch
from torch import nn,optim
import torchvision
from torchvision import transforms

In [2]:
def vgg_block(num_convs, in_channels, out_channels):
    blk = []
    for i in range(num_convs):
        if i == 0:
            blk.append(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        else:
            blk.append(
                nn.Conv2d(out_channels, out_channels, kernel_size=3,
                          padding=1))
        blk.append(nn.ReLU())
    blk.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*blk)

In [3]:
conv_arch=((1,1,64),(1,64,128),(2,128,256),(2,256,512),(2,512,512))
fc_features=512*7*7 #c*w*h
fc_hidden_units=4096

In [4]:
class FlattenLayer(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self,x):
        return x.view(x.shape[0],-1)

In [5]:
def vgg(conv_arch,fc_features,fc_hidden_units=4096):
    net=nn.Sequential()
    for i,(num_convs,in_channels,out_channels) in enumerate(conv_arch):
        net.add_module('vgg_block_'+str(i),vgg_block(num_convs,
                                                    in_channels,
                                                    out_channels))
    net.add_module('fc',nn.Sequential(FlattenLayer(),
                                     nn.Linear(fc_features,fc_hidden_units),
                                     nn.ReLU(),
                                     nn.Dropout(0.5),
                                     nn.Linear(fc_hidden_units,fc_hidden_units),
                                     nn.ReLU(),
                                     nn.Dropout(0.5),
                                     nn.Linear(fc_hidden_units,10)))
    return net
        

In [6]:
net=vgg(conv_arch,fc_features,fc_hidden_units)
x=torch.rand(1,1,224,224)
for name,blk in net.named_children():
    x=blk(x)
    print(name,'output shape:',x.shape)

vgg_block_0 output shape: torch.Size([1, 64, 112, 112])
vgg_block_1 output shape: torch.Size([1, 128, 56, 56])
vgg_block_2 output shape: torch.Size([1, 256, 28, 28])
vgg_block_3 output shape: torch.Size([1, 512, 14, 14])
vgg_block_4 output shape: torch.Size([1, 512, 7, 7])
fc output shape: torch.Size([1, 10])


In [7]:
ratio=8
small_conv_arch=[(1,1,64//ratio),(1,64//ratio,128//ratio),(2,128//ratio,256//ratio),
                (2,256//ratio,512//ratio),(2,512//ratio,512//ratio)]
net=vgg(small_conv_arch,fc_features//ratio,fc_hidden_units//ratio)
print(net)

Sequential(
  (vgg_block_0): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (vgg_block_1): Sequential(
    (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (vgg_block_2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (vgg_block_3): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [8]:
def load_data_fashion_mnist(batch_size,
                            resize=None,
                            root=r'F:\study\ml\DataSet\FashionMNIST'):
    trans = []
    if resize:
        trans.append(torchvision.transforms.Resize(size=resize))
    trans.append(torchvision.transforms.ToTensor())
    transform = torchvision.transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(root=root,
                                                    train=True,
                                                    download=True,
                                                    transform=transform)
    mnist_test = torchvision.datasets.FashionMNIST(root=root,
                                                   train=False,
                                                   download=True,
                                                   transform=transform)

    train_iter = torch.utils.data.DataLoader(mnist_train,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=4)
    test_iter = torch.utils.data.DataLoader(mnist_test,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=4)
    return train_iter, test_iter

In [9]:
def evaluate_accuracy(data_iter,net):
    acc_sum,n=0.0,0
    with torch.no_grad():
        for x,y in data_iter:
            net.eval()
            acc_sum +=(net(x).argmax(dim=1)==y).float().sum().item()
            net.train()
            n+=y.shape[0]
        return acc_sum /n

In [12]:
def train_ch5(net, train_iter, test_iter, batch_size, optimizer, num_epochs):
    loss = torch.nn.CrossEntropyLoss()
    batch_count = 0
    for epoch in range(num_epochs):
        train_l_sum, train_acc, n, start = 0.0, 0.0, 0, time.time()
        for x, y in train_iter:
            y_hat = net(x)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.item()
            n += y.shape[0]
            batch_count += 1
            train_acc += (y_hat.argmax(dim=1)==y).sum().item()
#             break
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d,loss %.4f,train acc %.3f,test acc %.3f,time %.1f sec' %
              (epoch + 1, train_l_sum / batch_count, train_acc / n, test_acc,
               time.time() - start))

In [13]:
batch_size=256
train_iter,test_iter=load_data_fashion_mnist(batch_size,resize=224)
lr,num_epochs=0.001,5
optimizer=torch.optim.Adam(net.parameters(),lr=lr)
train_ch5(net,train_iter,test_iter,batch_size,optimizer,num_epochs)

epoch 1,loss 0.7609,train acc 0.707,test acc 0.859,time 2004.7 sec
epoch 2,loss 0.1799,train acc 0.868,test acc 0.886,time 2007.3 sec


KeyboardInterrupt: 