In [1]:
!pip install git+https://github.com/brianbt/btorch

In [2]:
import torch
import btorch
from btorch import nn
import btorch.nn.functional as F
from btorch.vision.utils import UnNormalize
from torchvision import transforms, datasets
from tqdm import tqdm
from matplotlib import pyplot as plt

# Load Dataset, CIFAR10

In [3]:
# Load CIFAR10 dataset, do augmentation on the trainset
transform_train = transforms.Compose([
  transforms.Resize(64),
  transforms.RandomHorizontalFlip(0.5),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])
transform_test = transforms.Compose([
  transforms.Resize(64),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])
trainset = datasets.CIFAR10('./cifar10',train=True, download=True, transform=transform_train)
testset = datasets.CIFAR10('./cifar10',train=False, download=True, transform=transform_test)

# # Only select the `dog` class
# train_idx = torch.tensor(trainset.targets,dtype=torch.long) == 7
# test_idx = torch.tensor(testset.targets,dtype=torch.long) == 7
# trainset.targets = torch.tensor(trainset.targets,dtype=torch.long)[train_idx]
# trainset.data = trainset.data[train_idx]
# testset.targets = torch.tensor(testset.targets,dtype=torch.long)[test_idx]
# testset.data = testset.data[test_idx]

# Create GAN Model

## Generator and Discriminator

In [7]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(latent_dim, 64 * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )
    def forward(self,x):
        return self.main(x)
    def sample(self, batch_size = 1):
        noise = torch.randn((batch_size, self.latent_dim, 1, 1), device=self.device())
        return self.forward(noise)

In [8]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self,x):
        return self.main(x)
        

## Lets test the Generator and Discriminator output

In [9]:
g = Generator(100)
print(g.summary(input_size = (16,100,1,1)))
d = Discriminator()
print(d.summary(input_size=(16,3,64,64)))
print(d(g.sample(1)).shape)

## Create GAN Module

In [10]:
class GAN(nn.Module):
    def __init__(self, latent_dim):
        super(GAN, self).__init__()
        self.latent_dim = latent_dim
        self.g=Generator(latent_dim)
        self.d=Discriminator()
    def forward(self,x):
        return self.d(x)
    def sample(self, batch_size):
        return self.g.sample(batch_size)
    
    @classmethod
    def train_epoch(cls, net, criterion, trainloader, optimizer, epoch_idx, device='cuda', config=None,**kwargs):
        """This is the very basic training function for one epoch. Override this function when necessary
            
        Returns:
            (float): train_loss
        """
        net.g.train()
        net.d.train()
        G_loss = 0
        D_loss = 0
        G_curr_loss = 0 
        D_curr_loss = torch.nan
        pbar = tqdm(enumerate(trainloader), total=len(trainloader), disable=(kwargs.get("verbose", 1)==0))
        for batch_idx, (inputs, _) in pbar:
            # Trian G ###############################
            optimizer['G'].zero_grad()
            fake_inputs = net.sample(inputs.shape[0])
            fool_labels = torch.ones(inputs.shape[0], device=net.device())
            fool_predicted = net.d(fake_inputs).view(-1)
            G_fool_loss = criterion(fool_predicted, fool_labels)
            G_fool_loss.backward()
            
            optimizer['G'].step()
            G_curr_loss = G_fool_loss.item()
            G_loss += G_curr_loss
            
            # Trian D ###############################
            if epoch_idx >= 0: # train Discriminator less
                optimizer['D'].zero_grad()
                ## Train with real data
                inputs = inputs.to(device)
                real_labels = torch.ones(inputs.shape[0], device=net.device())
                real_predicted = net.d(inputs).view(-1)
                D_real_loss = criterion(real_predicted, real_labels)

                ## Train with fake data
                fake_labels = torch.zeros(inputs.shape[0], device=net.device())
                fake_predicted = net.d(fake_inputs.detach()).view(-1)
                D_fake_loss = criterion(fake_predicted, fake_labels)

                D_curr_lossB = D_real_loss+D_fake_loss
                D_curr_lossB.backward()
                optimizer['D'].step()
                
                D_curr_loss = D_curr_lossB.item()
                D_loss = D_loss + D_curr_loss
            
            pbar.set_description(
                f"epoch {epoch_idx+1} iter {batch_idx}: D loss {D_curr_loss:.5f}, G loss {G_curr_loss:.5f}.")
        return {'D_loss': D_loss/(batch_idx+1), 'G_loss': G_loss/(batch_idx+1)}

    @classmethod
    def before_each_train_epoch(cls, net, criterion, optimizer, trainloader, testloader=None, epoch_idx=0, lr_scheduler=None, config=None,**kwargs):
      config['evol'].append(net.g(config['evol_seed']))

    @classmethod
    def test_epoch(cls, net, criterion, testloader, epoch_idx=0, scoring=None, device='cuda', config=None,**kwargs):
        """This is the very basic evaluating function for one epoch. Override this function when necessary
            
        Returns:
            (float): eval_loss
        """
        net.g.eval()
        net.d.eval()
        G_loss = 0
        D_loss = 0
        with torch.inference_mode():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                # Test G ###############################
                fake_inputs = net.sample(inputs.shape[0])
                fool_labels = torch.ones(inputs.shape[0], device=net.device())
                fool_predicted = net.d(fake_inputs).view(-1)
                G_fool_loss = criterion(fool_predicted, fool_labels)
                G_loss += G_fool_loss.item()
                
                
                # Test D ###############################
                if epoch_idx >= 0:
                    ## Test with real data
                    inputs = inputs.to(device)
                    real_labels = torch.ones(inputs.shape[0], device=net.device())
                    real_predicted = net.d(inputs).view(-1)
                    D_real_loss = criterion(real_predicted, real_labels)

                    ## Test with fake data
                    fake_labels = torch.zeros(inputs.shape[0], device=net.device())
                    fake_predicted = net.d(fake_inputs).view(-1)
                    D_fake_loss = criterion(fake_predicted, fake_labels)

                    D_loss = D_loss + (D_real_loss.item() + D_fake_loss.item())/2

                
        return {'D_loss': D_loss/(batch_idx+1), 'G_loss': G_loss/(batch_idx+1)}

## Paper init weight as (0, 0.02)

In [11]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

## Init Model

In [12]:
from btorch.utils.load_save import save_model, resume
# Model
latent_dim = 128
gan = GAN(latent_dim)
gan.g.apply(weights_init)
gan.d.apply(weights_init)

# Loss & Optimizer & Config
gan._lossfn = nn.BCELoss()
gan._optimizer = {'D':torch.optim.Adam(gan.d.parameters(), lr=0.0002, betas=(0.5, 0.999)),
                  'G':torch.optim.Adam(gan.g.parameters(), lr=0.0002, betas=(0.5, 0.999))}
gan._config['max_epoch'] = 50
gan._config['val_freq'] = 1
gan._config['evol'] = []
gan._config['evol_seed'] = torch.randn((25,latent_dim,1,1), device='cuda')
# gan._config['save'] = './checkpoints/'
# gan._config['save_every_epoch_checkpoint'] = 20
# gan._config['save_base_on'] = 'G_loss'
# gan._config['tensorboard'] = '500epochs'


# Set GPU
gan.auto_gpu()

## Training

In [13]:
gan.fit(trainset, validation_data=testset, batch_size=512, drop_last=True, verbose=1, workers=4)

## Plot the Loss

In [14]:
import pandas as pd
print('train_loss')
pd.DataFrame(gan._history[0]['train_loss_data']).plot()
plt.show()
print('test_loss')
pd.DataFrame(gan._history[0]['test_loss_data']).plot()
plt.show()

# Generated Images

In [15]:
gan.eval()
gan.g.eval()
gan.d.eval()
generated = gan.sample(1)
btorch.vision.utils.pplot(btorch.vision.utils.img_MinMaxScaler(generated))
print(gan.d(generated))

# Discriminator on Real Image

In [16]:
gan.eval()
gan.g.eval()
gan.d.eval()
for i in trainset:
  btorch.vision.utils.pplot(btorch.vision.utils.img_MinMaxScaler(i[0][0]))
  print(gan.d(i[0].unsqueeze(0).cuda())[0])
  break

# Evolution of Generator

In [18]:
for i in range(0,len(gan._config['evol']), 5):
  print(f"epoch {i}")
  btorch.vision.utils.pplot(btorch.vision.utils.img_MinMaxScaler(gan._config['evol'][i]))