In [None]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models


import glob
from net import VGG11, UNet
from utils import write_log, loss_info, find_nearest_neighbor, build_features_dict

In [None]:
num_epochs = 20
if (os.path.exists("./output")) == False:
    os.mkdir("output")

for epoch in range (num_epochs):
    if (os.path.exists("./output/%03d" % epoch)) == False:
        os.mkdir("./output/%03d" % epoch)
    else:
        files = glob.glob("./output/%03d/*.png" % epoch)

        for f in files:
          os.remove(f)

In [None]:
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)

latent_size = 100
image_channels = 1
distance_reg_weight = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BATCH_SIZE = 512

# Update the transform pipeline to include the Resize operation

train_dataset = datasets.MNIST('data', train=True, download=True, 
                       transform=transforms.Compose([
                           transforms.Resize(32),
                           transforms.ToTensor()
                       ]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

test_dataset =  datasets.MNIST('data', train=False, transform=transforms.Compose([
                           transforms.Resize(32),
                           transforms.ToTensor()
                       ]))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

In [None]:
pre_trian_vgg = models.vgg16(pretrained=True).features.eval().cuda()
train_features_dict = build_features_dict(train_dataset, pre_trian_vgg, BATCH_SIZE, device)

In [None]:
unet = UNet().cuda()
cnn_real = VGG11().cuda()
cnn_fake = VGG11().cuda()
unet_optimizer =  torch.optim.Adam(unet.parameters(), lr=1e-4)
cnn_real_optimizer = optim.Adam(cnn_real.parameters(), lr=1e-4)
cnn_fake_optimizer = optim.Adam(cnn_fake.parameters(), lr=1e-4)

In [None]:
criterion_CNN = nn.CrossEntropyLoss()
criterion_unet = nn.BCELoss()

In [None]:
def train_unet(unet, images, target_labels, input_labels, criterion_unet):
    # Generate target labels (different from input labels)
    outputs = unet(images, input_labels, target_labels)
    target_images = images.clone()
    for j in range(input_labels.size(0)):
        target_images[j] = find_nearest_neighbor(images, target_labels[j], train_features_dict)

    return 10 * criterion_unet(outputs, target_images), target_images

def train(unet, cnn_fake, cnn_real, unet_optimizer, cnn_fake_optimizer, cnn_real_optimizer, 
          criterion_CNN, criterion_unet,
          distance_reg_weight, num_epochs):
    
    stats = {
    'unet_losses': [], 'fake_losses': [], 'real_losses': []
    } 

    for epoch in range(num_epochs):
        cnn_fake.train()
        cnn_real.train()
        unet.train()

        step = 0

        # Iterate through the GAN and real data loaders in parallel
        for (images, input_labels) in train_loader :
            images, input_labels = images.to(device), input_labels.to(device)

            # Generate target labels (different from input labels)
            target_labels = (input_labels + torch.randint(1, 9, size=(input_labels.size(0),)).cuda()) % 10

            # Train the UNet
            unet_loss, target_images = train_unet(unet, images, target_labels, input_labels, criterion_unet)

            unet_optimizer.zero_grad()
            unet_loss.backward()
            unet_optimizer.step()
           
            # Generate synthetic samples using the trained Unet
            recon_img = unet(images, input_labels, target_labels)

            # Train cnn_fake on synthetic samples
            recon_y_pred = cnn_fake(recon_img)
            fake_loss = criterion_CNN(recon_y_pred, input_labels)  # Use real labels as targets
            cnn_fake_optimizer.zero_grad()
            fake_loss.backward()
            cnn_fake_optimizer.step()

            # Train cnn_real on real samples
            real_y_pred = cnn_real(images)
            real_loss = criterion_CNN(real_y_pred, input_labels)
            cnn_real_optimizer.zero_grad()
            real_loss.backward()
            cnn_real_optimizer.step()

            # Add distance metric between the weights of cnn_fake and cnn_real as a regularization term
            distance_metric = 0
            for p_synthetic, p_real in zip(cnn_fake.parameters(), cnn_real.parameters()):
                if p_synthetic.dim() > 1:
                    distance_metric += torch.nn.functional.cosine_similarity(p_synthetic.view(1, -1), p_real.view(1, -1)).mean()
                else:
                    pass

            # Apply distance regularization on cnn_fake
            cnn_fake_optimizer.zero_grad()
            (-distance_reg_weight * distance_metric).backward(retain_graph=True)
            cnn_fake_optimizer.step()

            (distance_reg_weight * distance_metric).backward()
            cnn_real_optimizer.step()


            if ((step + 1) % 2 == 0):
                fake_img = recon_img[0].detach().cpu().numpy().squeeze()
                real_img = images[0].detach().cpu().numpy().squeeze()
                target_img = target_images[0].detach().cpu().numpy().squeeze()
                loss_info(step, len(train_loader), epoch, unet_loss, fake_loss, real_loss, distance_metric, fake_img, real_img, target_img)

            step += 1

        stats.update({ 'unet_loss': unet_loss.item(), 'fake_loss': fake_loss.item(), 'real_loss': real_loss.item()})
        
        write_log(epoch, num_epochs, stats, log_dir)