# 自动编码器

In [1]:
batch_size = 64
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

# train属性是区别并对应加载训练集和测试机
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [2]:
from torch import nn

In [3]:
class DCautoencoder(nn.Module):
    def __init__(self):
        super(DCautoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, 3, 1), # 16@10*10
            nn.ReLU(True),
            nn.MaxPool2d(2, 2), # 16@5*5
            nn.Conv2d(16, 8, 3, 2, 1), # 8@3*3
            nn.ReLU(True),
            nn.MaxPool2d(2, 1) # 8@2*2
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, 2), # 16@5*5
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 5, 3, 1), # 8@15*15
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 2, 2, 1), # 1@28*28
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [4]:
def to_image(x):
    '''
    定义一个函数将最后的结果转换回图片
    
    每个值在 -1 ~ 1，需要将其转变成0~1
    '''
    
    x = 0.5 * (x + 1.)
    #x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x

In [5]:
from torch import optim
from torch.autograd import Variable
from torchvision.utils import save_image

In [7]:
model = DCautoencoder()
criterion = nn.MSELoss(size_average=False)
#criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr = 1e-3, momentum=0.9)
model.train()
num = 1000
epoch = 0;
for epoch in range(40):
    for img, label in train_loader:
        img = Variable(img) # ++++++++
        out = model(img)
        loss = criterion(out, img)/img.shape[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print('epoch {}, loss: {:.6f}'.format(epoch, loss.item()))
    save_image(to_image(out.data), './CNN_results/lr_0.001_m/epoch_{}.png'.format(epoch))

epoch 0, loss: 354.082550
epoch 1, loss: 407.214691
epoch 2, loss: 357.930206
epoch 3, loss: 371.725372
epoch 4, loss: 357.654602
epoch 5, loss: 319.054932
epoch 6, loss: 331.425446
epoch 7, loss: 333.699738
epoch 8, loss: 329.034698
epoch 9, loss: 378.773682
epoch 10, loss: 362.951752
epoch 11, loss: 327.009399
epoch 12, loss: 362.044891
epoch 13, loss: 322.153259
epoch 14, loss: 373.617645
epoch 15, loss: 342.457764
epoch 16, loss: 370.279022
epoch 17, loss: 341.373199
epoch 18, loss: 327.406799
epoch 19, loss: 376.508728
epoch 20, loss: 367.138031
epoch 21, loss: 320.555756
epoch 22, loss: 334.427734
epoch 23, loss: 351.566864
epoch 24, loss: 417.785706
epoch 25, loss: 362.966644
epoch 26, loss: 342.478424
epoch 27, loss: 365.066895


KeyboardInterrupt: 

In [None]:
optimizer = optim.SGD(model.parameters(), lr = 0.001)
num = 1000
epoch = 0;
for epoch in range(40):
    for img, label in train_loader:
        img = Variable(img) # ++++++++
        out = model(img)
        loss = criterion(out, img)/img.shape[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print('epoch {}, loss: {:.4f}'.format(epoch, loss.item()))
    save_image(to_image(out.data), './CNN_results/lr_0.03__epoch_{}.png'.format(epoch+1))