In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter

In [2]:
%matplotlib inline

In [3]:
batch_size = 32
learning_rate = 1e-3
num_epochs = 20

In [4]:
train_dataset = datasets.MNIST('./datas', train=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST('./datas', train=False, transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [9]:
class Cnn(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Cnn, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_dim, 6, 3, stride=1, padding=1), # b 6 28 28
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2), # b 6 14 14
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 16, 3, stride=1, padding=1), # b 16 14 14
            nn.ReLU(True),
            nn.MaxPool2d(2, 2), # b 16 7 7
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(16, 16, 3, stride=1), # b 16 5 5
            nn.ReLU(True),
        )
        
        self.fc = nn.Sequential(
            nn.Linear(400, 200),
            nn.Linear(200, 100),
            nn.Linear(100, out_dim),
        )
        
    def forward(self, x):
        out1 = self.conv1(x)
        out2 = self.conv2(out1)
        out3 = self.conv3(out2)
        
        weight1 = torch.randn(out1.size(1), 1, 3, 3).cuda()
        deconv1 = F.conv_transpose2d(out1, weight1)
        
        weight2 = torch.randn(out2.size(1), 1, 3, 3).cuda()
        deconv2 = F.conv_transpose2d(out2, weight2)
        
        weight3 = torch.randn(out3.size(1), 1, 3, 3).cuda()
        deconv3 = F.conv_transpose2d(out3, weight3)
        
        out3 = out3.view(out3.size(0), -1)
        return self.fc(out3), deconv1, deconv2, deconv3
    

In [10]:
model = Cnn(1, 10).cuda()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [11]:
writer = SummaryWriter('./log/cnn3')

In [12]:
for epoch in range(num_epochs):
    
    running_loss = .0
    running_acc = .0
    for i, data in enumerate(train_loader, 1):
        img, label = data
        img = img.cuda()
        label = label.cuda()
        
        out, deconv1, deconv2, deconv3 = model(img)
        loss = criterion(out, label)
        
        running_loss += loss.item() * img.size(0)
        
        _, pred = torch.max(out, 1)
        running_acc += (pred == label).sum().item()
        
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        step = epoch * len(train_loader) + i
        accuracy = (pred == label).float().mean()
        
        writer.add_scalar('loss', loss.item(), step)
        writer.add_scalar('accuracy', accuracy, step)
        writer.add_image('images', torchvision.utils.make_grid(img), step)
        
#         torchvision.utils.save_image(torchvision.utils.make_grid(img), 'xxx.jpg')
#         torchvision.utils.save_image(torchvision.utils.make_grid(deconv1 + img.mean()), 'xxx1.jpg')
#         torchvision.utils.save_image(torchvision.utils.make_grid(deconv2 + img.mean()), 'xxx2.jpg')
#         torchvision.utils.save_image(torchvision.utils.make_grid(deconv3 + img.mean()), 'xxx3.jpg')
#         torchvision.utils.save_image(torchvision.utils.make_grid(img), 'xxx.jpg')
#         torchvision.utils.save_image(torchvision.utils.make_grid(deconv1.abs_()), 'xxx1.jpg')
#         torchvision.utils.save_image(torchvision.utils.make_grid(deconv2.abs_()), 'xxx2.jpg')
#         torchvision.utils.save_image(torchvision.utils.make_grid(deconv3.abs_()), 'xxx3.jpg')

        writer.add_image('deconv1', torchvision.utils.make_grid(deconv1, normalize=True, scale_each=True).data.cpu(), step)
        writer.add_image('deconv2', torchvision.utils.make_grid(deconv2, normalize=True, scale_each=True).data.cpu(), step)
        writer.add_image('deconv3', torchvision.utils.make_grid(deconv3, normalize=True, scale_each=True).data.cpu(), step)
            
        if i % 100 == 0:
            for tag, value in model.named_parameters():
                if tag.startswith('deconv'):
                    continue
                tag = tag.replace('.', '/')
                writer.add_histogram(tag, value.cpu().data.numpy(), step)
                writer.add_histogram(tag + '/grad', value.grad.cpu().data.numpy(), step)
        
        if i % 500 == 0:
            print 'Epoch: [{}/{}], Loss: {:.6f}, Acc: {:.6f}'.format(epoch + 1, num_epochs, \
                                                                    running_loss / (img.size(0) * i), \
                                                                    running_acc / (img.size(0) * i))
            
    print 'Finish {} Epoch, Loss: {:.6f}, Acc: {:.6f}'.format(epoch + 1, \
                                                             running_loss / len(train_dataset), \
                                                             running_acc / len(train_dataset))
        
    model.eval()
    eval_loss = .0
    eval_acc = .0
    for data in test_loader:
        img, label = data
        img = img.cuda()
        label = label.cuda()
        
        out, deconv1, deconv2, deconv3 = model(img)
        loss = criterion(out, label)
        
        eval_loss += loss.item() * img.size(0)
        
        _, pred = torch.max(out, 1)
        eval_acc += (pred == label).sum().item()
        
    print 'Eval Loss: {:.6f}, Eval Acc: {:.6f}'.format(eval_loss / len(test_dataset), eval_acc / len(test_dataset))
    
    model.train()

Epoch: [1/20], Loss: 0.423237, Acc: 0.862938
Epoch: [1/20], Loss: 0.279486, Acc: 0.909937
Epoch: [1/20], Loss: 0.219506, Acc: 0.929604
Finish 1 Epoch, Loss: 0.193162, Acc: 0.938250
Eval Loss: 0.066578, Eval Acc: 0.978200
Epoch: [2/20], Loss: 0.072615, Acc: 0.977313
Epoch: [2/20], Loss: 0.070869, Acc: 0.978531
Epoch: [2/20], Loss: 0.070052, Acc: 0.978458
Finish 2 Epoch, Loss: 0.069250, Acc: 0.978750
Eval Loss: 0.044094, Eval Acc: 0.985400
Epoch: [3/20], Loss: 0.056878, Acc: 0.981688
Epoch: [3/20], Loss: 0.055928, Acc: 0.982906
Epoch: [3/20], Loss: 0.055331, Acc: 0.983062
Finish 3 Epoch, Loss: 0.054683, Acc: 0.983300
Eval Loss: 0.038678, Eval Acc: 0.987700
Epoch: [4/20], Loss: 0.044276, Acc: 0.986375
Epoch: [4/20], Loss: 0.045476, Acc: 0.985906
Epoch: [4/20], Loss: 0.045847, Acc: 0.985812
Finish 4 Epoch, Loss: 0.046547, Acc: 0.985667
Eval Loss: 0.036329, Eval Acc: 0.987400
Epoch: [5/20], Loss: 0.035167, Acc: 0.989125
Epoch: [5/20], Loss: 0.040045, Acc: 0.987094
Epoch: [5/20], Loss: 0.040

In [13]:
writer.close()

In [16]:
torch.save(model.state_dict(), './ser/cnn3.pth')