In [1]:
import torch
import os
from tqdm import trange
import torchvision
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score
import time

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda', index=0)

In [2]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [3]:
path="/content/drive/MyDrive/DSLab2"
os.chdir(path)
os.listdir(path)

['utils.py', 'generate.py', 'data', 'checkpoints_kl', 'GAN_with_FID.ipynb']

In [4]:
from utils import save_models, load_model

In [5]:
class Args:
  def __init__(self, epochs=100, lr=0.0002, batch_size=64):
    self.epochs=epochs
    self.lr=lr
    self.batch_size=batch_size


def build_data_loader(batch_size):
  # Data Pipeline
  print('Dataset loading...')
  # MNIST Dataset
  transform = transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize(mean=(0.5), std=(0.5))])

  train_dataset = datasets.MNIST(root='data/MNIST/', train=True, transform=transform, download=True)
  test_dataset = datasets.MNIST(root='data/MNIST/', train=False, transform=transform, download=False)

  train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size, shuffle=True)
  test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size, shuffle=False)
  print('Dataset Loaded.')

  return train_loader, test_loader

In [6]:
args = Args()
train_loader, test_loader = build_data_loader(args.batch_size)

Dataset loading...
Dataset Loaded.


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Generator(nn.Module):
    def __init__(self, g_output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(100, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)

    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))

class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)

    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return self.fc4(x)
        # return torch.sigmoid(self.fc4(x))

# f-GAN



In [24]:
import torch

class F_divergence:
    def __init__(self, name):
        self.name = name

        if name == 'JS':
            self.fdiv = lambda t: torch.log(torch.ones(t.shape, device=t.device) * 2) - torch.log(1 + torch.exp(-t))
            self.fenchel = lambda t: -torch.log(2 - torch.exp(t))
            self.threshold = 0

        elif name == 'KL':
            self.fdiv = lambda t: t
            self.fenchel = lambda t: torch.exp(t - 1)
            self.threshold = 1

        elif name == 'RKL':
            self.fdiv = lambda t: -torch.exp(-t)
            self.fenchel = lambda t: -1 - torch.log(-t)
            self.threshold = -1

        else:
            raise ValueError(f"Unknown divergence type: {name}")

    def activation(self, t):
        return self.fdiv(t)

    def f_star(self, t):
        return self.fenchel(t)


In [18]:
def D_train(x, G, D, D_optimizer, fdiv):
    #=======================Train the discriminator=======================#
    D.zero_grad()

    # train discriminator on real
    x_real, y_real = x, torch.ones(x.shape[0], 1)
    x_real, y_real = x_real.cuda(), y_real.cuda()

    D_real_output = fdiv.activation(D(x_real))
    D_real_loss = torch.mean(D_real_output)

    # train discriminator on facke
    z = torch.randn(x.shape[0], 100).cuda()
    x_fake, y_fake = G(z), torch.zeros(x.shape[0], 1).cuda()

    D_fake_output = fdiv.activation(D(x_fake))
    D_fake_loss = -torch.mean(fdiv.f_star(D_fake_output))

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss

    real_correct = (D_real_output >= fdiv.threshold).float().mean().item()
    fake_correct = (D_fake_output < fdiv.threshold).float().mean().item()
    D_accuracy = 0.5 * (real_correct + fake_correct)

    (-D_loss).backward()
    D_optimizer.step()

    return  D_loss.data.item(), D_accuracy


def G_train(x, G, G_optimizer, fdiv):
    #=======================Train the generator=======================#
    G.zero_grad()

    z = torch.randn(x.shape[0], 100).cuda()

    G_output = fdiv.activation(G(z))
    G_loss = torch.mean(fdiv.f_star(G_output))

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()

    return G_loss.data.item()


# Training

In [26]:
# Load Model
print('Model Loading...')
mnist_dim = 784
G = torch.nn.DataParallel(Generator(g_output_dim = mnist_dim)).cuda()
D = torch.nn.DataParallel(Discriminator(mnist_dim)).cuda()
print('Model loaded.')

Model Loading...
Model loaded.


In [27]:
# Define f-divergence
fdiv = F_divergence('KL')  # Choose the f-divergence type

G_optimizer = optim.Adam(G.parameters(), lr=args.lr)
D_optimizer = optim.Adam(D.parameters(), lr=args.lr)

# Lists to store loss and accuracy values for plotting
D_losses = []
G_losses = []
D_accuracies = []

# Training loop
print('Start Training:')
n_epoch = 10
for epoch in trange(1, n_epoch + 1, leave=True):
    D_epoch_loss = 0
    G_epoch_loss = 0
    D_epoch_accuracy = 0
    batch_count = 0

    for batch_idx, (x, _) in enumerate(train_loader):
        x = x.view(-1, mnist_dim)
        G_loss = G_train(x, G, G_optimizer, fdiv)
        D_loss, D_accuracy = D_train(x, G, D, D_optimizer, fdiv)

        D_epoch_loss += D_loss
        G_epoch_loss += G_loss
        D_epoch_accuracy += D_accuracy
        batch_count += 1

    # Average the losses and accuracy for this epoch
    D_losses.append(D_epoch_loss / batch_count)
    G_losses.append(G_epoch_loss / batch_count)
    D_accuracies.append(D_epoch_accuracy / batch_count)

    # Save models periodically
    if epoch % 10 == 0:
        timestamp = time.time()
        save_models(G, D, 'checkpoints_kl')

print('Training done')

print('D_losses:', D_losses)
print('G_losses:', G_losses)
print('D_accuracies:', D_accuracies)

Start Training:


100%|██████████| 10/10 [02:58<00:00, 17.85s/it]

Training done
D_losses: [65.95773908721486, -885900.198186848, 108.68676505668331, 108.54585565026127, 108.79338639631455, 108.71101897103446, 108.58577578408378, 108.99024419235522, 109.39752266605271, 110.08816882135517]
G_losses: [0.13871643422191332, 0.135336252433786, 0.1353356159889876, 0.13533542564174514, 0.1353353451151075, 0.13533530309637473, 0.13533528557400715, 0.13533527450139587, 0.1353352619354913, 0.1353352546278856]
D_accuracies: [0.777393723347548, 0.8755996801705757, 0.6803787979744137, 0.6803371535181236, 0.6803954557569296, 0.6803538113006397, 0.6803454824093816, 0.6803621401918977, 0.6803954557569296, 0.6804537579957356]





# Generate

In [None]:
# Load Model
print('Model Loading...')
model = Generator(g_output_dim=mnist_dim).cuda()
model = load_model(model, 'checkpoints_js')
model = torch.nn.DataParallel(model).cuda()
model.eval()
print('Model loaded.')

Model Loading...
Model loaded.


  ckpt = torch.load(os.path.join(folder,'G.pth'))


In [None]:
# Gernerate Samples
sample_path = 'samples_f'
os.makedirs(sample_path, exist_ok=True)

n_samples = 0
with torch.no_grad():
    while n_samples<10000:
        z = torch.randn(args.batch_size, 100).cuda()
        x = model(z)
        x = x.reshape(args.batch_size, 28, 28)
        for k in range(x.shape[0]):
            if n_samples<10000:
                torchvision.utils.save_image(x[k:k+1], os.path.join(sample_path, f'{n_samples}.png'))
                n_samples += 1
