In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets, transforms, models
import numpy as np
import random
from torchvision.utils import save_image
import os
from torch.optim.lr_scheduler import StepLR
import glob
import cv2
import math

from torch.utils.data import Subset

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
UNET_PATH = 'model_weight/unet.pth'
DNN_PATH = 'model_weight/dnn.pth'

In [None]:
# Hyperparameters
BATCH_SIZE = 128
num_epochs = 24
learning_rate  = 1e-5
MNIST = True
CIFAR10 = False

# Network Training Settings
Train_BASE_DNN = True
Train_Unet = True

if (os.path.exists(DNN_PATH)) == True:
    Train_BASE_DNN = False

if (os.path.exists(UNET_PATH)) == True:
    Train_Unet = False

In [None]:
if (os.path.exists("./output")) == False:
    os.mkdir("output")

if (os.path.exists("./model_weight")) == False:
    os.mkdir("model_weight")

if (os.path.exists("./test_out")) == False:
    os.mkdir("test_out")

for epoch in range (num_epochs):
    if (os.path.exists("./output/%03d" % epoch)) == False:
        os.mkdir("./output/%03d" % epoch)
    else:
        files = glob.glob("./output/%03d/*.png" % epoch)

        for f in files:
          os.remove(f)

In [None]:
train_dataset = datasets.MNIST('data', train=True, download=True, 
                       transform=transforms.Compose([
                           transforms.Resize(32),
                           transforms.ToTensor()
                       ]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

test_dataset =  datasets.MNIST('data', train=False, transform=transforms.Compose([
                           transforms.Resize(32),
                           transforms.ToTensor()
                       ]))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

In [None]:
class VGG11(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG11, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(512, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [None]:
if (Train_BASE_DNN):
    dnn_model = VGG11().cuda()
    dnn_criterion = nn.CrossEntropyLoss()
    dnn_optimizer = torch.optim.Adam(dnn_model.parameters(), lr=1e-5)

    print("Training DNN classifier...")
    for epoch in range(100):
        total = 0
        total_correct = 0
        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda()
            labels = labels.cuda()

            # Forward pass
            outputs = dnn_model(images)
            loss = dnn_criterion(outputs, labels)

            # Backward pass and optimization
            dnn_optimizer.zero_grad()
            loss.backward()
            dnn_optimizer.step()

            _, pred = torch.max(outputs, 1)
            correct = pred.eq(labels).cpu().sum().item()
            total_correct += correct
            total += BATCH_SIZE
        
        print("e:", epoch, 'acc:', total_correct / total)

    print("DNN classifier training complete.")
    torch.save(dnn_model.state_dict(), DNN_PATH)

In [None]:
dnn_model = VGG11().cuda()
dnn_criterion = nn.CrossEntropyLoss()
dnn_model.load_state_dict(torch.load(DNN_PATH))

print("Testing DNN classifier...")
total = 0
total_correct = 0
for i, (images, labels) in enumerate(test_loader):
    images = images.cuda()
    labels = labels.cuda()

    outputs = dnn_model(images)
    
    _, pred = torch.max(outputs, 1)
    correct = pred.eq(labels).cpu().sum().item()
    total_correct += correct
    total += BATCH_SIZE

print("Test Acc:" , total_correct/total)


print("DNN classifier Test complete.")

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.activate = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d((2, 2))
        self.dropout = nn.Dropout(p=0.5)
        self.sigmod = nn.Sigmoid ()
        self.label_embedding = nn.Embedding(10, 512)

        self.encoder_1 = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding= 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding= 1),
            nn.ReLU(inplace=True),
        )

        self.encoder_2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding= 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding= 1),
            nn.ReLU(inplace=True),
        )

        self.encoder_3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding= 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding= 1),
            nn.ReLU(inplace=True),
        )

        self.encoder_4 = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding= 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding= 1),
            nn.ReLU(inplace=True),
        )
        
        self.middle_1_0 = nn.Conv2d(1024, 1024, 3, padding= 1)
        self.middle_1_1 = nn.Conv2d(1024, 1024, 3, padding= 1)
        
       
        self.deconv4_0 = nn.ConvTranspose2d(1536, 512, 3, stride=(2,2), padding = 1, output_padding = 1)
        self.uconv4_1 = nn.Conv2d(1024, 512, 3, padding= 1) 
        self.uconv4_2 = nn.Conv2d(512, 512, 3, padding= 1)

        self.deconv3_0 = nn.ConvTranspose2d(512, 512, 3, stride=(2,2), padding = 1, output_padding = 1)
        self.uconv3_1 = nn.Conv2d(768, 256, 3, padding= 1) 
        self.uconv3_2 = nn.Conv2d(256, 256, 3, padding= 1)

        self.deconv2_0 = nn.ConvTranspose2d(256, 512, 3, stride=(2,2), padding = 1, output_padding = 1)
        self.uconv2_1 = nn.Conv2d(640, 128, 3, padding= 1) 
        self.uconv2_2 = nn.Conv2d(128, 128, 3, padding= 1)

        self.deconv1_0 = nn.ConvTranspose2d(128, 512, 3, stride=(2,2), padding = 1, output_padding = 1)
        self.uconv1_1 = nn.Conv2d(576, 192, 3, padding= 1) 
        self.uconv1_2 = nn.Conv2d(192, 192, 3, padding= 1)

  
        self.out_layer = nn.Conv2d(192, 1, 1)

 

    def forward(self, x, input_labels, target_labels):
        conv1 = self.encoder_1(x)
        pool1 = self.pool(conv1)
        pool1 = self.dropout(pool1)

        conv2 = self.encoder_2(pool1)
        pool2 = self.pool(conv2)
        pool2 = self.dropout(pool2)

        conv3 = self.encoder_3(pool2)
        pool3 = self.pool(conv3)
        pool3 = self.dropout(pool3)

        conv4 = self.encoder_4(pool3)
        pool4 = self.pool(conv4)
        encoder_out = self.dropout(pool4)

        input_label_embedding = self.label_embedding(input_labels).view(input_labels.size(0), 512, 1, 1)
        x1 = torch.cat([encoder_out, input_label_embedding.expand_as(encoder_out)], dim=1)

        convm = self.middle_1_0(x1)
        convm = self.activate(convm)
        convm = self.middle_1_1(convm)
        x2 = self.activate(convm)

        target_label_embedding = self.label_embedding(target_labels).view(target_labels.size(0), 512, 1, 1)
        x2 = torch.cat([x2, target_label_embedding.expand(x2.size(0), 512, x2.size(2), x2.size(3))], dim=1)

        deconv4 = self.deconv4_0(x2)
        uconv4 = torch.cat([deconv4, conv4], 1)   # (None, 4, 4, 1024)
        uconv4 = self.dropout(uconv4)
        uconv4 = self.uconv4_1(uconv4)            # (None, 4, 4, 512)
        uconv4 = self.activate(uconv4)
        uconv4 = self.uconv4_2(uconv4)            # (None, 4, 4, 512)
        uconv4 = self.activate(uconv4)

        deconv3 = self.deconv3_0(uconv4)          # (None, 8, 8, 512)
        uconv3 = torch.cat([deconv3, conv3], 1)   # (None, 8, 8, 768)
        uconv3 = self.dropout(uconv3)
        uconv3 = self.uconv3_1(uconv3)            # (None, 8, 8, 256)
        uconv3 = self.activate(uconv3)
        uconv3 = self.uconv3_2(uconv3)            # (None, 8, 8, 256)
        uconv3 = self.activate(uconv3)
        
        deconv2 = self.deconv2_0(uconv3)          # (None, 16, 16, 512)
        uconv2 = torch.cat([deconv2, conv2], 1)   # (None, 16, 16, 640)
        uconv2 = self.dropout(uconv2)
        uconv2 = self.uconv2_1(uconv2)            # (None, 16, 16, 128)
        uconv2 = self.activate(uconv2)
        uconv2 = self.uconv2_2(uconv2)            # (None, 16, 16, 128)
        uconv2 = self.activate(uconv2)

        deconv1 = self.deconv1_0(uconv2)          # (None, 32, 32, 512)
        uconv1 = torch.cat([deconv1, conv1], 1)   # (None, 32, 32, 576)
        uconv1 = self.dropout(uconv1)
        uconv1 = self.uconv1_1(uconv1)            # (None, 32, 32, 192)
        uconv1 = self.activate(uconv1)
        uconv1 = self.uconv1_2(uconv1)            # (None, 32, 32, 192)
        uconv1 = self.activate(uconv1)

        out = self.out_layer(uconv1)
        out = F.softmax(x, dim=1)

        return out

In [None]:
def perceptual_loss(vgg_model, input_images, output_images):
    feature_layers = [vgg_model.features[i] for i in range(len(vgg_model.features))]
    feature_extractor = nn.Sequential(*feature_layers[:-1]).cuda()
    
    input_features = feature_extractor(input_images)
    output_features = feature_extractor(output_images)
    
    return nn.functional.mse_loss(input_features, output_features)

def total_variation_regularization(images):
    tv_h = torch.sum(torch.abs(images[:, :, 1:, :] - images[:, :, :-1, :]))
    tv_w = torch.sum(torch.abs(images[:, :, :, 1:] - images[:, :, :, :-1]))
    return tv_h + tv_w

def generate_synthetic_digits(digit, count):
    digit_indices = np.where(train_dataset.targets.cpu() == digit.cpu())[0]
    
    if len(digit_indices) == 0:
        raise ValueError(f"No samples found for label {digit.item()}")
        
    selected_indices = np.random.choice(digit_indices, count, replace=True)
    synthetic_digits = torch.stack([train_dataset[i][0] for i in selected_indices])
    return synthetic_digits

# Erode the input images to remove the digit information
def erode_images(images):
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    eroded_images = []
    for image in images:
        gray_image = image.squeeze(0).detach().cpu().numpy()
        eroded_image = cv2.erode(gray_image, kernel, iterations=1)
        eroded_images.append(eroded_image)
    
    eroded_images_np = np.array(eroded_images)
    return torch.tensor(eroded_images_np).unsqueeze(1).cuda()

In [None]:
# Training loop
dnn_criterion = nn.CrossEntropyLoss()

def compute_reward(images, actions, target_labels, epoch, alpha=0.5, beta=0.1):
    # Convert actions (pixel values) to output images
    output_images = actions.squeeze(1).float() / 255

    # Generate target images
    eroded_images = erode_images(images)
    synthetic_target_digits = torch.cat([generate_synthetic_digits(d, 1) for d in target_labels]).cuda()
    target_images = eroded_images + synthetic_target_digits

    # Compute the reconstruction loss
    reconstruction_loss = nn.functional.mse_loss(output_images, target_images)

    # Compute the classification loss
    classification_loss = dnn_criterion(dnn_model(output_images), target_labels)

    # Compute the perceptual loss
    p_loss = perceptual_loss(dnn_model, images, output_images)

    # Combine the losses to compute the reward
    reward = -(reconstruction_loss + alpha * classification_loss + beta * p_loss)

    if (epoch+1) % 50 == 0:
        save_image(images.data, './output/%03d/%04d_recon.png' % ( epoch, i))
        save_image(output_images.data, './output/%03d/%04d_img.png' % ( epoch, i))
        save_image(target_images.data, './output/%03d/%04d_target.png' % ( epoch, i))

    return reward


if (Train_Unet):
    # Initialize model, loss, and optimizer
    model = UNet().cuda()
    dnn_model = VGG11().cuda()
    dnn_model.load_state_dict(torch.load(DNN_PATH))

    # Freeze the DNN classifier weights
    for param in dnn_model.parameters():
        param.requires_grad = False

    # Hyperparameters
    alpha = 0.5
    beta = 0.3
    tv_weight = 0.01

    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        print('e:' , epoch)
        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda()
            labels = labels.cuda()

            # Generate target_labels
            target_labels = (labels + torch.randint(1, 9, size=(BATCH_SIZE,)).cuda()) % 10

            # Sample actions (pixel values) from the model's output probabilities
            action_probs = model(images, labels)
            actions = torch.multinomial(action_probs, 1)

            reward = compute_reward(images, actions, target_labels, epoch)

            # Calculate the policy gradient
            log_probs = torch.log(action_probs.gather(1, actions))
            policy_gradient = -torch.mean(log_probs * reward)

            # Update the model using the policy gradient
            optimizer.zero_grad()
            policy_gradient.backward()
            optimizer.step()

        if ((i + 1) % 50 == 0):
            print("reward:", reward)

    torch.save(model.state_dict(), UNET_PATH)