<a href="https://colab.research.google.com/github/bochendong/giao_bochen/blob/main/%E2%80%9Cmnist_ipynb%E2%80%9D%E7%9A%84%E5%89%AF%E6%9C%AC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets, transforms
import numpy as np
import random
from torchvision.utils import save_image
import os
from torch.optim.lr_scheduler import StepLR
import glob

In [2]:
if (os.path.exists("./output")) == False:
    os.mkdir("output")

if (os.path.exists("./model_weight")) == False:
    os.mkdir("model_weight")

BATCH_SIZE = 32
EPOCHS = 15

for epoch in range (EPOCHS):
    if (os.path.exists("./output/%03d" % epoch)) == False:
        os.mkdir("./output/%03d" % epoch)
    else:
        files = glob.glob("./output/%03d/*.png" % epoch)

        for f in files:
          os.remove(f)

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=True, download=True, 
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,)),
                           transforms.Resize(32)
                       ])),
        batch_size=BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,)),
                           transforms.Resize(32)
                       ])),
        batch_size=BATCH_SIZE, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [5]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size = 3)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size = 3)
    
    def forward(self, x):
        return self.relu(self.conv2(self.relu(self.conv1(x))))

class Encoder(nn.Module):
    def __init__(self, chs=(1, 64, 128)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs

class Decoder(nn.Module):
    def __init__(self, chs=(128, 64)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs

class UNet(nn.Module):
    def __init__(self, enc_chs=(1, 64, 128), dec_chs=(128, 64), 
                    num_class=1, retain_dim=True, out_sz=(32,32)):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)

        self.retain_dim  = retain_dim
        self.out_sz      = out_sz

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, self.out_sz)
        return out

In [6]:
class DNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc_1 = nn.Linear(32 * 32, 512)
        self.fc_2 = nn.Linear(512, 256)
        self.fc_3 = nn.Linear(256, 10)
        self.activate = nn.ReLU(inplace=True)

    def forward(self, x):
      x = x.view(-1, 32 * 32)
      out = self.fc_1(x)
      out = self.activate(out)
      out = self.fc_2(out)
      out = self.activate(out)
      out = self.fc_3(out)

      return out

# Base Line

In [8]:
def dnn_training(dnn, train_loader, criterion, optimizer, epoch, use_cuda):
    dataset_iter = iter(train_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    total = 0
    correct_label = 0

    while i < len_dataloader:
        total += BATCH_SIZE
        img, true_label = next(dataset_iter)
        optimizer.zero_grad()
        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        y_pred_on_recon = dnn(img)
        loss = criterion(y_pred_on_recon, true_label)

        loss.backward()
        optimizer.step()
        loss_sum += loss

        _, predicted = torch.max(y_pred_on_recon.data, 1)
        correct_label += predicted.eq(true_label.data).cpu().sum().item()

        i += 1

    print("e:", epoch, loss_sum, 'acc: ', correct_label/total)

def dnn_testing(dnn, test_loader, criterion, use_cuda):
    dataset_iter = iter(test_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    total = 0
    correct_label = 0

    while i < len_dataloader:
        total += BATCH_SIZE
        img, true_label = next(dataset_iter)

        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        y_pred = dnn(img)
        loss = criterion(y_pred, true_label)

        optimizer.step()
        loss_sum += loss

        _, predicted = torch.max(y_pred.data, 1)
        correct_label += predicted.eq(true_label.data).cpu().sum().item()

        i += 1

    print("testing:", loss_sum, 'acc: ', correct_label/total)

In [10]:
baseline = DNN()
learning_rate = 1e-4
optimizer = torch.optim.Adam(baseline.parameters(), lr=learning_rate)

loss_fn = nn.CrossEntropyLoss()

if (torch.cuda.is_available()):
    torch.cuda.manual_seed_all(42)
    loss_fn = loss_fn.cuda()
    baseline.cuda()

In [11]:
for e in range(EPOCHS):
    dnn_training(dnn = baseline,
                 train_loader = train_loader, 
                 criterion = loss_fn, 
                 optimizer = optimizer, 
                 epoch = e,
                 use_cuda = True)

e: 0 tensor(606.3034, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9100833333333334
e: 1 tensor(247.5872, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9608
e: 2 tensor(169.7446, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9732166666666666
e: 3 tensor(126.6334, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.97945
e: 4 tensor(97.5249, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9843833333333334
e: 5 tensor(77.1146, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.98755
e: 6 tensor(62.4413, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9900166666666667
e: 7 tensor(51.4639, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9915166666666667
e: 8 tensor(41.9970, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9931333333333333
e: 9 tensor(35.0475, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9944833333333334
e: 10 tensor(30.4478, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.99515
e: 11 tensor(23.7139, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9963
e: 12 t

In [12]:
dnn_testing(dnn = baseline, test_loader = test_loader, criterion = loss_fn, use_cuda=True)

testing: tensor(46422.5117, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.20577076677316294


# Pred on reconstruct

In [13]:
def training(dnn, unet, train_loader, criterion, recon_criterion, optimizer_dnn, optimizer_unet, epoch, use_cuda):
    dataset_iter = iter(train_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    total = 0
    correct_label = 0

    while i < len_dataloader:
        total += BATCH_SIZE
        img, true_label = next(dataset_iter)
        optimizer_dnn.zero_grad()
        optimizer_unet.zero_grad()
        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        recon = unet(img)
        y_pred_on_recon = dnn(recon)

        loss = criterion(y_pred_on_recon, true_label)

        loss.backward()
        optimizer_dnn.step()
        optimizer_unet.step()
        loss_sum += loss

        _, predicted = torch.max(y_pred_on_recon.data, 1)
        correct_label += predicted.eq(true_label.data).cpu().sum().item()

        if  i % 100 == 0:
            real = img.data
            reconstruction= recon.data

            save_image(real, './output/%03d/%d_A.png' % ( epoch, i))
            save_image(reconstruction, 'output/%03d/%d_reconA.png' % ( epoch, i))

        i += 1
    print("e:", epoch, loss_sum, 'acc: ', correct_label/total)

def testing(dnn, unet, criterion, test_loader, use_cuda):
    dataset_iter = iter(test_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    total = 0
    correct_label = 0

    while i < len_dataloader:
        total += BATCH_SIZE
        img, true_label = next(dataset_iter)

        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        recon = unet(img)
        y_pred_on_recon = dnn(recon)

        loss = criterion(y_pred_on_recon, true_label)
        loss_sum += loss

        _, predicted = torch.max(y_pred_on_recon.data, 1)
        correct_label += predicted.eq(true_label.data).cpu().sum().item()

        i += 1
    print("Testing:", loss_sum, 'acc: ', correct_label/total)

In [14]:
dnn = DNN()
unet = UNet()

learning_rate = 1e-4

optimizer = torch.optim.Adam(dnn.parameters(), lr=learning_rate)
optimizer_unet = torch.optim.Adam(unet.parameters(), lr=learning_rate)

loss_fn = nn.CrossEntropyLoss()

recon_loss = nn.MSELoss()

In [15]:
if (torch.cuda.is_available()):
    torch.cuda.manual_seed_all(42)
    loss_fn = loss_fn.cuda()
    recon_loss = recon_loss.cuda()
    dnn.cuda()
    unet.cuda()

In [16]:
for e in range(EPOCHS):
    training(dnn = dnn,
                 unet=unet,
                 train_loader = train_loader, 
                 criterion = loss_fn, 
                 recon_criterion = recon_loss,
                 optimizer_dnn = optimizer, 
                 optimizer_unet = optimizer_unet,
                 epoch = e,
                 use_cuda = True)

e: 0 tensor(414.9771, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9312333333333334
e: 1 tensor(159.6499, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9727333333333333
e: 2 tensor(107.1658, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.98205
e: 3 tensor(85.3417, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9853666666666666
e: 4 tensor(66.9069, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.98845
e: 5 tensor(53.8967, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9908166666666667
e: 6 tensor(45.9787, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9919
e: 7 tensor(38.7948, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9933
e: 8 tensor(31.3865, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9945
e: 9 tensor(28.8130, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.99485
e: 10 tensor(24.2633, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.99585
e: 11 tensor(23.2469, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9958166666666667
e: 12 tensor(18.3146, device='c

In [18]:
testing(dnn = dnn, unet=unet, criterion = loss_fn, test_loader = test_loader,  use_cuda = True)

Testing: tensor(13.5033, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.9881190095846646
