# Example 4: Learning auxiliary task loss weights
## Setup

Install / load required dependencies.

In [None]:
import os
import datetime
import sys
import requests
from urllib.request import urlretrieve
import urllib.request, json 

import torch
import torchvision.transforms as transforms
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

import numpy as np
import pandas as pd
import scipy
from decimal import Decimal, getcontext

import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler 
from sklearn.preprocessing import normalize
from sklearn import metrics

Check GPU device availability.

In [None]:
torch.cuda.get_device_name(0)

Helper functions, as found in `src/utils.py`.

In [None]:
#Normalize data values. Default is [0,1] range; if min_val = -1, range is [-1,1]
def normal(tensor,min_val=-1):
  t_min = torch.min(tensor)
  t_max = torch.max(tensor)
  if t_min == 0 and t_max == 0:
    return torch.tensor(tensor)
  if min_val == -1:
    tensor_norm = 2 * ((tensor - t_min) / (t_max - t_min)) - 1
  if min_val== 0:
    tensor_norm = ((tensor - t_min) / (t_max - t_min))
  return torch.tensor(tensor_norm)

#Light-weight Local Moran's I for tensor data, requiring a sparse weight matrix input. 
#This can be used when there is no need to re-compute the weight matrix at each step
def lw_tensor_local_moran(y,w_sparse,na_to_zero=True,norm=True,norm_min_val=-1):
  y = y.reshape(-1)
  n = len(y)
  n_1 = n - 1
  z = y - y.mean()
  sy = y.std()
  z /= sy
  den = (z * z).sum()
  zl = torch.tensor(w_sparse * z)
  mi = n_1 * z * zl / den
  if na_to_zero==True:
    mi[torch.isnan(mi)] = 0
  if norm==True:
    mi = normal(mi,min_val=norm_min_val)
  return torch.tensor(mi)

#Batch version of lw_tensor_local_moran
#Computes the (normalized) local Moran's I for an input batch
def batch_lw_tensor_local_moran(y_batch,w_sparse,na_to_zero=True,norm=True,norm_min_val=-1):
  batch_size = y_batch.shape[0]
  N = y_batch.shape[3]
  mi_y_batch = torch.zeros(y_batch.shape)
  for i in range(batch_size):
    y = y_batch[i,:,:,:].reshape(N,N)
    y = y.reshape(-1)
    n = len(y)
    n_1 = n - 1
    z = y - y.mean()
    sy = y.std()
    z /= sy
    den = (z * z).sum()
    zl = torch.tensor(w_sparse * z)
    mi = n_1 * z * zl / den
    if na_to_zero==True:
      mi[torch.isnan(mi)] = 0
    if norm==True:
      mi = normal(mi,min_val=norm_min_val)
    mi_y_batch[i,0,:,:] = mi.reshape(N,N)
  return mi_y_batch    

#Downsampling by average pooling (needed for computing the multi-res Moran's I)
downsample = nn.AvgPool2d(kernel_size=2)

Load sparse spatial weight matrices

In [None]:
%%capture

urlretrieve('https://github.com/konstantinklemmer/sxl/raw/master/data/w/w_sparse_64.npz','w_sparse_64.npz')
urlretrieve('https://github.com/konstantinklemmer/sxl/raw/master/data/w/w_sparse_32.npz','w_sparse_32.npz')
urlretrieve('https://github.com/konstantinklemmer/sxl/raw/master/data/w/w_sparse_16.npz','w_sparse_16.npz')
urlretrieve('https://github.com/konstantinklemmer/sxl/raw/master/data/w/w_sparse_8.npz','w_sparse_8.npz')
urlretrieve('https://github.com/konstantinklemmer/sxl/raw/master/data/w/w_sparse_4.npz','w_sparse_4.npz')

In [None]:
w_sparse_64 = scipy.sparse.load_npz('w_sparse_64.npz')
w_sparse_32 = scipy.sparse.load_npz('w_sparse_32.npz')
w_sparse_16 = scipy.sparse.load_npz('w_sparse_16.npz')
w_sparse_8 = scipy.sparse.load_npz('w_sparse_8.npz')
w_sparse_4 = scipy.sparse.load_npz('w_sparse_4.npz')

##Data

As customary for GAN training, data is normalized in the range `[-1,1]`. The local Moran's I of the data can be computed at this step already, to avoid further computational burden during training.


### Petrel grid (32x32)

Download and prepare data. We will work with the *PetrelGrid* dataset here.

In [None]:
with urllib.request.urlopen("https://github.com/konstantinklemmer/sxl/raw/master/data/list_petrel.json") as url:
    train_y = np.array(json.loads(url.read().decode()))

N = 32
t = train_y.shape[0]
data = torch.zeros(t,2,N,N)
for i in range(t-1):
    train_y_t = torch.tensor(train_y[i,:,:])
    train_y_t = torch.tensor(normal(train_y_t.reshape(-1)))
    data[i,0,:,:] = train_y_t.reshape(N,N)
data[:,1,:,:] = batch_lw_tensor_local_moran(data[:,0,:,:].reshape(t,1,N,N),w_sparse_32,norm_min_val=-1).reshape(t,N,N)

## Training


Define the model architectures for Discriminator (**D**) and Generator (**G**). In this examples we use a EDGAN with MRES-MAT.

In [None]:
###
# EDGAN
###

class Discriminator_EDGAN_MRES_MAT_32(nn.Module):
    """
        Convolutional Discriminator
    """
    def __init__(self,nc=1,ndf1=32):
        super(Discriminator_EDGAN_MRES_MAT_32,self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(nc,ndf1,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ndf1),
          nn.LeakyReLU(0.2,inplace=True),
          nn.Conv2d(ndf1,ndf1*2,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ndf1*2),
          nn.LeakyReLU(0.2,inplace=True),
          nn.Conv2d(ndf1*2,ndf1*4,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ndf1*4),
          nn.LeakyReLU(0.2,inplace=True)
        )
        self.output_t1 = nn.Sequential(nn.Conv2d(ndf1*4,1,kernel_size=4,stride=1,padding=0),
          nn.Sigmoid()
        )
        self.output_t2 = nn.Sequential(nn.Conv2d(ndf1*4,1,kernel_size=4,stride=1,padding=0),
          nn.Sigmoid()
        )
        self.output_t3 = nn.Sequential(nn.Conv2d(ndf1*4,1,kernel_size=4,stride=1,padding=0),
          nn.Sigmoid()
        )
        self.output_t4 = nn.Sequential(nn.Conv2d(ndf1*4,1,kernel_size=4,stride=1,padding=0),
          nn.Sigmoid()
        )
    def forward(self, x, y=None):
        x_d1 = downsample(x)
        x_d2 = downsample(x_d1)
        mi_x = batch_lw_tensor_local_moran(x.detach().cpu(),w_sparse_32)
        mi_x_d1 = batch_lw_tensor_local_moran(x_d1.detach().cpu(),w_sparse_16)
        mi_x_d2 = batch_lw_tensor_local_moran(x_d2.detach().cpu(),w_sparse_8)
        mi_x = mi_x.to(DEVICE)
        mi_x_d1 = mi_x_d1.to(DEVICE)
        mi_x_d2 = mi_x_d2.to(DEVICE)
        mi_x_d1 = nn.functional.interpolate(mi_x_d1,scale_factor=2,mode="nearest")
        mi_x_d2 = nn.functional.interpolate(mi_x_d2,scale_factor=4,mode="nearest")
        y_ = self.conv(x)
        mi_y_ = self.conv(mi_x)
        mi_y_d1 = self.conv(mi_x_d1)
        mi_y_d2 = self.conv(mi_x_d2)
        y_ = self.output_t1(y_)
        mi_y_ = self.output_t2(mi_y_)
        mi_y_d1 = self.output_t3(mi_y_d1)
        mi_y_d2 = self.output_t4(mi_y_d2)
        y_ = y_.view(y_.size(0), -1)
        mi_y_ = mi_y_.view(mi_y_.size(0), -1)
        mi_y_d1 = mi_y_d1.view(mi_y_d1.size(0), -1)
        mi_y_d2 = mi_y_d2.view(mi_y_d2.size(0), -1)
        return y_, mi_y_, mi_y_d1, mi_y_d2

class Generator_EDGAN_32(nn.Module):
    """
        Encoder-Decoder Generator
    """
    def __init__(self, input_size=100, nc=1, ngf=N):
        super(Generator_EDGAN_32, self).__init__()
        assert IMAGE_DIM[0] % 2**4 == 0, 'Should be divided 16'
        self.init_dim = (IMAGE_DIM[0] // 2**4, IMAGE_DIM[1] // 2**4)
        self.fc = nn.Sequential(
            nn.Linear(input_size, N*N),
            nn.ReLU(),
        )
        self.encoder = nn.Sequential(nn.Conv2d(nc,ngf,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ngf),
          nn.LeakyReLU(0.2,inplace=True),
          nn.Conv2d(ngf,ngf*2,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ngf*2),
          nn.LeakyReLU(0.2,inplace=True),
          nn.Conv2d(ngf*2,ngf*4,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ngf*4),
          nn.LeakyReLU(0.2,inplace=True)
        )
        self.decoder = nn.Sequential(nn.ConvTranspose2d(ngf*4,ngf*2,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ngf*2),
          nn.ReLU(),
          nn.ConvTranspose2d(ngf*2,ngf,kernel_size=4,stride=2,padding=1),
          nn.BatchNorm2d(ngf),
          nn.ReLU(),
          nn.ConvTranspose2d(ngf,nc,kernel_size=4,stride=2,padding=1),
          nn.Tanh()
        )   
    def forward(self, x, y=None):
        x = x.view(x.size(0), -1)
        y_ = self.fc(x)
        y_ = y_.view(y_.size(0), 1, N, N)
        y_ = self.encoder(y_)
        y_ = self.decoder(y_)
        return y_

Define training configuration:

- `train_split`: % of data to use for training (in case held-out data is needed for evaluation)
- `batch_size`: training batch size
- `num_epochs`: number of training epochs 

In [None]:
getcontext().prec = 3
torch.manual_seed(99)

### DEFINE EXPERIMENT SETTINGS ###
train_split = Decimal(0.8) # 80% training data
batch_size = 32 # define batch size
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Train on GPU or CPU
num_epochs = 500 # Number of training epochs
###

Define the training loop and train the model! To learn $\lambda$ we need to tweak the Discriminator loss function in order to avoid extreme values and allow convergence:

$$\min_G \max_D \mathcal{L}_{MRES-MAT} (D,G) = \mathcal{L}_{GAN} (D,G) + \\
1 / (2 exp(\lambda)) (\mathcal{L}_{AT_{1}}^{(D)} + \dots + \mathcal{L}_{AT_{N}}^{(D)}) + \lambda$$

We can then add $\lambda$ to the parameter list of the optimizer and learn it throughout training.

In [None]:
#Prepare input
test_split = Decimal(1 - train_split)
n = data.shape[0]
N = data.shape[3]
IMAGE_DIM = (N,N,1)
train_set, test_set = torch.utils.data.random_split(data, [int(n * train_split), int(n * test_split)])
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,drop_last=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True,drop_last=True)
#Set Discriminator and Generator
D = Discriminator_EDGAN_MRES_MAT_32().to(DEVICE)
G = Generator_EDGAN_32().to(DEVICE)
#Prepare labels
D_labels = torch.ones([batch_size, 1]).to(DEVICE) # Discriminator Label to real
D_labels = D_labels - 0.1 #This can be skipped
D_fakes = torch.zeros([batch_size, 1]).to(DEVICE) # Discriminator Label to fake
#Prepare training
criterion = nn.BCELoss() #Binary cross entropy loss
lamb = torch.tensor([0.1],requires_grad=True, device="cuda") #Initiate lambda
lr = 0.001 #Set learning rate
#Set Parameters
weight_list = []
bias_list = []
last_weight_list = []
last_bias_list = []
loss_weight_list = [lamb]
for name, value in D.named_parameters():
    if 'classifier' in name:
        if 'weight' in name:
            last_weight_list.append(value)
        elif 'bias' in name:
            last_bias_list.append(value)
    else:
        if 'weight' in name:
            weight_list.append(value)
        elif 'bias' in name:
            bias_list.append(value)
#Initiate opzimizer
D_opt = torch.optim.Adam([{'params': weight_list, 'lr': lr},
                              {'params': bias_list, 'lr': lr},
                              {'params': last_weight_list, 'lr': lr},
                              {'params': last_bias_list, 'lr': lr},
                              {'params': loss_weight_list, 'lr': lr}],betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.001, betas=(0.5, 0.999))
# Utilities
step = 0
n_noise = 100
loss_d = []
loss_g = []
lambdas = []
### TRAINING
for e in range(num_epochs):
    # Within each iteration, we will go over each minibatch of data
    for minibatch_i, (x_batch) in enumerate(train_loader):
      # Get data
      x = x_batch[:,0,:,:]
      x = x.reshape(batch_size,1,N,N).to(DEVICE)
      ### Training Discriminator
      x_outputs, mi_x_outputs, mi_x_d1_outputs, mi_x_d2_outputs = D(x)
      z = torch.randn(batch_size, n_noise).to(DEVICE)
      z_gen = G(z)
      z_outputs, mi_z_outputs, mi_z_d1_outputs, mi_z_d2_outputs = D(z_gen)
      D_x_loss = criterion(x_outputs, D_labels)
      D_z_loss = criterion(z_outputs, D_fakes)
      D_mi_x_loss = criterion(mi_x_outputs, D_labels)
      D_mi_z_loss = criterion(mi_z_outputs, D_fakes)
      D_mi_x_d1_loss = criterion(mi_x_d1_outputs, D_labels)
      D_mi_z_d1_loss = criterion(mi_z_d1_outputs, D_fakes)
      D_mi_x_d2_loss = criterion(mi_x_d2_outputs, D_labels)
      D_mi_z_d2_loss = criterion(mi_z_d2_outputs, D_fakes)
      #Discriminator loss
      lambda_ = torch.Tensor.exp(lamb)
      D_loss = D_x_loss + D_z_loss + (1 / (2 * lambda_)) * (D_mi_x_loss + D_mi_z_loss + D_mi_x_d1_loss + D_mi_z_d1_loss + D_mi_x_d2_loss + D_mi_z_d2_loss) + lamb
      D.zero_grad()
      D_loss.backward()
      D_opt.step()
      ### Train Generator
      z = torch.randn(batch_size, n_noise).to(DEVICE)
      z_gen = G(z)
      z_outputs, mi_z_outputs, mi_z_d1_outputs, mi_z_d2_outputs = D(z_gen)
      G_z_loss = criterion(z_outputs, D_labels)
      #Generator loss
      G_loss = G_z_loss 
      G.zero_grad()
      G_loss.backward()
      G_opt.step()                  
      step = step + 1
      #Save losses / lambas
      loss_d.append(D_loss.item())
      loss_g.append(G_loss.item())
      lambdas.append(lambda_)
      #Print progress
      if step % 250 == 0:
        print('Epoch: [%d/%d] - G Loss: %f - D Loss: %f - Lambda: %f' % (e+1, num_epochs, G_loss.item(), D_loss.item(), lambda_))

Plot losses and lambda throughout training.

In [None]:
fig, ((ax1,ax2)) = plt.subplots(1, 2, figsize=(11, 3))
ax1.plot(loss_g, "orange",alpha=.65)
ax1.plot(loss_d, "green",alpha=.65)
ax1.set_ylim(0,15)
ax1.set_title("Losses", fontsize=15, fontweight='bold')
ax1.legend(('G Loss', 'D Loss'),loc='upper left')

ax2.plot(lambdas, "blue",alpha=.8,label = r'$exp(\lambda)$')
ax2.set_ylim(0.2,1.5)
ax2.set_title("Lambda", fontsize=15, fontweight='bold')
ax2.legend(loc='upper left')