In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST

In [None]:
train_data = FashionMNIST(
    root = "./data/FashionMNIST",
    train = True,
    transform = transforms.ToTensor(),
    download = True
)
train_loader = Data.DataLoader(
    dataset = train_data,
    batch_size = 128,
    shuffle = False,
    num_workers = 0,
)
test_data = FashionMNIST(
    root = "./data/FashionMNIST",
    train = False,
    download = False
)


test_x = test_data.test_data.type(torch.FloatTensor)/255.0
test_x = torch.unsqueeze(test_x,dim = 1)
test_y = test_data.test_labels


In [None]:
class atrousConvNet(nn.Module):
    def __init__(self):
        super(atrousConvNet,self).__init__()

        self.conv1 = nn.Sequential(

            nn.Conv2d(1,16,3,1,1,dilation = 2),
            nn.ReLU(),
            nn.MaxPool2d(2,2), 
        )
        ##定义第二个卷积层
        self.conv2 = nn.Sequential(
            nn.Conv2d(16,32,3,1,0,dilation=2),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(32*4*4,256),
            nn.ReLU(),
            nn.Linear(256,10),
        )

    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0),-1) 
        output = self.classifier(x)
        return output

model_atrous = atrousConvNet()

In [None]:
batch_num = len(train_loader)
train_batch_num = round(batch_num * 0.8)
num_epochs = 25
optimizer = torch.optim.Adam(model_atrous.parameters(),lr=0.003)
criterion = nn.CrossEntropyLoss() 
for epoch in range(num_epochs):

    train_corrects= 0

    for step,(xx,yy) in enumerate(train_loader):

        model_atrous.train() 
        output = model_atrous(xx)
        pred = torch.argmax(output,1)
        loss = criterion(output,yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_corrects += torch.sum(pred == yy.data)

    model_atrous.eval() 
    output = model_atrous(test_x)
    pred = torch.argmax(output,1)
    val_corrects += torch.sum(pred == test_y.data)

    train_acc_now = train_corrects.double().item()/batch_num
    val_acc_now = val_corrects.double().item()/len(test_y.data)
    print('Train Acc: {:.4f}'.format(train_acc_now))
    print('Test Acc:{:.4f}'.format(val_acc_now))
