# GAN trained on volcano data on 2D domain

The inputs to the generative model are samples of a GRF


In [None]:
import torch
import numpy as np
import pylab as plt
import torch.nn.functional as F
import torch.nn as nn
from random_fields import *

#### Parameters

In [None]:
res = 128-8 #resolution
ntrain = 4096 # number of training samples
width= 16 # the dimension of the co-domain of the initial U-NO layer.
lr = 1e-4 # learning rate of the optimizer
device = 'cuda:0'
epochs = 400
λ_grad = 10.0 # Lagrange coefficinet for gradient penalty
n_critic = 20 # every n_critic iteration the generator is updated
in_value_dim = 1 # dimension of the co-domain of input functions
out_value_dim = 2 # dimension of the co-domain of output functions
batch_size = 64

In [None]:
# normalization, pointwise gaussian
class InputNormalizer(object):
    def __init__(self, x, eps=0.00001):
        super(InputNormalizer, self).__init__()

        # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T
        self.mean = torch.mean(x, dim=(1,2,3)).mean(dim=0)
        self.std = torch.std(x, dim=(1,2,3)).mean(dim=0)
        self.eps = eps

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x, sample_idx=None):
        if sample_idx is None:
            std = self.std + self.eps # n
            mean = self.mean
        else:
            if len(self.mean.shape) == len(sample_idx[0].shape):
                std = self.std[sample_idx] + self.eps  # batch*n
                mean = self.mean[sample_idx]
            if len(self.mean.shape) > len(sample_idx[0].shape):
                std = self.std[:,sample_idx]+ self.eps # T*batch*n
                mean = self.mean[:,sample_idx]

        # x is in shape of batch*n or T*batch*n
        x = (x * std) + mean
        return x

    def cuda(self):
        self.mean = self.mean.to(device)
        self.std = self.std.to(device)

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()
        
def compute_acovf(z):
    from scipy.stats import binned_statistic
    z_hat = torch.fft.rfft2(z)
    acf = torch.fft.irfft2(torch.conj(z_hat) * z_hat) / (z[0].numel())
    acf = torch.fft.fftshift(acf).mean(dim=0)
    acf_r = acf.view(-1).cpu().detach().numpy()
    lags_x, lags_y = torch.meshgrid(torch.arange(res) - res//2, torch.arange(res) - res//2)
    lags_r = torch.sqrt(lags_x**2 + lags_y**2).view(-1).cpu().detach().numpy()

    idx = np.argsort(lags_r)
    lags_r = lags_r[idx]
    acf_r = acf_r[idx]

    bin_means, bin_edges, binnumber = binned_statistic(lags_r, acf_r, 'mean', bins=np.linspace(0.0, res, 50))
    return bin_edges[:-1], bin_means

### Load the dataset

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
import glob
files = glob.glob('~/GANO/volcano_data/**/*.int', recursive=True)[:ntrain]
x_train = torch.zeros(ntrain, res, res, 2).float()
for i, f in enumerate(files):
    dtype = np.float32
    nline = 128
    nsamp = 128

    with open(f, 'rb') as fn:
        load_arr = np.frombuffer(fn.read(), dtype=dtype)
        img = np.array(load_arr.reshape((nline, nsamp, -1)))

    phi = np.angle(img[:,:,0] + img[:,:,1]*1j)
    x_train[i,:,:,0] = torch.cos(torch.tensor(phi[:res, :res]))
    x_train[i,:,:,1] = torch.sin(torch.tensor(phi[:res, :res]))    

In [None]:
# Directional statistics
def circular_var(x, dim=None):
    #R = torch.sqrt((x.mean(dim=(1,2))**2).sum(dim=1))
    phase = torch.atan2(x[:,:,:,1], x[:,:,:,0])
    phase = (phase + np.pi) % (2 * np.pi) - np.pi
    
    C1 = torch.cos(phase).sum(dim=(1,2))
    S1 = torch.sin(phase).sum(dim=(1,2))
    R1 = torch.sqrt(C1**2 + S1**2) / (phase.shape[1]*phase.shape[2])
    return 1 - R1

def circular_skew(x):
    phase = torch.atan2(x[:,:,:,1], x[:,:,:,0])
    phase = (phase + np.pi) % (2 * np.pi) - np.pi
    
    C1 = torch.cos(phase).sum(dim=(1,2))
    S1 = torch.sin(phase).sum(dim=(1,2))
    R1 = torch.sqrt(C1**2 + S1**2) / (phase.shape[1]*phase.shape[2])
    
    C2 = torch.cos(2*phase).sum(dim=(1,2))
    S2 = torch.sin(2*phase).sum(dim=(1,2))
    R2 = torch.sqrt(C2**2 + S2**2) / (phase.shape[1]*phase.shape[2])
    
    T1 = torch.atan2(S1, C1)
    T2 = torch.atan2(S2, C2)

    return R2 * torch.sin(T2 - 2*T1) / (1 - R1)**(3/2)


var = circular_var(x_train)
plt.hist(var.numpy(), bins=np.linspace(0.0, 1.0, 25), histtype='step', linewidth=2.0)

In [None]:
print(x_train.shape)
for i in range(10):
    fig, ax = plt.subplots(1,1, figsize=(20,4))
    j = torch.randint(x_train.shape[0], size=(1,))[0]
    phase = torch.atan2(x_train[j,:,:,1], x_train[j,:,:,0]).cpu().detach().numpy()
    phase = (phase + np.pi) % (2 * np.pi) - np.pi
    bar = ax.imshow(phase,  cmap='RdYlBu', vmin = -np.pi, vmax=np.pi, interpolation=None)
    plt.colorbar(bar)
plt.show()

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train), batch_size=batch_size, shuffle=True)

In [None]:
numb_fig = 5
ite = 200
fig, ax = plt.subplots(1,numb_fig, figsize=(16,4))
for j in range(numb_fig):
    phase = torch.atan2(x_train[j,:,:,1], x_train[j,:,:,0]).cpu().detach().numpy()
    phase = (phase + np.pi) % (2 * np.pi) - np.pi
    bar = ax[j].imshow(phase,  cmap='RdYlBu', vmin = -np.pi, vmax=np.pi,extent=[0,1,0,1])
cax = fig.add_axes([ax[numb_fig-1].get_position().x1+0.01,ax[numb_fig-1].get_position().y0,0.02,ax[numb_fig-1].get_position().height])
plt.colorbar(bar, cax=cax) # Similar to fig.colorbar(im, cax = cax)
plt.savefig('~/GANO/Figures/volcano/{}real.pdf'.format(ite))  

### Data visualization

In [None]:
datapoint = 1 #index of a data point
numb_fig = 4
ite = 200
fig, ax = plt.subplots(5,numb_fig, figsize=(14,16))
for i in range(5):
    for j in range(numb_fig):
        phase = torch.atan2(x_train[datapoint+i*numb_fig+j,:,:,1], x_train[datapoint+i*numb_fig+j,:,:,0]).cpu().detach().numpy()
        phase = (phase + np.pi) % (2 * np.pi) - np.pi
        bar = ax[i,j].imshow(phase,  cmap='RdYlBu', vmin = -np.pi, vmax=np.pi,extent=[0,1,0,1])
#cax = fig.add_axes([ax[i,numb_fig-1].get_position().x1+0.01,ax[i,numb_fig-1].get_position().y0,0.02,ax[i,numb_fig-1].get_position().height])
#plt.colorbar(bar, cax=cax) # Similar to fig.colorbar(im, cax = cax)
plt.savefig('~/GANO/Figures/volcano/{}real.pdf'.format(ite)) 
plt.show()


### Samples of input GRF

In [None]:
grf = GaussianRF_idct(2, res, alpha=1.5, tau=1.0, device=device)
numb_fig = 5
z = grf.sample(5)
print(z.mean(), z.std())
z = z.detach().cpu()
fig, ax = plt.subplots(1,numb_fig, figsize=(16,4))
for i in range(5):
    bar = ax[i].imshow(z[i], extent=[0,1,0,1])
cax = fig.add_axes([ax[numb_fig-1].get_position().x1+0.01,ax[numb_fig-1].get_position().y0,0.02,ax[numb_fig-1].get_position().height])
plt.colorbar(bar, cax=cax)
plt.savefig('~/GANO/Figures/volcano/{}GRF.pdf'.format(ite))  
#plt.show()

### Generator and Discriminator functions

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)



        
class Generator_net(nn.Module):
    def __init__(self):
        super(Generator_net, self).__init__()
        
        self.Conv1 = nn.Conv2d(in_value_dim, width, kernel_size=7, stride = 1, padding = 3)

        self.Conv2 = nn.Conv2d(width, 2*width, kernel_size=7, stride = 1, padding = 3)

        self.Conv3 = nn.Conv2d(2*width, 2*width, kernel_size=7, stride = 1, padding = 3)

        self.Conv4 = nn.Conv2d(2* width, 2* width, kernel_size=7, stride = 1, padding = 3)
        self.Conv4_ = nn.Conv2d(2* width, 4* width, kernel_size=7, stride = 1, padding = 3)
        self.Conv4__ = nn.Conv2d(4* width, 4* width, kernel_size=7, stride = 1, padding = 3)
        self.Conv4___ = nn.Conv2d(4* width, 2* width, kernel_size=7, stride = 1, padding = 3)
        
        self.Conv5 = nn.Conv2d(2* width, 2* width, kernel_size=7, stride = 1, padding = 3)
        
        self.Conv6 = nn.Conv2d(2* width, width, kernel_size=3, stride = 1, padding = 1)
        
        self.Conv7 = nn.Conv2d(width, out_value_dim, kernel_size=3, stride = 1, padding = 1)


    def forward(self, x):
        x = torch.permute(x,(0,3,1,2))
        x = F.relu(self.Conv1(x))
        x = F.relu(self.Conv2(x))
        x = F.relu(self.Conv3(x))
        x = F.relu(self.Conv4(x))
        x = F.relu(self.Conv4_(x))
        x = F.relu(self.Conv4__(x))
        x = F.relu(self.Conv4___(x))
        x = F.relu(self.Conv5(x))
        x = F.relu(self.Conv6(x))
        x = 1.1*F.tanh(self.Conv7(x))
        x = torch.permute(x,(0,2,3,1))
        return x




class Discriminator_net(nn.Module):
    def __init__(self):
        super(Discriminator_net, self).__init__()

        self.Conv1 = nn.Conv2d(out_value_dim, width, kernel_size=5, stride = 2, padding = 2)        
        
        self.Conv2 = nn.Conv2d(width, 2* width, kernel_size=5, stride = 2, padding = 2)
        
        self.Conv2_ = nn.Conv2d(2*width, 2* width, kernel_size=5, stride = 1, padding = 2)
        
        self.Conv2__ = nn.Conv2d(2*width, 2* width, kernel_size=5, stride = 1, padding = 2)
        
        self.Conv3 = nn.Conv2d(2* width, 4* width, kernel_size=7, stride = 2, padding = 3)

        self.Conv4 = nn.Conv2d(4* width, 2*width, kernel_size=6, stride = 4, padding = 2)

        self.Conv5 = nn.Conv2d(2*width, 1*width, kernel_size=7, stride = 3, padding = 3)
        self.fc = nn.Linear(64, 1)

#         self.Conv4 = nn.Conv2d(4* width, width, kernel_size=4, stride = 2, padding = 2)
        
    def forward(self, x):
        x = torch.permute(x,(0,3,1,2))
        x = F.leaky_relu(self.Conv1(x))
        x = F.leaky_relu(self.Conv2(x))
        x = F.leaky_relu(self.Conv2_(x))
        x = F.leaky_relu(self.Conv2__(x))
        x = F.leaky_relu(self.Conv3(x))
        x = F.leaky_relu(self.Conv4(x))
        x = F.leaky_relu(self.Conv5(x))
        x = torch.flatten(x, start_dim=1)
#         print(x.shape)
        x = self.fc(x)

        return x





In [None]:
D_net = Discriminator_net().to(device)
G_net = Generator_net().to(device)
D_net.apply(weights_init)
G_net.apply(weights_init)
nn_params = sum(p.numel() for p in D_net.parameters() if p.requires_grad)
print("Number discriminator parameters: ", nn_params)
nn_params = sum(p.numel() for p in G_net.parameters() if p.requires_grad)
print("Number generator parameters: ", nn_params)



### Samples generated with initilized generator

In [None]:
z = grf.sample(4).unsqueeze(-1)
x = G_net(z)
D_net(x)
fig, ax = plt.subplots(1,4, figsize=(16,4))
for i in range(4):
    phase = torch.atan2(x[i,:,:,1], x[i,:,:,0]).cpu().detach().numpy()
    phase = (phase + np.pi) % (2 * np.pi) - np.pi
    bar = ax[i].imshow(phase,  cmap='RdYlBu', vmin = -np.pi, vmax=np.pi, interpolation=None)
plt.colorbar(bar)
plt.show()

In [None]:
G_optimizer_net = torch.optim.Adam(G_net.parameters(), lr=lr) #, weight_decay=1e-4)
D_optimizer_net = torch.optim.Adam(D_net.parameters(), lr=lr) #, weight_decay=1e-4)

D_net.train()
G_net.train()

In [None]:
def calculate_gradient_penalty(model, real_images, fake_images, device):
    """Calculates the gradient penalty loss for GAN"""
    # Random weight term for interpolation between real and fake data
    alpha = torch.randn((real_images.size(0), 1, 1, 1), device=device)
    # Get random interpolation between real and fake data
    interpolates = (alpha * real_images + ((1 - alpha) * fake_images)).requires_grad_(True)

    model_interpolates = model(interpolates)
    grad_outputs = torch.ones(model_interpolates.size(), device=device, requires_grad=False)

    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=model_interpolates,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.reshape(gradients.size(0), -1)
    gradient_penalty = torch.mean((gradients.norm(2, dim=1) - 1.0/res) ** 2)
    return gradient_penalty

In [None]:
def storedata(ite,G):
    high = 2
    low = 0.3
    fig, ax = plt.subplots(1,numb_fig, figsize=(16,4))
    for i in range(numb_fig):
        bar = ax[i].imshow(x_train[i],extent=[0,1,0,1])
#         bar.set_clim(low, high)
    cax = fig.add_axes([ax[numb_fig-1].get_position().x1+0.01,ax[numb_fig-1].get_position().y0,0.02,ax[numb_fig-1].get_position().height])
    plt.colorbar(bar, cax=cax) # Similar to fig.colorbar(im, cax = cax)

    plt.savefig('~/GANO/Figures/volcano/{}GRF.pdf'.format(ite))  
    plt.show()


############################################

    with torch.no_grad():
        z = grf.sample(5*4).unsqueeze(-1)
        x = G(z)
        numb_fig = 4
        fig, ax = plt.subplots(5,numb_fig, figsize=(14,16))
        for i in range(5):
            for j in range(numb_fig):
                phase = torch.atan2(x[i*numb_fig+j,:,:,1], x[i*numb_fig+j,:,:,0]).cpu().detach().numpy()
                phase = (phase + np.pi) % (2 * np.pi) - np.pi
                bar = ax[i,j].imshow(phase,  cmap='RdYlBu', vmin = -np.pi, vmax=np.pi,extent=[0,1,0,1])
        #cax = fig.add_axes([ax[numb_fig-1].get_position().x1+0.01,ax[numb_fig-1].get_position().y0,0.02,ax[numb_fig-1].get_position().height])
        #plt.colorbar(bar, cax=cax) # Similar to fig.colorbar(im, cax = cax)
        plt.savefig('~/GANO/Figures//volcano/{}GRF_GAN.pdf'.format(ite))  
        plt.show()
        
############################################
    numb_fig = 5
    
    with torch.no_grad():
        var = torch.zeros(10000, device=device).float()
        skew = torch.zeros(10000, device=device).float()
        for j in range(10000//100):
            z = grf.sample(100).unsqueeze(-1)
            x = G(z)
            var[j*100:(j+1)*100] = circular_var(x)
            skew[j*100:(j+1)*100] = circular_skew(x)
        
    var_train = circular_var(x_train)
    skew_train = circular_skew(x_train)

    fig, ax = plt.subplots(1,1, figsize=(4,4),tight_layout=True)
    ax.hist(var_train.cpu().detach().numpy(), bins=np.linspace(0.0, 1.0, 50), histtype='step', density=True, color='#ff7f0e', label='Ground truth')
    ax.hist(var.view(-1).cpu().detach().numpy(), bins=np.linspace(0.0, 1.0, 50), color='#1f77b4' , histtype='step', density=True, label='GAN')
    ax.set_xlim(0.0, 1.0)
    # ax[1].plot(lags_ref, acf_ref, c='r')
    # ax[1].plot(lags, acf, c='k')
    plt.xlabel('value')
    # plt.title('Histogram')
    plt.ylabel('Histogram')




    plt.legend()

    plt.savefig('~/GANO/Figures/volcano/{}GRF_varianceGAN.pdf'.format(ite))  
    plt.show()
    
############################################

    fig, ax = plt.subplots(1,1, figsize=(4,4),tight_layout=True)
    ax.hist(skew_train.cpu().detach().numpy(), bins=np.linspace(-4.0, 4.0, 50), histtype='step', density=True, color='#ff7f0e', label='Ground truth')
    ax.hist(skew.view(-1).cpu().detach().numpy(), bins=np.linspace(-4.0, 4.0, 50), color='#1f77b4' , histtype='step', density=True, label='GAN')
    plt.xlabel('value')
    plt.ylabel('Histogram')
    ax.set_xlim(-4.0, 4.0)
    # plt.title('Auto correlation')
    plt.legend()
    plt.savefig('~/GANO/Figures/volcano/{}GRF_skewnessGAN.pdf'.format(ite))
    plt.show()
    
#     plt.close('all')

In [None]:
def train_GAN(D, G, train_data, epochs, D_optim, G_optim, scheduler=None):
    losses_D = np.zeros(epochs)
    losses_G = np.zeros(epochs)
    losses_W = np.zeros(epochs)
    for i in range(epochs):
        loss_D = 0.0
        loss_G = 0.0
        loss_W = 0.0
        for j, data in enumerate(train_data):
            # Train D
            x = data[0].to(device)
            D_optim.zero_grad()

            x_syn = G(grf.sample(x.shape[0]).unsqueeze(-1))
            
            W_loss = -torch.mean(D(x)) + torch.mean(D(x_syn.detach()))

            gradient_penalty = calculate_gradient_penalty(D, x.data, x_syn.data, device)

            loss = W_loss + λ_grad * gradient_penalty
            loss.backward()

            loss_D += loss.item()
            loss_W += W_loss.item()

            D_optim.step()
            
            # Train G
            if (j + 1) % n_critic == 0:
                G_optim.zero_grad()

                x_syn = G(grf.sample(x.shape[0]).unsqueeze(-1))

                loss = -torch.mean(D(x_syn))
                loss.backward()
                loss_G += loss.item()

                G_optim.step()
            
        losses_D[i] = loss_D / batch_size
        losses_G[i] = loss_G / batch_size
        losses_W[i] = loss_W / batch_size
        print(i, "D: ", losses_D[i], "G: ", losses_G[i], "W: ", losses_W[i])
            
#         torch.save(D.state_dict(), "/home/zross/git/gano/models/volcano/D.pt".format())
#         torch.save(G.state_dict(), "/home/zross/git/gano/models/volcano/G.pt".format())

        storedata(i, G)
        
    return losses_D, losses_G, losses_W

In [None]:
losses_D, losses_G, losses_W = train_GAN(D_net, G_net, train_loader, epochs, D_optimizer_net, G_optimizer_net)
G.load_state_dict(torch.load("/home/kaazizza/GANO/volcano/G.pt".format()))