<a href="https://colab.research.google.com/github/bochendong/giao_bochen/blob/main/%E2%80%9Cmnist_ipynb%E2%80%9D%E7%9A%84%E5%89%AF%E6%9C%AC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets, transforms
import numpy as np
import random
from torchvision.utils import save_image
import os
from torch.optim.lr_scheduler import StepLR
import glob

In [None]:
if (os.path.exists("./output")) == False:
    os.mkdir("output")

if (os.path.exists("./model_weight")) == False:
    os.mkdir("model_weight")

BATCH_SIZE = 32
EPOCHS = 15

for epoch in range (EPOCHS):
    if (os.path.exists("./output/%03d" % epoch)) == False:
        os.mkdir("./output/%03d" % epoch)
    else:
        files = glob.glob("./output/%03d/*.png" % epoch)

        for f in files:
          os.remove(f)

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=True, download=True, 
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,)),
                           transforms.Resize(32)
                       ])),
        batch_size=BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,)),
                           transforms.Resize(32)
                       ])),
        batch_size=BATCH_SIZE, shuffle=True)

In [None]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size = 3)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size = 3)
    
    def forward(self, x):
        return self.relu(self.conv2(self.relu(self.conv1(x))))

class Encoder(nn.Module):
    def __init__(self, chs=(1, 64, 128)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs

class Decoder(nn.Module):
    def __init__(self, chs=(128, 64)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs

class UNet(nn.Module):
    def __init__(self, enc_chs=(1, 64, 128), dec_chs=(128, 64), 
                    num_class=1, retain_dim=True, out_sz=(32,32)):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)

        self.retain_dim  = retain_dim
        self.out_sz      = out_sz

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, self.out_sz)
        return out

In [None]:
class DNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc_1 = nn.Linear(32 * 32, 512)
        self.fc_2 = nn.Linear(512, 256)
        self.fc_3 = nn.Linear(256, 10)
        self.activate = nn.ReLU(inplace=True)

    def forward(self, x):
      x = x.view(-1, 32 * 32)
      out = self.fc_1(x)
      out = self.activate(out)
      out = self.fc_2(out)
      out = self.activate(out)
      out = self.fc_3(out)

      return out

In [None]:
def training(dnn, unet, train_loader, criterion, recon_criterion, optimizer_dnn, optimizer_unet, epoch, use_cuda):
    dataset_iter = iter(train_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    total = 0
    correct_label = 0

    while i < len_dataloader:
        total += BATCH_SIZE
        img, true_label = next(dataset_iter)
        optimizer_dnn.zero_grad()
        optimizer_unet.zero_grad()
        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        recon = unet(img)
        y_pred_on_recon = dnn(recon)

        loss = criterion(y_pred_on_recon, true_label)

        loss.backward()
        optimizer_dnn.step()
        optimizer_unet.step()
        loss_sum += loss

        _, predicted = torch.max(y_pred_on_recon.data, 1)
        correct_label += predicted.eq(true_label.data).cpu().sum().item()

        if  i % 100 == 0:
            real = img.data
            reconstruction= recon.data

            save_image(real, './output/%03d/%d_A.png' % ( epoch, i))
            save_image(reconstruction, 'output/%03d/%d_reconA.png' % ( epoch, i))

        i += 1
    print("e:", epoch, loss_sum, 'acc: ', correct_label/total)

In [None]:
dnn = DNN()
unet = UNet()

learning_rate = 1e-4

optimizer = torch.optim.Adam(dnn.parameters(), lr=learning_rate)
optimizer_unet = torch.optim.Adam(unet.parameters(), lr=learning_rate)

loss_fn = nn.CrossEntropyLoss()

recon_loss = nn.MSELoss()

In [None]:
if (torch.cuda.is_available()):
    torch.cuda.manual_seed_all(42)
    loss_fn = loss_fn.cuda()
    recon_loss = recon_loss.cuda()
    dnn.cuda()
    unet.cuda()

In [None]:
for e in range(EPOCHS):
    training(dnn = dnn,
                 unet=unet,
                 train_loader = train_loader, 
                 criterion = loss_fn, 
                 recon_criterion = recon_loss,
                 optimizer_dnn = optimizer, 
                 optimizer_unet = optimizer_unet,
                 epoch = e,
                 use_cuda = True)

e: 0 tensor(425.2411, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.93115
e: 1 tensor(151.7661, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9751833333333333
e: 2 tensor(107.7781, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9816166666666667
e: 3 tensor(81.1118, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9864
e: 4 tensor(64.6139, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.98895
e: 5 tensor(54.5813, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9906166666666667
e: 6 tensor(47.5059, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9917166666666667
e: 7 tensor(37.6757, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9933333333333333
e: 8 tensor(33.0212, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9945333333333334
e: 9 tensor(28.9382, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9948833333333333
e: 10 tensor(26.3884, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.99535
e: 11 tensor(22.7213, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.996
e: 12 ten

KeyboardInterrupt: ignored

In [None]:
def testing(dnn, unet, test_loader, use_cuda):
    dataset_iter = iter(test_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    total = 0
    correct_label = 0

    while i < len_dataloader:
        total += BATCH_SIZE
        img, true_label = next(dataset_iter)

        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        recon = unet(img)
        y_pred_on_recon = dnn(recon)

        loss = criterion(y_pred_on_recon, true_label)
        loss_sum += loss

        _, predicted = torch.max(y_pred_on_recon.data, 1)
        correct_label += predicted.eq(true_label.data).cpu().sum().item()

        i += 1
    print("Testing:", epoch, loss_sum, 'acc: ', correct_label/total)

In [None]:
testing(dnn = dnn, unet=unet, test_loader = test_loader,  use_cuda = True)