In [1]:
batch_size = 64
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

# train属性是区别并对应加载训练集和测试机
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [2]:
from torch import nn

In [3]:
class DCautoencoder(nn.Module):
    def __init__(self):
        super(DCautoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, 3, 1), # 16@10*10
            nn.ReLU(True),
            nn.MaxPool2d(2, 2), # 16@5*5
            nn.Conv2d(16, 8, 3, 2, 1), # 8@3*3
            nn.ReLU(True),
            nn.MaxPool2d(2, 1) # 8@2*2
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, 2), # 16@5*5
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 5, 3, 1), # 8@15*15
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 2, 2, 1), # 1@28*28
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [4]:
from torch import optim
from torch.autograd import Variable
from torchvision.utils import save_image

In [5]:
model = DCautoencoder()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
criterion = nn.MSELoss(size_average=False)



In [8]:
def to_image(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x

In [9]:
for epoch in range(40):
    for img, label in train_loader:
        img = Variable(img) # ++++++++
        out = model(img)
        loss = criterion(out, img)/img.shape[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print('epoch {}, loss: {:.6f}'.format(epoch+1, loss.item()))
    save_image(to_image(out.data), './CNN_results/adam/epoch_{}.png'.format(epoch))

epoch 1, loss: 98.023285
epoch 2, loss: 100.635101
epoch 3, loss: 78.012428
epoch 4, loss: 85.822777
epoch 5, loss: 96.050484
epoch 6, loss: 91.333908
epoch 7, loss: 75.051041
epoch 8, loss: 77.449928


KeyboardInterrupt: 