In [1]:
import os

import torch
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image

In [8]:
im_tfs = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 标准化
])

train_set = MNIST('./data', transform=im_tfs)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)

In [3]:
# 定义网络
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True),
            nn.Linear(64, 12),
            nn.ReLU(True),
            nn.Linear(12, 3) # 输出的 code 是 3 维，便于可视化
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(True),
            nn.Linear(12, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True),
            nn.Linear(128, 28*28),
            nn.Tanh()
        )

    def forward(self, x):
        encode = self.encoder(x)
        decode = self.decoder(encode)
        return encode, decode

In [9]:
net = autoencoder()
x = Variable(torch.randn(1, 28*28)) # batch size 是 1
code, _ = net(x)
print(code.shape)

torch.Size([1, 3])


In [11]:
criterion = nn.MSELoss(size_average=False)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x

In [None]:
len(train_data)

In [6]:
for im, _ in train_data:
    im = im.view(im.shape[0], -1)
    _, out = net(im)
    break
#print(out.size(),'\n',out)
x = to_img(out.data)
print(x.size(),'\n', x)

torch.Size([64, 1, 28, 28]) 
 tensor([[[[0.5332, 0.5108, 0.4923,  ..., 0.5340, 0.4745, 0.4702],
          [0.5226, 0.4647, 0.5165,  ..., 0.5179, 0.4789, 0.5166],
          [0.5172, 0.4436, 0.5318,  ..., 0.4353, 0.4814, 0.5088],
          ...,
          [0.5099, 0.5157, 0.4605,  ..., 0.4733, 0.5117, 0.4362],
          [0.4696, 0.5391, 0.5302,  ..., 0.4428, 0.4680, 0.5378],
          [0.5130, 0.5206, 0.4384,  ..., 0.5371, 0.5345, 0.5153]]],


        [[[0.5334, 0.5107, 0.4921,  ..., 0.5342, 0.4749, 0.4703],
          [0.5228, 0.4646, 0.5166,  ..., 0.5178, 0.4787, 0.5171],
          [0.5174, 0.4434, 0.5318,  ..., 0.4348, 0.4817, 0.5095],
          ...,
          [0.5101, 0.5158, 0.4606,  ..., 0.4730, 0.5116, 0.4364],
          [0.4694, 0.5390, 0.5301,  ..., 0.4427, 0.4681, 0.5378],
          [0.5127, 0.5208, 0.4386,  ..., 0.5367, 0.5343, 0.5156]]],


        [[[0.5335, 0.5105, 0.4918,  ..., 0.5341, 0.4747, 0.4703],
          [0.5229, 0.4649, 0.5167,  ..., 0.5176, 0.4789, 0.5166],
        

In [12]:
# 开始训练自动编码器
#for e in range(100):
e = 0
for im, _ in train_data:
    e += 1
    im = im.view(im.shape[0], -1)
    im = Variable(im)
        # 前向传播
    _, output = net(im)
    loss = criterion(output, im) / im.shape[0] # 平均
        # 反向传播
    optimizer.zero_grad()
    loss.backward() 
        #print(1)
    optimizer.step()
        #print(2)
    if e % 100 == 0: # 每 20 次，将生成的图片保存一下
        print('epoch: {}, Loss: {:.4f}'.format(e, loss.item()))
        pic = to_img(output.data)
        if not os.path.exists('./std'):
            os.mkdir('./std')
        save_image(pic, './std/image_{}.png'.format(e))

epoch: 100, Loss: 207.3728
epoch: 200, Loss: 188.3503
epoch: 300, Loss: 168.4754
epoch: 400, Loss: 163.0545


In [15]:
from torch import Tensor

In [32]:
for i in range(50):
    a = Variable(torch.randn(1,28*28).clamp(-1,1))
    _ ,img = net(a)
    img = to_img(img.data)
    save_image(img, './test{}.png'.format(i+1))
    

In [23]:
a = Variable(torch.randn(1,10).clamp())
_ ,img = net(a)
img = to_img(img.data)
save_image(img, './test_1.png')

In [31]:
a = Variable(torch.randn(1,10))
print(a)
a = a.clamp(-1,1)
print(a)

tensor([[-0.3501,  0.2911,  0.8022, -0.0199,  0.2474,  1.7094,  0.0311,  1.3327,
         -0.9180,  1.5430]])
tensor([[-0.3501,  0.2911,  0.8022, -0.0199,  0.2474,  1.0000,  0.0311,  1.0000,
         -0.9180,  1.0000]])
