# Example 1: Generative Modeling with Moran's Auxiliary Task (MAT)


## Setup

Load required dependencies

In [None]:
import os
import datetime
import sys
import requests
from urllib.request import urlretrieve
import urllib.request, json 

import torch
import torchvision.transforms as transforms
import torchvision
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import pandas as pd
import scipy

import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler 
from sklearn.preprocessing import normalize
from sklearn import metrics

Check GPU device

In [None]:
torch.cuda.get_device_name(0)

Helper functions, as found in `src/utils.py`

In [None]:
#Normalize data values. Default is [0,1] range; if min_val = -1, range is [-1,1]
def normal(tensor,min_val=-1):
  t_min = torch.min(tensor)
  t_max = torch.max(tensor)
  if t_min == 0 and t_max == 0:
    return torch.tensor(tensor)
  if min_val == -1:
    tensor_norm = 2 * ((tensor - t_min) / (t_max - t_min)) - 1
  if min_val== 0:
    tensor_norm = ((tensor - t_min) / (t_max - t_min))
  return torch.tensor(tensor_norm)

#Light-weight Local Moran's I for tensor data, requiring a sparse weight matrix input. 
#This can be used when there is no need to re-compute the weight matrix at each step
def lw_tensor_local_moran(y,w_sparse,na_to_zero=True,norm=True,norm_min_val=-1):
  y = y.reshape(-1)
  n = len(y)
  n_1 = n - 1
  z = y - y.mean()
  sy = y.std()
  z /= sy
  den = (z * z).sum()
  zl = torch.tensor(w_sparse * z)
  mi = n_1 * z * zl / den
  if na_to_zero==True:
    mi[torch.isnan(mi)] = 0
  if norm==True:
    mi = normal(mi,min_val=norm_min_val)
  return torch.tensor(mi)

#Batch version of lw_tensor_local_moran()
#Computes the (normalized) local Moran's I for an input batch
def batch_lw_tensor_local_moran(y_batch,w_sparse,na_to_zero=True,norm=True,norm_min_val=-1):
  batch_size = y_batch.shape[0]
  N = y_batch.shape[3]
  mi_y_batch = torch.zeros(y_batch.shape)
  for i in range(batch_size):
    y = y_batch[i,:,:,:].reshape(N,N)
    y = y.reshape(-1)
    n = len(y)
    n_1 = n - 1
    z = y - y.mean()
    sy = y.std()
    z /= sy
    den = (z * z).sum()
    zl = torch.tensor(w_sparse * z)
    mi = n_1 * z * zl / den
    if na_to_zero==True:
      mi[torch.isnan(mi)] = 0
    if norm==True:
      mi = normal(mi,min_val=norm_min_val)
    mi_y_batch[i,0,:,:] = mi.reshape(N,N)
  return mi_y_batch    

Load sparse spatial weight matrices

In [None]:
%%capture

urlretrieve('https://github.com/konstantinklemmer/sxl/raw/master/data/w/w_sparse_64.npz','w_sparse_64.npz')
urlretrieve('https://github.com/konstantinklemmer/sxl/raw/master/data/w/w_sparse_32.npz','w_sparse_32.npz')
urlretrieve('https://github.com/konstantinklemmer/sxl/raw/master/data/w/w_sparse_16.npz','w_sparse_16.npz')
urlretrieve('https://github.com/konstantinklemmer/sxl/raw/master/data/w/w_sparse_8.npz','w_sparse_8.npz')
urlretrieve('https://github.com/konstantinklemmer/sxl/raw/master/data/w/w_sparse_4.npz','w_sparse_4.npz')

In [None]:
w_sparse_64 = scipy.sparse.load_npz('w_sparse_64.npz')
w_sparse_32 = scipy.sparse.load_npz('w_sparse_32.npz')
w_sparse_16 = scipy.sparse.load_npz('w_sparse_16.npz')
w_sparse_8 = scipy.sparse.load_npz('w_sparse_8.npz')
w_sparse_4 = scipy.sparse.load_npz('w_sparse_4.npz')

## Data

As customary for GAN training, data is normalized in the range `[-1,1]`. The local Moran's I of the data can be computed at this step already, to avoid further computational burden during training.

### (1) d1: Toy Example - Gaussian peak / dip (32x32)

Generate data

In [None]:
%%capture

def random_peak_dip(X, Y, s=7): #For smaller peak, use s = [2,3], for larger peak use s = [5,6,...]
    
    #Alternate a and b between [-1,11] or [0,10] (if less edge cases are wanted)
    a = torch.randperm(torch.arange(0,10).numel())[:1].float()
    b = torch.randperm(torch.arange(0,10).numel())[:1].float()
    
    c = 10 - a
    d = 10 - b
    
    term1 = .75*torch.exp(-((9*X - a).pow(2) + (9*Y - b).pow(2))/ s) 
    term2 = -(.75*torch.exp(-((9*X - c).pow(2) + (9*Y - d).pow(2))/ s))
    
    f= term1 + term2
    
    return f

xv1, yv1 = torch.meshgrid([torch.linspace(0, 1, 32), torch.linspace(0, 1, 32)])
train_x = torch.cat((
    xv1.contiguous().view(xv1.numel(), 1), 
    yv1.contiguous().view(yv1.numel(), 1)),
    dim=1
)

#Set seed for reproducibility
torch.manual_seed(99)
#Define number of data samples
t = 7000
#Add random noise to unnoisy data from Franke's function; different noise eacht step
N = 32
d1= torch.zeros(t,2,N,N)
for i in range(t):
    train_y_t = torch.zeros(32,32)
    while len((train_y_t == 0).nonzero()) == N**2:
        f = random_peak_dip(train_x[:, 0], train_x[:, 1])
        f = torch.stack([f], -1).squeeze(1)
        train_y_t = f
        train_y_t = normal(torch.tensor(f).reshape(-1),min_val=-1)
    d1[i,0,:,:] = train_y_t.reshape(32,32)
d1[:,1,:,:] = batch_lw_tensor_local_moran(d1[:,0,:,:].reshape(t,1,N,N),w_sparse_32,norm_min_val=-1).reshape(t,N,N)

### (2) d2: Petrel grid (32x32)

Download and prepare data

In [None]:
with urllib.request.urlopen("https://github.com/konstantinklemmer/sxl/raw/master/data/list_petrel.json") as url:
    train_y = np.array(json.loads(url.read().decode()))

N = 32
t = train_y.shape[0]
d2 = torch.zeros(t,2,N,N)
for i in range(t-1):
    train_y_t = torch.tensor(train_y[i,:,:])
    train_y_t = torch.tensor(normal(train_y_t.reshape(-1)))
    d2[i,0,:,:] = train_y_t.reshape(N,N)
d2[:,1,:,:] = batch_lw_tensor_local_moran(d2[:,0,:,:].reshape(t,1,N,N),w_sparse_32,norm_min_val=-1).reshape(t,N,N)

### (3) d3: DEM (32x32)

Download and prepare data

In [None]:
with urllib.request.urlopen("https://github.com/konstantinklemmer/sxl/raw/master/data/list_dem.json") as url:
    train_y = np.array(json.loads(url.read().decode()))
    
N = 32
t = train_y.shape[0]
d3 = torch.zeros(t,2,N,N)
for i in range(t-1):
    train_y_t = torch.tensor(train_y[i,:,:])
    train_y_t = torch.tensor(normal(train_y_t.reshape(-1)))
    d3[i,0,:,:] = train_y_t.reshape(N,N)
d3[:,1,:,:] = batch_lw_tensor_local_moran(d3[:,0,:,:].reshape(t,1,N,N),w_sparse_32,norm_min_val=-1).reshape(t,N,N)

### (4) d4: Tree canopy (64x64)



In [None]:
with urllib.request.urlopen("https://github.com/konstantinklemmer/sxl/raw/master/data/list_tree.json") as url:
    train_y = np.array(json.loads(url.read().decode()))
    
N = 64
t = train_y.shape[0]
d4 = torch.zeros(t,2,N,N)
for i in range(t-1):
    train_y_t = torch.tensor(train_y[i,:,:])
    train_y_t = torch.tensor(StandardScaler().fit_transform(train_y_t.reshape(-1,1)))
    train_y_t = torch.tensor(normal(train_y_t.reshape(-1)))
    d4[i,0,:,:] = train_y_t.reshape(N,N)
d4[:,1,:,:] = batch_lw_tensor_local_moran(d4[:,0,:,:].reshape(t,1,N,N),w_sparse_64,norm_min_val=-1).reshape(t,N,N)
d4 = d4[:1800,:,:,:]

## Training


Define the model architectures for Discriminator (**D**) and Generator (**G**)

In [None]:
###
# VANILLA GAN
###

class Discriminator_VanillaGAN_MAT(nn.Module):
    """
        Simple Discriminator w/ MLP
    """
    def __init__(self, input_size=N*N, num_classes=1):
        super(Discriminator_VanillaGAN_MAT, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2)
        )
        self.output_t1 = nn.Sequential(
            nn.Linear(256, num_classes),
            nn.Sigmoid()
        )
        self.output_t2 = nn.Sequential(
            nn.Linear(256, num_classes),
            nn.Sigmoid()
        )
    def forward(self, x):
        if N==32:
          mi_x = batch_lw_tensor_local_moran(x.detach().cpu(),w_sparse_32)
        if N==64:
          mi_x = batch_lw_tensor_local_moran(x.detach().cpu(),w_sparse_64)
        mi_x = mi_x.to(DEVICE)
        y_ = x.view(x.size(0), -1)
        y_ = self.layer(y_)
        mi_x = mi_x.view(mi_x.size(0), -1)
        mi_y_ = self.layer(mi_x)
        y_ = self.output_t1(y_)
        mi_y_ = self.output_t2(mi_y_)
        return y_, mi_y_

class Generator_VanillaGAN(nn.Module):
    """
        Simple Generator w/ MLP
    """
    def __init__(self, input_size=100, num_classes=N*N):
        super(Generator_VanillaGAN, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, num_classes),
            nn.Tanh()
        )      
    def forward(self, x):
        y_ = self.layer(x)
        y_ = y_.view(x.size(0), 1, N, N)
        return y_

###
# DCGAN 
###

class Discriminator_DCGAN_MAT_32(nn.Module):
    """
        DeepConv Discriminator
    """
    def __init__(self, in_channel=1, num_classes=1):
        super(Discriminator_DCGAN_MAT_32, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channel, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)
        )
        self.output_t1 = nn.Sequential(
            nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.AvgPool2d(4)
        )
        self.output_t2 = nn.Sequential(
            nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.AvgPool2d(4)
        )
        self.fc_t1 = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )
        self.fc_t2 = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )
    def forward(self, x, y=None):
        mi_x = batch_lw_tensor_local_moran(x.detach().cpu(),w_sparse_32)
        mi_x = mi_x.to(DEVICE)
        y_ = self.conv(x)
        mi_y_ = self.conv(mi_x)
        y_ = self.output_t1(y_)
        mi_y_ = self.output_t2(mi_y_)
        y_ = y_.view(y_.size(0), -1)
        mi_y_ = mi_y_.view(mi_y_.size(0), -1)
        y_ = self.fc_t1(y_)
        mi_y_ = self.fc_t2(mi_y_)
        return y_, mi_y_

class Generator_DCGAN_32(nn.Module):
    """
        Convolutional Generator
    """
    def __init__(self, input_size=100, out_channel=1):
        super(Generator_DCGAN_32, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 4*4*512),
            nn.ReLU(),
        )
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, out_channel, 4, stride=2, padding=1, bias=False),
            nn.Tanh(),
        )      
    def forward(self, x, y=None):
        x = x.view(x.size(0), -1)
        y_ = self.fc(x)
        y_ = y_.view(y_.size(0), 512, 4, 4)
        y_ = self.conv(y_)
        return y_

class Discriminator_DCGAN_MAT_64(nn.Module):
    """
        Convolutional Discriminator
    """
    def __init__(self, in_channel=1, num_classes=1):
        super(Discriminator_DCGAN_MAT_64, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channel, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        self.output_t1 = nn.Sequential(
            nn.Conv2d(128, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1),
        )        
        self.output_t2 = nn.Sequential(
            nn.Conv2d(128, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1),
        )          
        self.fc_t1 = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )
        self.fc_t2 = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )
    def forward(self, x, y=None):
        mi_x = batch_lw_tensor_local_moran(x.detach().cpu(),w_sparse_64)
        mi_x = mi_x.to(DEVICE)
        y_ = self.conv(x)
        mi_y_ = self.conv(mi_x)
        y_ = self.output_t1(y_)
        mi_y_ = self.output_t2(mi_y_)
        y_ = y_.view(y_.size(0), -1)
        mi_y_ = mi_y_.view(mi_y_.size(0), -1)
        y_ = self.fc_t1(y_)
        mi_y_ = self.fc_t2(mi_y_)
        return y_, mi_y_

class Generator_DCGAN_64(nn.Module):
    """
        Convolutional Generator
    """
    def __init__(self, out_channel=1, input_size=100):
        super(Generator_DCGAN_64, self).__init__()
        assert IMAGE_DIM[0] % 2**4 == 0, 'Should be divided 16'
        self.init_dim = (IMAGE_DIM[0] // 2**4, IMAGE_DIM[1] // 2**4)
        self.fc = nn.Sequential(
            nn.Linear(input_size, self.init_dim[0]*self.init_dim[1]*512),
            nn.ReLU(),
        )
        self.conv = nn.Sequential(
            nn.Conv2d(512, 512, 3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 128, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, out_channel, 4, stride=2, padding=1, bias=False),
            nn.Tanh(),
        )
    def forward(self, x, y=None):
        x = x.view(x.size(0), -1)
        y_ = self.fc(x)
        y_ = y_.view(y_.size(0), 512, self.init_dim[0], self.init_dim[1])
        y_ = self.conv(y_)
        return y_

###
# EDGAN
###

class Discriminator_EDGAN_MAT_32(nn.Module):
    """
        Convolutional Discriminator
    """
    def __init__(self,nc=1,ndf1=32):
        super(Discriminator_EDGAN_MAT_32,self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(nc,ndf1,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ndf1),
          nn.LeakyReLU(0.2,inplace=True),
          nn.Conv2d(ndf1,ndf1*2,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ndf1*2),
          nn.LeakyReLU(0.2,inplace=True),
          nn.Conv2d(ndf1*2,ndf1*4,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ndf1*4),
          nn.LeakyReLU(0.2,inplace=True),
        )
        self.output_t1 = nn.Sequential(nn.Conv2d(ndf1*4,1,kernel_size=4,stride=1,padding=0),
          nn.Sigmoid()
        )
        self.output_t2 = nn.Sequential(nn.Conv2d(ndf1*4,1,kernel_size=4,stride=1,padding=0),
          nn.Sigmoid()
        )
    def forward(self, x, y=None):
        mi_x = batch_lw_tensor_local_moran(x.detach().cpu(),w_sparse_32)
        mi_x = mi_x.to(DEVICE)
        y_ = self.conv(x)
        mi_y_ = self.conv(mi_x)
        y_ = self.output_t1(y_)
        mi_y_ = self.output_t2(mi_y_)
        y_ = y_.view(y_.size(0), -1)
        mi_y_ = mi_y_.view(mi_y_.size(0), -1)
        return y_, mi_y_

class Generator_EDGAN_32(nn.Module):
    """
        Encoder-Decoder Generator
    """
    def __init__(self, input_size=100, nc=1, ngf=N):
        super(Generator_EDGAN_32, self).__init__()
        assert IMAGE_DIM[0] % 2**4 == 0, 'Should be divided 16'
        self.init_dim = (IMAGE_DIM[0] // 2**4, IMAGE_DIM[1] // 2**4)
        self.fc = nn.Sequential(
            nn.Linear(input_size, N*N),
            nn.ReLU(),
        )
        self.encoder = nn.Sequential(nn.Conv2d(nc,ngf,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ngf),
          nn.LeakyReLU(0.2,inplace=True),
          nn.Conv2d(ngf,ngf*2,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ngf*2),
          nn.LeakyReLU(0.2,inplace=True),
          nn.Conv2d(ngf*2,ngf*4,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ngf*4),
          nn.LeakyReLU(0.2,inplace=True)
        )
        self.decoder = nn.Sequential(nn.ConvTranspose2d(ngf*4,ngf*2,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ngf*2),
          nn.ReLU(),
          nn.ConvTranspose2d(ngf*2,ngf,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ngf),
          nn.ReLU(),
          nn.ConvTranspose2d(ngf,nc,kernel_size=4,stride=2,padding=1),
          nn.Tanh()
        )   
    def forward(self, x, y=None):
        x = x.view(x.size(0), -1)
        y_ = self.fc(x)
        y_ = y_.view(y_.size(0), 1, N, N)
        y_ = self.encoder(y_)
        y_ = self.decoder(y_)
        return y_

class Discriminator_EDGAN_MAT_64(nn.Module):
    """
        Convolutional Discriminator 
    """
    def __init__(self,nc=1,ndf=64):
        super(Discriminator_EDGAN_MAT_64,self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(nc,ndf,kernel_size=4,stride=2,padding=1),
          nn.LeakyReLU(0.2,inplace=True),
          nn.Conv2d(ndf,ndf*2,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ndf*2),
          nn.LeakyReLU(0.2,inplace=True),
          nn.Conv2d(ndf*2,ndf*4,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ndf*4),
          nn.LeakyReLU(0.2,inplace=True),
          nn.Conv2d(ndf*4,ndf*8,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ndf*8),
          nn.LeakyReLU(0.2,inplace=True),
        )
        self.output_t1 = nn.Sequential(nn.Conv2d(ndf*8,1,kernel_size=4,stride=1,padding=0),
          nn.Sigmoid()
        )   
        self.output_t2 = nn.Sequential(nn.Conv2d(ndf*8,1,kernel_size=4,stride=1,padding=0),
          nn.Sigmoid()
        ) 
    def forward(self, x, y=None):
        mi_x = batch_lw_tensor_local_moran(x.detach().cpu(),w_sparse_64)
        mi_x = mi_x.to(DEVICE)
        y_ = self.conv(x)
        mi_y_ = self.conv(mi_x)
        y_ = self.output_t1(y_)
        mi_y_ = self.output_t2(mi_y_)
        y_ = y_.view(y_.size(0), -1)
        mi_y_ = mi_y_.view(mi_y_.size(0), -1)
        return y_, mi_y_

class Generator_EDGAN_64(nn.Module):
    """
        Encoder-Decoder Generator
    """
    def __init__(self, input_size=100, nc=1, ngf=N):
        super(Generator_EDGAN_64, self).__init__()
        assert IMAGE_DIM[0] % 2**4 == 0, 'Should be divided 16'
        self.init_dim = (IMAGE_DIM[0] // 2**4, IMAGE_DIM[1] // 2**4)
        self.fc = nn.Sequential(
            nn.Linear(input_size, N*N),
            nn.ReLU(),
        )
        self.encoder = nn.Sequential(nn.Conv2d(nc,ngf,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ngf),
          nn.LeakyReLU(0.2,inplace=True),
          nn.Conv2d(ngf,ngf*2,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ngf*2),
          nn.LeakyReLU(0.2,inplace=True),
          nn.Conv2d(ngf*2,ngf*4,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ngf*4),
          nn.LeakyReLU(0.2,inplace=True)
        )
        self.decoder = nn.Sequential(nn.ConvTranspose2d(ngf*4,ngf*2,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ngf*2),
          nn.ReLU(),
          nn.ConvTranspose2d(ngf*2,ngf,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ngf),
          nn.ReLU(),
          nn.ConvTranspose2d(ngf,nc,kernel_size=4,stride=2,padding=1),
          nn.Tanh()
        )    
    def forward(self, x, y=None):
        x = x.view(x.size(0), -1)
        y_ = self.fc(x)
        y_ = y_.view(y_.size(0), 1, N, N)
        y_ = self.encoder(y_)
        y_ = self.decoder(y_)
        return y_

Define training configuration:

- `experiment`: Dataset (Toy="d1", PetrelGrid="d2", DEM="d3", TreeCanopy="d4)
- `train_split`: % of data to use for training (in case held-out data is needed for evaluation)
- `model`: MAT augmented GAN model to-be-used; models abailable are a VanillaGAN, DCGAN and EDGAN
- `batch_size`: training batch size
- `lambda_`: auxiliary task loss weight parameter. We experimented (and got good results) with values `[0.01, 0.1, 1]`
- `loss_method`: using a normal loss(="N") or Wasserstein loss (="W")

In [None]:
from decimal import Decimal, getcontext
from torch.utils.data import TensorDataset, DataLoader

getcontext().prec = 3
torch.manual_seed(99)

### DEFINE EXPERIMENT SETTINGS ###
experiment = "d2" 
train_split = Decimal(0.8) # 80% training data
model = "EDGAN_MAT" # chose from ["VanillaGAN_MAT","DCGAN_MAT","EDGAN_MAT"]
batch_size = 32 # define batch size
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Train on GPU or CPU
lambda_ = 0.1 # weight of the Moran's Auxiliary Task loss
loss_method = "N" # loss method to be used; chose normal (="N") or Wasserstein (="W")
###

Define training loop and run the model!

In [None]:
##PREPARATION
# Prepare input
data = eval(experiment)
test_split = Decimal(1 - train_split)
n = data.shape[0]
N = data.shape[3]
IMAGE_DIM = (N,N,1)
train_set, val_set = torch.utils.data.random_split(data, [int(n * train_split), int(n * test_split)])
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,drop_last=True)
test_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True,drop_last=True)
# Set training epochs
if experiment=="d1":
  num_epochs = 40
if experiment=="d2":
  num_epochs = 300
if experiment=="d3":
  num_epochs = 100
if experiment=="d4":
  num_epochs = 100
# Model
if model=="VanillaGAN_MAT":
  D = Discriminator_VanillaGAN_MAT().to(DEVICE)
  G = Generator_VanillaGAN().to(DEVICE)
if model=="DCGAN_MAT":
  if N==32:
    D = Discriminator_DCGAN_MAT_32().to(DEVICE)
    G = Generator_DCGAN_32().to(DEVICE)
  if N==64:
    D = Discriminator_DCGAN_MAT_64().to(DEVICE)
    G = Generator_DCGAN_64().to(DEVICE)
if model=="EDGAN_MAT":
  if N==32:
    D = Discriminator_EDGAN_MAT_32().to(DEVICE)
    G = Generator_EDGAN_32().to(DEVICE)
  if N==64:
    D = Discriminator_EDGAN_MAT_64().to(DEVICE)
    G = Generator_EDGAN_64().to(DEVICE)
# Labels
D_labels = torch.ones([batch_size, 1]).to(DEVICE) # Discriminator Label to real
D_labels = D_labels - 0.1
D_fakes = torch.zeros([batch_size, 1]).to(DEVICE) # Discriminator Label to fake
# Optimizer
criterion = nn.BCELoss()
D_opt = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.001, betas=(0.5, 0.999))
# Other
g_step = 0
step = 0
n_noise = 100
### TRAINING
for e in range(num_epochs):
  # Within each iteration, we will go over each minibatch of data
  for minibatch_i, (x_batch) in enumerate(train_loader):
    # Get data
    x = x_batch[:,0,:,:]
    x = x.reshape(batch_size,1,N,N).to(DEVICE)
    # Training Discriminator
    x_outputs, mi_x_outputs = D(x)
    z = torch.randn(batch_size, n_noise).to(DEVICE)
    z_gen = G(z)
    z_outputs, mi_z_outputs = D(z_gen)
    if loss_method=="W":
      D_x_loss = torch.mean(x_outputs)
      D_z_loss = torch.mean(z_outputs)
      D_mi_x_loss = torch.mean(mi_x_outputs)
      D_mi_z_loss = torch.mean(mi_z_outputs)
      D_loss = (D_z_loss - D_x_loss) + lambda_ * (D_mi_z_loss - D_mi_x_loss)
    else:
      D_x_loss = criterion(x_outputs, D_labels)
      D_z_loss = criterion(z_outputs, D_fakes)
      D_mi_x_loss = criterion(mi_x_outputs, D_labels)
      D_mi_z_loss = criterion(mi_z_outputs, D_fakes)
      D_loss = D_x_loss + D_z_loss + lambda_ * (D_mi_x_loss + D_mi_z_loss)
    D.zero_grad()
    D_loss.backward()
    D_opt.step()
    if loss_method=="W":
      for p in D.parameters():
        p.data.clamp_(-0.01, 0.01)
    #Train Generator
    g_step += 1
    z = torch.randn(batch_size, n_noise).to(DEVICE)
    z_gen = G(z)
    z_outputs, mi_z_outputs = D(z_gen)
    if loss_method=="W":
      G_loss = -torch.mean(z_outputs)
    else:
      G_z_loss = criterion(z_outputs, D_labels)
      G_loss = G_z_loss 
    G.zero_grad()
    G_loss.backward()
    G_opt.step()    
    #Print progress:
    if step % 100 == 0:
      print("Epoch [%d/%d] - G Loss: %f - D Loss: %f" % (e, num_epochs,G_loss.item(),D_loss.item()))              
    #Increment steps
    step = step + 1

Print some example images from the final Generator:

In [None]:
def get_sample_image(G, n_noise):
    """
        save sample 25 images
    """
    z = torch.randn(25, n_noise).to(DEVICE)
    y_hat = G(z).view(25, N, N) 
    result = y_hat.cpu().data.numpy()
    img = np.zeros([5*N, 5*N])
    for j in range(5):
        img[j*N:(j+1)*N] = np.concatenate([x for x in result[j*5:(j+1)*5]], axis=-1)
    return img

G.eval()
plt.imshow(get_sample_image(G, n_noise))