In [1]:
import random, os, torch, numpy as np


def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
#seed_everything()

In [2]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from collections import defaultdict
import pandas as pd
#import numpy as np
import torch.fft
import subprocess
import logging
#import random
import shutil
import psutil
import sklearn
import scipy
#import torch
import copy
import yaml
import time
import tqdm
import sys
import gc

import segmentation_models_pytorch as smp

from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F
#import torch

#from holodecml.data import PickleReader, UpsamplingReader, XarrayReader, XarrayReaderLabels
#from holodecml.propagation import InferencePropagator
from holodecml.transforms import LoadTransformations
from holodecml.models import load_model
from holodecml.losses import load_loss

#import os
import warnings
warnings.filterwarnings("ignore")
import lpips

In [3]:
import sklearn, sklearn.metrics

def man_metrics(results):
    result = {}
    for metric in ["f1", "auc", 'pod', "far", "csi"]: #"man_prec", "man_recall",
        if metric == 'f1':
            score = sklearn.metrics.f1_score(results["true"], results["pred"], average = "weighted")
        elif metric == 'prec':
            score = sklearn.metrics.precision_score(results["true"], results["pred"], average = "weighted")
        elif metric == 'recall':
            score = sklearn.metrics.recall_score(results["true"], results["pred"], average = "weighted")
        elif metric == 'auc':
            try:
                score = sklearn.metrics.roc_auc_score(results["true"], results["pred"], average = "weighted")
            except:
                score = 1.0
        elif metric ==  "csi":
            try:
                TN, FP, FN, TP = sklearn.metrics.confusion_matrix(results["true"], results["pred"]).ravel()
                score = TP / (TP + FN + FP)
            except:
                score = 1
        elif metric == 'far':
            try:
                TN, FP, FN, TP = sklearn.metrics.confusion_matrix(results["true"], results["pred"]).ravel()
                score = FP / (TP + FP)
            except:
                score = 1
        elif metric == 'pod':
            try:
                TN, FP, FN, TP = sklearn.metrics.confusion_matrix(results["true"], results["pred"]).ravel()
                score = TP / (TP + FN)
            except: 
                score = 1
        result[metric] = score
        #print(metric, round(score, 3))
    return result

In [4]:
def apply_transforms(transforms, image):
    im = {"image": image}
    for image_transform in transforms:
        im = image_transform(im)
    image = im["image"]
    return image

In [5]:
import sys
sys.path.append("ML-for-Derecho")

import torch_funcs
import torch_s2s_dataset

import numpy as np
import pandas as pd
import xarray as xr
from datetime import timedelta
import matplotlib.pyplot as plt
import xskillscore as xs

In [6]:
def reverse_negone(ds, minv, maxv):
    return (((ds + 1) / 2) * (maxv - minv)) + minv

In [7]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

available_ncpus = len(psutil.Process().cpu_affinity())

# Set up the GPU
is_cuda = torch.cuda.is_available()
device = torch.device("cpu") if not is_cuda else torch.device("cuda")

In [8]:
print(device, available_ncpus)

cuda 8


In [9]:
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

In [10]:
config = "gan.yml" #"../config/gan.yml"
with open(config) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [11]:
# Set seeds for reproducibility
seed = 1000 if "seed" not in conf else conf["seed"]
seed_everything(seed)

save_loc = conf["save_loc"]
os.makedirs(save_loc, exist_ok = True)
os.makedirs(os.path.join(save_loc, "images"), exist_ok = True)
if not os.path.isfile(os.path.join(save_loc, "model.yml")):
    shutil.copyfile(config, os.path.join(save_loc, "model.yml"))

# Trainer params
train_batch_size = conf["trainer"]["train_batch_size"]
valid_batch_size = conf["trainer"]["valid_batch_size"]

epochs = conf["trainer"]["epochs"]
batches_per_epoch = conf["trainer"]["batches_per_epoch"]
Tensor = torch.cuda.FloatTensor if is_cuda else torch.FloatTensor
adv_loss = conf["trainer"]["adv_loss"]
lambda_gp = conf["trainer"]["lambda_gp"]
mask_penalty = conf["trainer"]["mask_penalty"]
regression_penalty = conf["trainer"]["regression_penalty"]
train_gen_every = conf["trainer"]["train_gen_every"]
train_disc_every = conf["trainer"]["train_disc_every"]
threshold = conf["trainer"]["threshold"]

In [12]:
# Load the preprocessing transforms
if "Normalize" in conf["transforms"]["training"]:
    conf["transforms"]["validation"]["Normalize"]["mode"] = conf["transforms"]["training"]["Normalize"]["mode"]
    conf["transforms"]["inference"]["Normalize"]["mode"] = conf["transforms"]["training"]["Normalize"]["mode"]

train_transforms = LoadTransformations(conf["transforms"]["training"])
valid_transforms = LoadTransformations(conf["transforms"]["validation"])

In [13]:
var = 'tas2m'
wks = 2

train = torch_s2s_dataset.S2SDataset(
    
    week=wks, variable=var, norm='minmax', region='fixed',
    
    minv=None, maxv=None, mnv=None, stdv=None,
    
    lon0=250., lat0=30., dxdy=32., feat_topo=True, feat_lats=True, feat_lons=True,
    
    startdt='1999-02-01', enddt='2015-12-31', homedir='/glade/scratch/molina/'
)

valid = torch_s2s_dataset.S2SDataset(
    
    week=wks, variable=var, norm='minmax', region='fixed',
    
    minv=train.min_val, maxv=train.max_val, mnv=None, stdv=None,
    
    lon0=250., lat0=30., dxdy=32., feat_topo=True, feat_lats=True, feat_lons=True,
    
    startdt='2016-01-01', enddt='2017-12-31', homedir='/glade/scratch/molina/'
)

tests = torch_s2s_dataset.S2SDataset(
    
    week=wks, variable=var, norm='minmax', region='fixed',
    
    minv=train.min_val, maxv=train.max_val, mnv=None, stdv=None,
    
    lon0=250., lat0=30., dxdy=32., feat_topo=True, feat_lats=True, feat_lons=True,
    
    startdt='2018-01-01', enddt='2020-12-31', homedir='/glade/scratch/molina/'
)

In [14]:
# train_loader = torch.utils.data.DataLoader(
#     train,
#     batch_size=train_batch_size,
#     num_workers=available_ncpus//2,
#     pin_memory=True,
#     shuffle=True)

# valid_loader = torch.utils.data.DataLoader(
#     valid,
#     batch_size=train_batch_size,
#     num_workers=available_ncpus//2,
#     pin_memory=True,
#     shuffle=False)


train_loader = DataLoader(train, batch_size=train_batch_size, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid, batch_size=train_batch_size, shuffle=False, drop_last=False)
tests_loader = DataLoader(tests, batch_size=train_batch_size, shuffle=False, drop_last=False)

### Load models

In [15]:
generator = load_model(conf["generator"]).to(device) 
discriminator = load_model(conf["discriminator"]).to(device)

In [16]:
# generator = smp.Unet(
#     encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=4,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,                      # model output channels (number of classes in your dataset)
# )

In [17]:
adv_loss = conf["trainer"]["adv_loss"]
if adv_loss == "bce":
    adversarial_loss = torch.nn.BCELoss().to(device)
    
perceptual_alex = lpips.LPIPS(net='alex').to(device)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /glade/work/schreck/py37/lib/python3.7/site-packages/lpips/weights/v0.1/alex.pth


In [18]:
optimizer_G = torch.optim.Adam(
    filter(lambda p: p.requires_grad, generator.parameters()),
    lr = conf["optimizer_G"]["learning_rate"],
    betas = (conf["optimizer_G"]["b0"], conf["optimizer_G"]["b1"]))

optimizer_D = torch.optim.Adam(
    filter(lambda p: p.requires_grad, discriminator.parameters()), 
    lr = conf["optimizer_D"]["learning_rate"], 
    betas = (conf["optimizer_D"]["b0"], conf["optimizer_D"]["b1"]))

In [19]:
def compute_gradient_penalty(discriminator, real_imgs, gen_imgs):
    """Calculates the gradient penalty loss for WGAN GP"""
    alpha = torch.rand(real_imgs.size(0), 1, 1, 1).cuda().expand_as(real_imgs)
    interpolated = Variable(alpha * real_imgs.data + (1 - alpha) * gen_imgs.data, requires_grad=True)
    out = discriminator(interpolated)[1]
    grad = torch.autograd.grad(outputs=out,
                               inputs=interpolated,
                               grad_outputs=torch.ones(out.size()).cuda(),
                               retain_graph=True,
                               create_graph=True,
                               only_inputs=True)[0]
    grad = grad.view(grad.size(0), -1)
    grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
    d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)
    return d_loss_gp

In [20]:
lr_G_decay = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=30, gamma=0.2)
lr_D_decay = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=30, gamma=0.2)

In [None]:
results = defaultdict(list)
for epoch in range(epochs):
    
    ### Train
    dual_iter = tqdm.tqdm(
        enumerate(train_loader),
        total = batches_per_epoch, 
        leave = True)
    
    train_results = defaultdict(list)
    for i, x in dual_iter:
        
        img = x["input"].squeeze(2)[:, :, :32, :32]
        label = x["label"].squeeze(2)[:, :, :32, :32]
                            
        # Adversarial ground truths
        valid = Variable(Tensor(img.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(img.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(img.type(Tensor))
                
        # Sample noise as generator input
        #z = Variable(Tensor(np.random.normal(0, 1.0, (holo_img.shape[0], 512))))
        b, c, l, w = img.shape
        z = Tensor(np.random.normal(0, 1.0, (b, 1, l, w)))
        # C-GAN-like input using the synthethic image as conditional input
        #gen_input = torch.cat([synthethic_imgs, z], 1)
        # Generate a batch of images
        
        gen_imgs = generator(z)
        
        # Discriminate the fake images
        _, verdict = discriminator(gen_imgs)
        
        # -----------------
        #  Train Generator
        # -----------------
            
        if (i + 1) % train_gen_every == 0:
            
            optimizer_G.zero_grad()
            requires_grad(generator, True)
            requires_grad(discriminator, False)
            
            # measure the generator's ability to fool the discriminator
            if adv_loss == 'wgan-gp':
                g_loss = -verdict.mean()
            elif adv_loss == 'hinge':
                g_loss = -verdict.mean()
            elif adv_loss == 'bce':
                g_loss = adversarial_loss(verdict, valid)
            
            # compute mask loss reg term
            train_results["g_loss"].append(g_loss.item())
                
            g_loss.backward()
            optimizer_G.step()
            
            # compute perception scores
            p_score_real = perceptual_alex(gen_imgs[:, :3, :, :], real_imgs[:, :3, :, :]).mean()
            train_results["p_real"].append(p_score_real.item())
            
        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        if (i + 1) % train_disc_every == 0:
        
            optimizer_D.zero_grad()
            requires_grad(generator, False)
            requires_grad(discriminator, True)

            # Measure discriminator's ability to classify real from generated samples
            _, disc_real = discriminator(real_imgs)
            _, disc_synth = discriminator(gen_imgs.detach())
            
            train_results["real_acc"].append(((disc_real > threshold) == valid).float().mean().item())
            train_results["fake_acc"].append(((disc_synth > threshold) == fake).float().mean().item())
            
            if adv_loss == 'wgan-gp':
                real_loss = -torch.mean(disc_real) 
                fake_loss = disc_synth.mean() 
            elif adv_loss == 'hinge':
                real_loss = torch.nn.ReLU()(1.0 - disc_real).mean() 
                fake_loss = torch.nn.ReLU()(1.0 + disc_synth).mean()             
            elif adv_loss == 'bce':
                real_loss = adversarial_loss(disc_real, valid) 
                fake_loss = adversarial_loss(disc_synth, fake) 
                
            d_loss = real_loss + fake_loss 
            train_results["d_loss"].append(d_loss.item())
            
            if adv_loss == 'wgan-gp':
                # Compute gradient penalty
                alpha = torch.rand(real_imgs.size(0), 1, 1, 1).cuda().expand_as(real_imgs)
                interpolated = Variable(alpha * real_imgs.data + (1 - alpha) * gen_imgs.data, requires_grad=True)
                out = discriminator(interpolated)[1]

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)
                d_loss_reg = lambda_gp * d_loss_gp
                d_loss += d_loss_reg
                train_results["d_reg"].append(d_loss_reg.item())
            
            d_loss.backward()
            optimizer_D.step()

        print_str =  f'Epoch {epoch}'
        print_str += f' D_loss {np.mean(train_results["d_loss"]):.6f}'
        if adv_loss == 'wgan-gp':
            print_str += f' D_reg {np.mean(train_results["d_reg"]):.6f}'
        print_str += f' G_loss {np.mean(train_results["g_loss"]):6f}'
        print_str += f' G_reg {np.mean(train_results["g_reg"]):6f}'
        print_str += f' p_real {np.mean(train_results["p_real"]):.4f}'
        print_str += f' real_acc {np.mean(train_results["real_acc"]):.4f}'
        print_str += f' fake_acc {np.mean(train_results["fake_acc"]):.4f}'
        dual_iter.set_description(print_str)
        dual_iter.refresh()
        
        if i == batches_per_epoch and i > 0:
            break
        
    # Save the dataframe to disk
    results["epoch"].append(epoch)
    results["d_loss"].append(np.mean(train_results["d_loss"]))
    if adv_loss == 'wgan-gp':
        results["d_loss_reg"].append(np.mean(train_results["d_reg"]))
    results["g_loss"].append(np.mean(train_results["g_loss"]))
    results["g_reg"].append(np.mean(train_results["g_reg"]))
    results["perception"].append(np.mean(train_results["p_real"]))
    results["real_acc"].append(np.mean(train_results["real_acc"]))
    results["fake_acc"].append(np.mean(train_results["syn_acc"]))

    metric = "custom"
    metric_value = results["perception"][-1]
    results[metric].append(metric_value)
    
    df = pd.DataFrame.from_dict(results).reset_index()
    df.to_csv(f'{conf["save_loc"]}/training_log.csv', index=False)
    
    save_image(real_imgs.data[:9], f'{conf["save_loc"]}/images/real_{epoch}.png', nrow=3, normalize=True)
    save_image(gen_imgs.data[:9], f'{conf["save_loc"]}/images/pred_{epoch}.png', nrow=3, normalize=True)

    # Save the model
    state_dict = {
        'epoch': epoch,
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        'generator_state_dict': generator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
    }
    torch.save(state_dict, f'{conf["save_loc"]}/best.pt')
    
    # Anneal learning rates 
    lr_G_decay.step(epoch)
    lr_D_decay.step(epoch)

Epoch 0 D_loss -0.044851 D_reg 241.917496 G_loss -0.279516 G_reg    nan p_real 0.5877 real_acc 0.2344 fake_acc 0.8750:  13%|█▎        | 13/100 [00:29<03:20,  2.30s/it]
Epoch 1 D_loss 0.010148 D_reg 534.407715 G_loss -0.275837 G_reg    nan p_real 0.5952 real_acc 0.2344 fake_acc 0.7344:  13%|█▎        | 13/100 [00:26<02:55,  2.02s/it]
Epoch 2 D_loss 0.038002 D_reg 277.065369 G_loss -0.284760 G_reg    nan p_real 0.5966 real_acc 0.1562 fake_acc 0.7500:  13%|█▎        | 13/100 [00:25<02:52,  1.98s/it]
Epoch 3 D_loss -0.029194 D_reg 110.951462 G_loss -0.293972 G_reg    nan p_real 0.5921 real_acc 0.3438 fake_acc 0.7656:  13%|█▎        | 13/100 [00:26<02:55,  2.02s/it]
Epoch 4 D_loss 0.030986 D_reg 295.695953 G_loss -0.293794 G_reg    nan p_real 0.5877 real_acc 0.2031 fake_acc 0.7031:  13%|█▎        | 13/100 [00:26<02:57,  2.04s/it]
Epoch 5 D_loss -0.031382 D_reg 146.958588 G_loss -0.265491 G_reg    nan p_real 0.5840 real_acc 0.2500 fake_acc 0.7656:  13%|█▎        | 13/100 [00:26<02:58,  2.05s