<a href="https://colab.research.google.com/github/bochendong/giao_bochen/blob/main/cifa10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets, transforms
import numpy as np
import random
from torchvision.utils import save_image
import os
from torch.optim.lr_scheduler import StepLR
import glob

# Data Prepare

In [2]:
Train_BASE_UNET = False
Train_BASE_DNN = False
Train_Model = True
MODEL_PATH = 'model_weight/unet.pth'
Base_UNET_PATH = 'model_weight/base_unet.pth'
DNN_PATH = 'model_weight/dnn.pth'

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
'''if (os.path.exists(MODEL_PATH)) == True:
    Train_Model = False'''

'if (os.path.exists(MODEL_PATH)) == True:\n    Train_Model = False'

In [5]:
if (os.path.exists("./output")) == False:
    os.mkdir("output")

if (os.path.exists("./model_weight")) == False:
    os.mkdir("model_weight")

BATCH_SIZE = 128
EPOCHS = 100

for epoch in range (EPOCHS):
    if (os.path.exists("./output/%03d" % epoch)) == False:
        os.mkdir("./output/%03d" % epoch)
    else:
        files = glob.glob("./output/%03d/*.png" % epoch)

        for f in files:
          os.remove(f)

In [6]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
dataset_iter = iter(trainloader)
test_img, test_label = next(dataset_iter)

In [8]:
test_img.size()

torch.Size([128, 3, 32, 32])

In [9]:
test_label.size()

torch.Size([128])

# Unet

In [10]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.activate = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d((2, 2))
        self.dropout = nn.Dropout(p=0.5)
        self.sigmod = nn.Sigmoid ()

        self.encoder_1_0 = nn.Conv2d(3, 64, 3, padding= 1)
        self.encoder_1_1 = nn.Conv2d(64, 64, 3, padding= 1)

        self.encoder_2_0 = nn.Conv2d(64, 128, 3, padding= 1)
        self.encoder_2_1 = nn.Conv2d(128, 128, 3, padding= 1)

        self.encoder_3_0 = nn.Conv2d(128, 256, 3, padding= 1)
        self.encoder_3_1 = nn.Conv2d(256, 256, 3, padding= 1)

        self.encoder_4_0 = nn.Conv2d(256, 512, 3, padding= 1)
        self.encoder_4_1 = nn.Conv2d(512, 512, 3, padding= 1)

        self.middle_1_0 = nn.Conv2d(512, 1024, 3, padding= 1)
        self.middle_1_1 = nn.Conv2d(1024, 1024, 3, padding= 1)

        self.deconv4_0 = nn.ConvTranspose2d(1024, 512, 3, stride=(2,2), padding = 1, output_padding = 1)
        self.uconv4_1 = nn.Conv2d(1024, 512, 3, padding= 1) 
        self.uconv4_2 = nn.Conv2d(512, 512, 3, padding= 1)

        self.deconv3_0 = nn.ConvTranspose2d(512, 512, 3, stride=(2,2), padding = 1, output_padding = 1)
        self.uconv3_1 = nn.Conv2d(768, 256, 3, padding= 1) 
        self.uconv3_2 = nn.Conv2d(256, 256, 3, padding= 1)

        self.deconv2_0 = nn.ConvTranspose2d(256, 512, 3, stride=(2,2), padding = 1, output_padding = 1)
        self.uconv2_1 = nn.Conv2d(640, 128, 3, padding= 1) 
        self.uconv2_2 = nn.Conv2d(128, 128, 3, padding= 1)

        self.deconv1_0 = nn.ConvTranspose2d(128, 512, 3, stride=(2,2), padding = 1, output_padding = 1)
        self.uconv1_1 = nn.Conv2d(576, 192, 3, padding= 1) 
        self.uconv1_2 = nn.Conv2d(192, 192, 3, padding= 1)

        self.out_layer = nn.Conv2d(192, 3, 1)
 

    def forward(self, x):
        conv1 = self.encoder_1_0(x)
        conv1 = self.activate(conv1)
        conv1 = self.encoder_1_1(conv1)
        conv1 = self.activate(conv1)

        pool1 = self.pool(conv1)
        pool1 = self.dropout(pool1)

        conv2 = self.encoder_2_0(pool1)
        conv2 = self.activate(conv2)
        conv2 = self.encoder_2_1(conv2)
        conv2 = self.activate(conv2)

        pool2 = self.pool(conv2)
        pool2 = self.dropout(pool2)

        conv3 = self.encoder_3_0(pool2)
        conv3 = self.activate(conv3)
        conv3 = self.encoder_3_1(conv3)
        conv3 = self.activate(conv3)

        pool3 = self.pool(conv3)
        pool3 = self.dropout(pool3)

        conv4 = self.encoder_4_0(pool3)
        conv4 = self.activate(conv4)
        conv4 = self.encoder_4_1(conv4)
        conv4 = self.activate(conv4)

        pool4 = self.pool(conv4)
        pool4 = self.dropout(pool4)

        convm = self.middle_1_0(pool4)
        convm = self.activate(convm)
        convm = self.middle_1_1(convm)
        convm = self.activate(convm)

        deconv4 = self.deconv4_0(convm)           # (None, 4, 4, 512)
        uconv4 = torch.cat([deconv4, conv4], 1)   # (None, 4, 4, 1024)
        uconv4 = self.dropout(uconv4)
        uconv4 = self.uconv4_1(uconv4)            # (None, 4, 4, 512)
        uconv4 = self.activate(uconv4)
        uconv4 = self.uconv4_2(uconv4)            # (None, 4, 4, 512)
        uconv4 = self.activate(uconv4)

        deconv3 = self.deconv3_0(uconv4)          # (None, 8, 8, 512)
        uconv3 = torch.cat([deconv3, conv3], 1)   # (None, 8, 8, 768)
        uconv3 = self.dropout(uconv3)
        uconv3 = self.uconv3_1(uconv3)            # (None, 8, 8, 256)
        uconv3 = self.activate(uconv3)
        uconv3 = self.uconv3_2(uconv3)            # (None, 8, 8, 256)
        uconv3 = self.activate(uconv3)
        
        deconv2 = self.deconv2_0(uconv3)          # (None, 16, 16, 512)
        uconv2 = torch.cat([deconv2, conv2], 1)   # (None, 16, 16, 640)
        uconv2 = self.dropout(uconv2)
        uconv2 = self.uconv2_1(uconv2)            # (None, 16, 16, 128)
        uconv2 = self.activate(uconv2)
        uconv2 = self.uconv2_2(uconv2)            # (None, 16, 16, 128)
        uconv2 = self.activate(uconv2)

        deconv1 = self.deconv1_0(uconv2)          # (None, 32, 32, 512)
        uconv1 = torch.cat([deconv1, conv1], 1)   # (None, 32, 32, 576)
        uconv1 = self.dropout(uconv1)
        uconv1 = self.uconv1_1(uconv1)            # (None, 32, 32, 192)
        uconv1 = self.activate(uconv1)
        uconv1 = self.uconv1_2(uconv1)            # (None, 32, 32, 192)
        uconv1 = self.activate(uconv1)

        out = self.out_layer(uconv1)
        out = self.activate(out)
        return out
     

In [11]:
def unet_training(model, train_loader, loss_fn, optimizer, epoch, use_cuda = True):
    dataset_iter = iter(train_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    while i < len_dataloader:
        img, true_label = next(dataset_iter)
        optimizer.zero_grad()
        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        recon = model(img)
        loss = loss_fn(recon, img)

        loss.backward()
        optimizer.step()
        loss_sum += loss

        if  i % 100 == 0:
            real = img.data
            reconstruction= recon.data

            save_image(real, './output/%03d/%d_A.png' % ( epoch, i))
            save_image(reconstruction, 'output/%03d/%d_reconA.png' % ( epoch, i))

        i += 1

    print("e:", epoch, loss_sum)


In [12]:
if (Train_BASE_UNET):
    unet = UNet()

    learning_rate = 1e-4
    optimizer = torch.optim.Adam(unet.parameters(), lr=learning_rate)
    recon_loss_fn = nn.L1Loss()

    if (torch.cuda.is_available()):
        recon_loss_fn = recon_loss_fn.cuda()
        unet.cuda()

    for e in range(30):
        unet_training(model=unet, train_loader = trainloader, 
                    loss_fn = recon_loss_fn, optimizer = optimizer, 
                    epoch = e, use_cuda = True)
        
    torch.save(unet.state_dict(), Base_UNET_PATH)

# DNN

In [13]:
class DNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [14]:
def dnn_training(dnn, train_loader, criterion, optimizer, epoch, use_cuda):
    dataset_iter = iter(train_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    total = 0
    correct_label = 0

    while i < len_dataloader:
        total += BATCH_SIZE
        img, true_label = next(dataset_iter)
        optimizer.zero_grad()
        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        y_pred_on_recon = dnn(img)
        loss = criterion(y_pred_on_recon, true_label)

        loss.backward()
        optimizer.step()
        loss_sum += loss

        _, predicted = torch.max(y_pred_on_recon.data, 1)
        correct_label += predicted.eq(true_label.data).cpu().sum().item()

        i += 1

    print("e:", epoch, loss_sum, 'acc: ', correct_label/total)


In [15]:
if (Train_BASE_DNN):
    baseline = DNN()
    learning_rate = 1e-4
    optimizer = torch.optim.Adam(baseline.parameters(), lr=learning_rate)

    loss_fn = nn.CrossEntropyLoss()

    if (torch.cuda.is_available()):
        torch.cuda.manual_seed_all(42)
        loss_fn = loss_fn.cuda()
        baseline.cuda()

    for e in range(EPOCHS):
        dnn_training(dnn = baseline,
                    train_loader = trainloader, 
                    criterion = loss_fn, 
                    optimizer = optimizer, 
                    epoch = e,
                    use_cuda = True)
        
        torch.save(unet.state_dict(), DNN_PATH)

# Pred on reconstruct

In [16]:
def training(dnn, unet, train_loader, loss_fn, recon_loss_fn, optimizer_dnn, optimizer_unet, epoch, use_cuda):
    dataset_iter = iter(train_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    total = 0
    correct_label = 0

    while i < len_dataloader:
        total += BATCH_SIZE
        img, true_label = next(dataset_iter)
        optimizer_dnn.zero_grad()
        optimizer_unet.zero_grad()
        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        recon = unet(img)
        y_pred_on_recon = dnn(recon)

        loss_recon = recon_loss_fn(recon, img)
        loss = loss_fn(y_pred_on_recon, true_label)

        loss.backward()
        optimizer_dnn.step()
        optimizer_unet.step()
        loss_sum += loss

        _, predicted = torch.max(y_pred_on_recon.data, 1)
        correct_label += predicted.eq(true_label.data).cpu().sum().item()

        if  i % 100 == 0:
            real = img.data
            reconstruction= recon.data

            save_image(real, './output/%03d/%d_A.png' % ( epoch, i))
            save_image(reconstruction, 'output/%03d/%d_reconA.png' % ( epoch, i))

        i += 1
    print("e:", epoch, loss_sum, 'acc: ', correct_label/total)

def testing(dnn, unet, criterion, test_loader, use_cuda):
    dataset_iter = iter(test_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    total = 0
    correct_label = 0

    while i < len_dataloader:
        total += BATCH_SIZE
        img, true_label = next(dataset_iter)

        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        recon = unet(img)
        y_pred_on_recon = dnn(recon)

        loss = criterion(y_pred_on_recon, true_label)
        loss_sum += loss

        _, predicted = torch.max(y_pred_on_recon.data, 1)
        correct_label += predicted.eq(true_label.data).cpu().sum().item()

        i += 1
    print("Testing:", loss_sum, 'acc: ', correct_label/total)

In [None]:
if (Train_Model):
    dnn = DNN()
    unet = UNet()

    learning_rate = 1e-4

    optimizer_dnn = torch.optim.Adam(dnn.parameters(), lr=learning_rate)
    optimizer_unet = torch.optim.Adam(unet.parameters(), lr=learning_rate)

    loss_fn = nn.CrossEntropyLoss()

    recon_loss_fn = nn.MSELoss()

    if (torch.cuda.is_available()):
        torch.cuda.manual_seed_all(42)
        loss_fn = loss_fn.cuda()
        recon_loss_fn = recon_loss_fn.cuda()
        dnn.cuda()
        unet.cuda()

    for e in range(EPOCHS):
        training(dnn = dnn,
                    unet=unet,
                    train_loader = trainloader, 
                    loss_fn = loss_fn, 
                    recon_loss_fn = recon_loss_fn,
                    optimizer_dnn = optimizer_dnn, 
                    optimizer_unet = optimizer_unet,
                    epoch = e,
                    use_cuda = True)
        
    torch.save(unet.state_dict(), MODEL_PATH)

e: 0 tensor(844.0641, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.22068813938618925
e: 1 tensor(695.2189, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.3813139386189258
e: 2 tensor(627.8627, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.4527653452685422
e: 3 tensor(567.7477, device='cuda:0', grad_fn=<AddBackward0>) acc:  0.5117886828644501


# Showcase

In [None]:
unet = UNet()
unet.load_state_dict(torch.load(MODEL_PATH))

In [None]:
demo_iter = iter(testloader)
demo_img, demo_label = next(demo_iter)

if torch.cuda.is_available():
    demo_img = demo_img.cuda()
    unet.cuda()

recon = unet(demo_img)
recon = recon.data.cpu().detach()
demo_img = demo_img.data.cpu().detach()
demo_label = demo_label.numpy()


In [None]:
recon_imgs = [[], [], [], [], [], [], [], [], [], []]
real_imgs = [[], [], [], [], [], [], [], [], [], []]

In [None]:
for i, label in enumerate(demo_label):

    recon_img = recon[i].transpose(0,2).transpose(0,1)
    real_img = demo_img[i].transpose(0,2).transpose(0,1)

    recon_imgs[label].append(recon_img)
    real_imgs[label].append(real_img)
  

In [None]:
fig, axs = plt.subplots(10, 10, figsize=(12, 12))

for i in range (5):
    for j in range (5):
      axs[i * 2][j].imshow(recon_imgs[i][j])
      axs[i * 2 + 1][j].imshow(real_imgs[i][j])

for i in range (5):
    for j in range (5):
      axs[i * 2][j + 5].imshow(recon_imgs[i + 5][j])
      axs[i * 2 + 1][j + 5].imshow(real_imgs[i + 5][j])

In [None]:
def dnn_testing(dnn, test_loader, criterion, use_cuda):
    dataset_iter = iter(test_loader)
    len_dataloader = len(dataset_iter)

    i = 0
    loss_sum = 0
    total = 0
    correct_label = 0

    while i < len_dataloader:
        total += BATCH_SIZE
        img, true_label = next(dataset_iter)

        if use_cuda:
            img, true_label = img.cuda(), true_label.cuda()

        y_pred = dnn(img)
        loss = criterion(y_pred, true_label)

        optimizer.step()
        loss_sum += loss

        _, predicted = torch.max(y_pred.data, 1)
        correct_label += predicted.eq(true_label.data).cpu().sum().item()

        i += 1

    print("testing:", loss_sum, 'acc: ', correct_label/total)