<a href="https://colab.research.google.com/github/bochendong/giao_bochen/blob/main/cifa10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets, transforms
import numpy as np
import random
from torchvision.utils import save_image
import os
from torch.optim.lr_scheduler import StepLR
import glob

# Data Prepare

In [2]:
if (os.path.exists("./output")) == False:
    os.mkdir("output")

if (os.path.exists("./model_weight")) == False:
    os.mkdir("model_weight")

BATCH_SIZE = 128
EPOCHS = 30

for epoch in range (EPOCHS):
    if (os.path.exists("./output/%03d" % epoch)) == False:
        os.mkdir("./output/%03d" % epoch)
    else:
        files = glob.glob("./output/%03d/*.png" % epoch)

        for f in files:
          os.remove(f)

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
dataset_iter = iter(trainloader)
test_img, test_label = next(dataset_iter)

In [6]:
test_img.size()

torch.Size([128, 3, 32, 32])

In [7]:
test_label.size()

torch.Size([128])

# Unet

In [8]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size = 3)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size = 3)
    
    def forward(self, x):
        return self.relu(self.conv2(self.relu(self.conv1(x))))

class Encoder(nn.Module):
    def __init__(self, chs = (3, 64, 128)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs

class Decoder(nn.Module):
    def __init__(self, chs = (128, 64)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs

class UNet(nn.Module):
    def __init__(self, enc_chs = (3, 64, 128), dec_chs = (128, 64), 
                    num_class = 3, retain_dim = True, out_sz = (32, 32)):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)

        self.retain_dim  = retain_dim
        self.out_sz      = out_sz

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, self.out_sz)
        return out
     

In [9]:
def unet_training(model, train_loader, loss_fn, optimizer, epoch, use_cuda = True):
    dataset_iter = iter(train_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    while i < len_dataloader:
        img, true_label = next(dataset_iter)
        optimizer.zero_grad()
        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        recon = model(img)
        loss = loss_fn(recon, img)

        loss.backward()
        optimizer.step()
        loss_sum += loss

        if  i % 100 == 0:
            real = img.data
            reconstruction= recon.data

            save_image(real, './output/%03d/%d_A.png' % ( epoch, i))
            save_image(reconstruction, 'output/%03d/%d_reconA.png' % ( epoch, i))

        i += 1

    print("e:", epoch, loss_sum)


In [10]:
unet = UNet()

learning_rate = 1e-4
optimizer = torch.optim.Adam(unet.parameters(), lr=learning_rate)
recon_loss_fn = nn.L1Loss()

if (torch.cuda.is_available()):
    recon_loss_fn = recon_loss_fn.cuda()
    unet.cuda()

In [11]:
for e in range(30):
    unet_training(model=unet, train_loader = trainloader, 
                 loss_fn = recon_loss_fn, optimizer = optimizer, 
                 epoch = e, use_cuda = True)

e: 0 tensor(117.8972, device='cuda:0', grad_fn=<AddBackward0>)
e: 1 tensor(107.4478, device='cuda:0', grad_fn=<AddBackward0>)
e: 2 tensor(104.7686, device='cuda:0', grad_fn=<AddBackward0>)
e: 3 tensor(102.9117, device='cuda:0', grad_fn=<AddBackward0>)
e: 4 tensor(101.5591, device='cuda:0', grad_fn=<AddBackward0>)
e: 5 tensor(100.5021, device='cuda:0', grad_fn=<AddBackward0>)
e: 6 tensor(99.5055, device='cuda:0', grad_fn=<AddBackward0>)
e: 7 tensor(98.8449, device='cuda:0', grad_fn=<AddBackward0>)
e: 8 tensor(98.2580, device='cuda:0', grad_fn=<AddBackward0>)
e: 9 tensor(97.7451, device='cuda:0', grad_fn=<AddBackward0>)
e: 10 tensor(97.2287, device='cuda:0', grad_fn=<AddBackward0>)
e: 11 tensor(96.7663, device='cuda:0', grad_fn=<AddBackward0>)
e: 12 tensor(96.3375, device='cuda:0', grad_fn=<AddBackward0>)
e: 13 tensor(95.8918, device='cuda:0', grad_fn=<AddBackward0>)
e: 14 tensor(95.5500, device='cuda:0', grad_fn=<AddBackward0>)
e: 15 tensor(95.1934, device='cuda:0', grad_fn=<AddBackward

# DNN

In [12]:
class DNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [13]:
def dnn_training(dnn, train_loader, criterion, optimizer, epoch, use_cuda):
    dataset_iter = iter(train_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    total = 0
    correct_label = 0

    while i < len_dataloader:
        total += BATCH_SIZE
        img, true_label = next(dataset_iter)
        optimizer.zero_grad()
        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        y_pred_on_recon = dnn(img)
        loss = criterion(y_pred_on_recon, true_label)

        loss.backward()
        optimizer.step()
        loss_sum += loss

        _, predicted = torch.max(y_pred_on_recon.data, 1)
        correct_label += predicted.eq(true_label.data).cpu().sum().item()

        i += 1

    print("e:", epoch, loss_sum, 'acc: ', correct_label/total)

def dnn_testing(dnn, test_loader, criterion, use_cuda):
    dataset_iter = iter(test_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    total = 0
    correct_label = 0

    while i < len_dataloader:
        total += BATCH_SIZE
        img, true_label = next(dataset_iter)

        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        y_pred = dnn(img)
        loss = criterion(y_pred, true_label)

        optimizer.step()
        loss_sum += loss

        _, predicted = torch.max(y_pred.data, 1)
        correct_label += predicted.eq(true_label.data).cpu().sum().item()

        i += 1

    print("testing:", loss_sum, 'acc: ', correct_label/total)

In [14]:
baseline = DNN()
learning_rate = 1e-4
optimizer = torch.optim.Adam(baseline.parameters(), lr=learning_rate)

loss_fn = nn.CrossEntropyLoss()

if (torch.cuda.is_available()):
    torch.cuda.manual_seed_all(42)
    loss_fn = loss_fn.cuda()
    baseline.cuda()

In [15]:
for e in range(EPOCHS):
    dnn_training(dnn = baseline,
                 train_loader = trainloader, 
                 criterion = loss_fn, 
                 optimizer = optimizer, 
                 epoch = e,
                 use_cuda = True)

e: 0 tensor(804.1334, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.23549392583120204
e: 1 tensor(691.6365, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.3559582800511509
e: 2 tensor(650.0811, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.390684942455243
e: 3 tensor(627.7081, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.41374280690537085
e: 4 tensor(612.3911, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.43054667519181583
e: 5 tensor(600.9468, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.44259510869565216
e: 6 tensor(591.3448, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.45408407928388744
e: 7 tensor(582.6982, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.4631553708439898
e: 8 tensor(575.2344, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.4687699808184143
e: 9 tensor(568.4309, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.47570332480818417
e: 10 tensor(561.6481, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.48127797314578
e: 11 tensor(555.7025, devic

In [16]:
dnn_testing(dnn = baseline, test_loader = testloader, criterion = loss_fn, use_cuda=True)

testing: tensor(755.0688, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.21054193037974683


# Pred on reconstruct

In [17]:
def training(dnn, unet, train_loader, loss_fn, recon_loss_fn, optimizer_dnn, optimizer_unet, epoch, use_cuda):
    dataset_iter = iter(train_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    total = 0
    correct_label = 0

    while i < len_dataloader:
        total += BATCH_SIZE
        img, true_label = next(dataset_iter)
        optimizer_dnn.zero_grad()
        optimizer_unet.zero_grad()
        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        recon = unet(img)
        y_pred_on_recon = dnn(recon)

        loss = loss_fn(y_pred_on_recon, true_label)

        loss.backward()
        optimizer_dnn.step()
        optimizer_unet.step()
        loss_sum += loss

        _, predicted = torch.max(y_pred_on_recon.data, 1)
        correct_label += predicted.eq(true_label.data).cpu().sum().item()

        if  i % 100 == 0:
            real = img.data
            reconstruction= recon.data

            save_image(real, './output/%03d/%d_A.png' % ( epoch, i))
            save_image(reconstruction, 'output/%03d/%d_reconA.png' % ( epoch, i))

        i += 1
    print("e:", epoch, loss_sum, 'acc: ', correct_label/total)

def testing(dnn, unet, criterion, test_loader, use_cuda):
    dataset_iter = iter(test_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    total = 0
    correct_label = 0

    while i < len_dataloader:
        total += BATCH_SIZE
        img, true_label = next(dataset_iter)

        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        recon = unet(img)
        y_pred_on_recon = dnn(recon)

        loss = criterion(y_pred_on_recon, true_label)
        loss_sum += loss

        _, predicted = torch.max(y_pred_on_recon.data, 1)
        correct_label += predicted.eq(true_label.data).cpu().sum().item()

        i += 1
    print("Testing:", loss_sum, 'acc: ', correct_label/total)

In [18]:
dnn = DNN()
unet = UNet()

learning_rate = 1e-4

optimizer_dnn = torch.optim.Adam(dnn.parameters(), lr=learning_rate)
optimizer_unet = torch.optim.Adam(unet.parameters(), lr=learning_rate)

loss_fn = nn.CrossEntropyLoss()

recon_loss_fn = nn.MSELoss()

In [19]:
if (torch.cuda.is_available()):
    torch.cuda.manual_seed_all(42)
    loss_fn = loss_fn.cuda()
    recon_loss_fn = recon_loss_fn.cuda()
    dnn.cuda()
    unet.cuda()

In [20]:
for e in range(EPOCHS):
    training(dnn = dnn,
                 unet=unet,
                 train_loader = trainloader, 
                 loss_fn = loss_fn, 
                 recon_loss_fn = recon_loss_fn,
                 optimizer_dnn = optimizer_dnn, 
                 optimizer_unet = optimizer_unet,
                 epoch = e,
                 use_cuda = True)

e: 0 tensor(789.8221, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.22912004475703324
e: 1 tensor(702.7535, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.3131993286445013
e: 2 tensor(669.8340, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.36111333120204603
e: 3 tensor(636.6530, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.3992167519181586
e: 4 tensor(612.8073, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.42443254475703324
e: 5 tensor(592.0802, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.44651134910485935
e: 6 tensor(573.3353, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.46457400895140666
e: 7 tensor(555.8212, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.480778452685422
e: 8 tensor(537.9424, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.500559462915601
e: 9 tensor(524.3246, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.5127677429667519
e: 10 tensor(512.3017, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.5245164641943734
e: 11 tensor(499.0322, devic

In [21]:
testing(dnn = dnn, unet=unet, criterion = loss_fn, test_loader = testloader,  use_cuda = True)

Testing: tensor(81.8390, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.6234177215189873
