# GAN trained on mixture of GRFs data on 2D domain

The inputs to the generative model are samples of a GRF
The data is sampled from mixture of GRFs

In [None]:
import torch
import numpy as np
import pylab as plt
import torch.nn.functional as F
import torch.nn as nn
from random_fields import *
from scipy.stats import binned_statistic

#### Parameters

In [None]:
ndim = 60-8 #resolution
in_value_dim = 1 # dimension of the co-domain of input functions
ntrain = 8192 # number of training samples
width = 32 # width of the initial CNN layer.
lr = 0.5 * 1e-5 # learning rate of the optimizer
device = 'cuda:2'
epochs = 400
λ_grad = 10.0 # Lagrange coefficinet for gradient penalty
n_critic = 15 # every n_critic iteration the generator is updated
batch_size = 128


In [None]:
# normalization, pointwise gaussian
class InputNormalizer(object):
    def __init__(self, x, eps=0.00001):
        super(InputNormalizer, self).__init__()

        # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T
        self.mean = torch.mean(x, dim=(1,2,3)).mean(dim=0)
        self.std = torch.std(x, dim=(1,2,3)).mean(dim=0)
        self.eps = eps

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x, sample_idx=None):
        if sample_idx is None:
            std = self.std + self.eps # n
            mean = self.mean
        else:
            if len(self.mean.shape) == len(sample_idx[0].shape):
                std = self.std[sample_idx] + self.eps  # batch*n
                mean = self.mean[sample_idx]
            if len(self.mean.shape) > len(sample_idx[0].shape):
                std = self.std[:,sample_idx]+ self.eps # T*batch*n
                mean = self.mean[:,sample_idx]

        # x is in shape of batch*n or T*batch*n
        x = (x * std) + mean
        return x

    def cuda(self):
        self.mean = self.mean.to(device)
        self.std = self.std.to(device)

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()
        
def compute_acovf(z):
    z_hat = torch.fft.rfft2(z)
    acf = torch.fft.irfft2(torch.conj(z_hat) * z_hat)
    acf = torch.fft.fftshift(acf).mean(dim=0) / z[0].numel()
    acf_r = acf.view(-1).cpu().detach().numpy()
    lags_x, lags_y = torch.meshgrid(torch.arange(ndim) - ndim//2, torch.arange(ndim) - ndim//2)
    lags_r = torch.sqrt(lags_x**2 + lags_y**2).view(-1).cpu().detach().numpy()

    idx = np.argsort(lags_r)
    lags_r = lags_r[idx]
    acf_r = acf_r[idx]

    bin_means, bin_edges, binnumber = binned_statistic(lags_r, acf_r, 'mean', bins=np.linspace(0.0, ndim, 50))
    return bin_edges[:-1], bin_means

### Samples of input GRF

In [None]:
grf = GaussianRF_idct(2, ndim, alpha=1.5, tau=1.0, device=device)
z = grf.sample(10000)
print(z.mean(), z.std())
z = z.detach().cpu()
fig, ax = plt.subplots(1,4, figsize=(16,4))
for i in range(4):
    bar = ax[i].imshow(z[i])
plt.colorbar(bar)
plt.show()

### Samples of input GRF

In [None]:
grf_x = GaussianRF_idct(2, ndim, alpha=3.0, tau=5.0, device=device)
x_train = grf_x.sample(ntrain)
x_train[:ntrain//2] += 1.0
x_train[ntrain//2:] += -1.0
idx = torch.randperm(x_train.shape[0])
x_train = x_train[idx]
print(x_train.mean(), x_train.std())
x_train -= x_train.mean()
x_train /= x_train.std()
x_train = x_train.detach().cpu().unsqueeze(-1)
fig, ax = plt.subplots(1,4, figsize=(16,4))
for i in range(4):
    bar = ax[i].imshow(x_train[i])
plt.colorbar(bar)
plt.show()

x_hist = x_train.view(-1).cpu().detach().numpy()
plt.hist(x_hist, bins=100, histtype='step', density=True)
plt.show()

lags_ref, acf_ref = compute_acovf(x_train.squeeze())
plt.plot(lags_ref, acf_ref)
plt.show()

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train), batch_size=batch_size, shuffle=False)

In [None]:
z = grf.sample(4).unsqueeze(-1)
x = G(z)
fig, ax = plt.subplots(1,4, figsize=(16,4))
for i in range(4):
    ax[i].imshow(x[i].cpu().detach().numpy())
plt.show()
print(D(x))

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)



### Generator and Discriminator functions

In [None]:

class Generator_net(nn.Module):
    def __init__(self):
        super(Generator_net, self).__init__()
        
        self.Conv1 = nn.Conv2d(in_value_dim, width, kernel_size=7, stride = 1, padding = 3)

        self.Conv2 = nn.Conv2d(width, 2*width, kernel_size=7, stride = 1, padding = 3)

        self.Conv3 = nn.Conv2d(2*width, 2*width, kernel_size=7, stride = 1, padding = 3)

        self.Conv4 = nn.Conv2d(2* width, 1* width, kernel_size=7, stride = 1, padding = 3)
        
#         self.Conv5 = nn.Conv2d(2* width, 2* width, kernel_size=7, stride = 1, padding = 3)
        
#         self.Conv6 = nn.Conv2d(1* width, width, kernel_size=3, stride = 1, padding = 1)
        
        self.Conv7 = nn.Conv2d(width, in_value_dim, kernel_size=3, stride = 1, padding = 1)


    def forward(self, x):
        x = torch.permute(x,(0,3,1,2))
        x = F.relu(self.Conv1(x))
        x = F.relu(self.Conv2(x))
        x = F.relu(self.Conv3(x))
        x = F.relu(self.Conv4(x))
#         x = F.relu(self.Conv5(x))
#         x = F.relu(self.Conv6(x))
        x = self.Conv7(x)
        x = torch.permute(x,(0,2,3,1))
        return x




class Discriminator_net(nn.Module):
    def __init__(self):
        super(Discriminator_net, self).__init__()

        self.Conv1 = nn.Conv2d(in_value_dim, width, kernel_size=5, stride = 2, padding = 2)        
        
        self.Conv2 = nn.Conv2d(width, 1* width, kernel_size=5, stride = 2, padding = 2)
        
        self.Conv3 = nn.Conv2d(1* width, 2* width, kernel_size=7, stride = 2, padding = 3)

#         self.Conv4 = nn.Conv2d(4* width, 2*width, kernel_size=6, stride = 4, padding = 2)

        self.Conv5 = nn.Conv2d(2*width, 1*width, kernel_size=7, stride = 3, padding = 3)
        self.fc = nn.Linear(288, 1)

#         self.Conv4 = nn.Conv2d(4* width, width, kernel_size=4, stride = 2, padding = 2)
        
    def forward(self, x):
        x = torch.permute(x,(0,3,1,2))
        x = F.leaky_relu(self.Conv1(x))
        x = F.leaky_relu(self.Conv2(x))
        x = F.leaky_relu(self.Conv3(x))
#         x = F.leaky_relu(self.Conv4(x))
        x = F.leaky_relu(self.Conv5(x))
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)

        return x



In [None]:
D_net = Discriminator_net().to(device)
G_net = Generator_net().to(device)
D_net.apply(weights_init)
G_net.apply(weights_init)
nn_params = sum(p.numel() for p in D_net.parameters() if p.requires_grad)
print("Number discriminator parameters: ", nn_params)
nn_params = sum(p.numel() for p in G_net.parameters() if p.requires_grad)
print("Number generator parameters: ", nn_params)

### Samples generated with initilized generator

In [None]:
z = grf.sample(4).unsqueeze(-1)
print(z.shape)
x = G_net(z)
print(x.shape)
fig, ax = plt.subplots(1,4, figsize=(16,4))
for i in range(4):
    ax[i].imshow(x[i].cpu().detach().numpy())
plt.show()
print(D_net(x))

In [None]:
G_optimizer_net = torch.optim.Adam(G_net.parameters(), lr=lr) #, weight_decay=1e-4)
D_optimizer_net = torch.optim.Adam(D_net.parameters(), lr=lr) #, weight_decay=1e-4)

D_net.train()
G_net.train()

In [None]:
def calculate_gradient_penalty(model, real_images, fake_images, device):
    """Calculates the gradient penalty loss for GAN"""
    # Random weight term for interpolation between real and fake data
    alpha = torch.randn((real_images.size(0), 1, 1, 1), device=device)
    # Get random interpolation between real and fake data
    interpolates = (alpha * real_images + ((1 - alpha) * fake_images)).requires_grad_(True)

    model_interpolates = model(interpolates)
    grad_outputs = torch.ones(model_interpolates.size(), device=device, requires_grad=False)

    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=model_interpolates,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = torch.mean((gradients.norm(2, dim=1) - 1/np.sqrt(ndim * ndim)) ** 2)
    return gradient_penalty

In [None]:
def train_WGANO(D, G, train_data, epochs, D_optim, G_optim, scheduler=None):
    losses_D = np.zeros(epochs)
    losses_G = np.zeros(epochs)
    for i in range(epochs):
        loss_D = 0.0
        loss_G = 0.0
        for j, data in enumerate(train_data):
            # Train D
            x = data[0].to(device)
            D_optimizer.zero_grad()

            x_syn = G(grf.sample(x.shape[0]).unsqueeze(-1))

            W_loss = -torch.mean(D(x)) + torch.mean(D(x_syn.detach()))

            gradient_penalty = calculate_gradient_penalty(D, x.data, x_syn.data, device)

            loss = W_loss + λ_grad * gradient_penalty
            loss.backward()

            loss_D += loss.item()

            D_optim.step()
            
            # Train G
            if (j + 1) % n_critic == 0:
                G_optimizer.zero_grad()

                x_syn = G(grf.sample(x.shape[0]).unsqueeze(-1))

                loss = -torch.mean(D(x_syn))
                loss.backward()
                loss_G += loss.item()

                G_optim.step()
            
        losses_D[i] = loss_D / (ntrain / batch_size)
        losses_G[i] = loss_G / (ntrain / batch_size * n_critic)
            
        with torch.no_grad():
            z = grf.sample(1000).unsqueeze(-1)
            x = G(z)
            fig, ax = plt.subplots(1,5, figsize=(20,4))
            for j in range(3):
                ax[j].imshow(x[j].cpu().detach().numpy())
            lags, acf = compute_acovf(x.squeeze())
            ax[3].hist(x.view(-1).cpu().detach().numpy(), bins=100, histtype='step', density=True)
            ax[3].hist(x_hist, bins=100, histtype='step', density=True)
            ax[4].plot(lags_ref, acf_ref, c='r')
            ax[4].plot(lags, acf, c='k')
            plt.show()
        print(i, "D: ", losses_D[i], "G: ", losses_G[i], "mean: ", x.mean().item(), "std: ", x.std().item())
        storedata(i,G)

    return losses_D, losses_G

In [None]:
def storedata(ite,G):
    high = 2
    low = 0.3
    numb_fig = 5
    fig, ax = plt.subplots(1,numb_fig, figsize=(16,4))
    for i in range(numb_fig):
        bar = ax[i].imshow(x_train[i],extent=[0,1,0,1])
#         bar.set_clim(low, high)
    cax = fig.add_axes([ax[numb_fig-1].get_position().x1+0.01,ax[numb_fig-1].get_position().y0,0.02,ax[numb_fig-1].get_position().height])
    plt.colorbar(bar, cax=cax) # Similar to fig.colorbar(im, cax = cax)

    plt.savefig('~/GANO/Figures/GAN_Mix_GRF/{}GRF.pdf'.format(ite))  
#     plt.show()


############################################

    with torch.no_grad():
        z = grf.sample(numb_fig).unsqueeze(-1)
        x = G(z)
        fig, ax = plt.subplots(1,numb_fig, figsize=(16,4))
    #     ax.set_xlabel([0,1])
        for j in range(numb_fig):
            bar = ax[j].imshow(x[j].cpu().detach().numpy(),extent=[0,1,0,1])
#             bar.set_clim(low, high)
        cax = fig.add_axes([ax[numb_fig-1].get_position().x1+0.01,ax[numb_fig-1].get_position().y0,0.02,ax[numb_fig-1].get_position().height])
        plt.colorbar(bar, cax=cax) # Similar to fig.colorbar(im, cax = cax)
        plt.savefig('~/GANO/Figures/GAN_Mix_GRF/{}GRF_mix_GAN.pdf'.format(ite))  
#         plt.show()
        
        
############################################
        
    with torch.no_grad():
        z = grf.sample(1000).unsqueeze(-1)
        x = G(z)
        lags, acf = compute_acovf(x.squeeze())
        


    fig, ax = plt.subplots(1,1, figsize=(4,4),tight_layout=True)
    ax.hist(x_hist, bins=1000, histtype='step', density=True, color='#ff7f0e', label='Ground truth')
    ax.hist(x.view(-1).cpu().detach().numpy(), bins=1000, color='#1f77b4' , histtype='step', density=True, label='GAN')

    # ax[1].plot(lags_ref, acf_ref, c='r')
    # ax[1].plot(lags, acf, c='k')
    plt.xlabel('value')
    # plt.title('Histogram')
    plt.ylabel('Histogram')




    plt.legend()

    plt.savefig('~/GANO/Figures/GAN_Mix_GRF/{}GRF_mix_HistogramGAN.pdf'.format(ite))  
#     plt.show()
    
############################################

    fig, ax = plt.subplots(1,1, figsize=(4,4),tight_layout=True)
    ax.plot(lags_ref/max(lags_ref), acf_ref, c='#ff7f0e',label='Ground truth'.format(ite))
    ax.plot(lags/max(lags_ref), acf,c='#1f77b4' , label='GAN')
    plt.xlabel('Position')
    plt.ylabel('Auto correlation')
    # plt.title('Auto correlation')
    plt.legend()
    plt.savefig('~/GANO/Figures/GAN_Mix_GRF/{}GRF_mix_AutoCorrelationGAN.pdf'.format(ite))  





In [None]:
# losses_D, losses_G = train_WGANO(D, G, train_loader, epochs, D_optimizer, G_optimizer)
losses_D, losses_G = train_GAN(D_net, G_net, train_loader, epochs, D_optimizer_net, G_optimizer_net)

In [None]:
plt.plot(np.arange(epochs), losses_D, c='k', label='D')
plt.plot(np.arange(epochs), losses_G, c='b', label='G')
plt.legend()
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.show()

In [None]:
high = 1.6
low = -1.6
numb_fig = 5
fig, ax = plt.subplots(1,numb_fig, figsize=(16,4))
for i in range(numb_fig):
    bar = ax[i].imshow(x_train[i],extent=[0,1,0,1])
#     bar.set_clim(low, high)
cax = fig.add_axes([ax[numb_fig-1].get_position().x1+0.01,ax[numb_fig-1].get_position().y0,0.02,ax[numb_fig-1].get_position().height])
plt.colorbar(bar, cax=cax) # Similar to fig.colorbar(im, cax = cax)

# plt.savefig('/home/kazizzad/GANO/Figures/GRF.pdf')  
#     plt.show()

In [None]:
with torch.no_grad():
    z = grf.sample(numb_fig).unsqueeze(-1)
    x = G_net(z)
    fig, ax = plt.subplots(1,numb_fig, figsize=(16,4))
#     ax.set_xlabel([0,1])
    for j in range(numb_fig):
        bar = ax[j].imshow(x[j].cpu().detach().numpy(),extent=[0,1,0,1])
        bar.set_clim(low, high)
    cax = fig.add_axes([ax[numb_fig-1].get_position().x1+0.01,ax[numb_fig-1].get_position().y0,0.02,ax[numb_fig-1].get_position().height])
    plt.colorbar(bar, cax=cax) # Similar to fig.colorbar(im, cax = cax)
#     plt.savefig('/home/kazizzad/GANO/Figures/GRF_GAN.pdf')  
    plt.show()