In [1]:
import sys
import numpy as np
import xarray as xr
import itertools as it
import os.path
import multiprocessing as mp
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.rcsetup as rcsetup
import tensorflow as tf
import sklearn
from keras.layers import *
from keras.regularizers import l1,l2
from keras.models import Model
from keras.callbacks import EarlyStopping
from keras.losses import mean_squared_error
from keras import backend as K
from livelossplot.keras import PlotLossesCallback
from dcgan import DCGAN

CPU_COUNT = mp.cpu_count()
print('{} CPUs'.format(CPU_COUNT))
print('numpy version: {}'.format(np.__version__))
print('xarray version: {}'.format(xr.__version__))
print('matplotlib version: {}'.format(matplotlib.__version__))
print('sklearn version: {}'.format(sklearn.__version__))
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))

8 CPUs
numpy version: 1.15.4
xarray version: 0.11.3
matplotlib version: 3.0.2
sklearn version: 0.20.1


Using TensorFlow backend.


In [2]:
use_data_caching = True
keep_original_data_loaded = False
experiment = 'rcp26'

model_names = xr.open_dataset('./data/{0}/{0}_m6_tas_pr.nc'.format(experiment)).coords['model']

X_all_filename = './data/X_{}_all.npy'.format(experiment)
X_train_filename = './data/X_{}_train.npy'.format(experiment)
X_valid_filename = './data/X_{}_valid.npy'.format(experiment)
ls_mask_filename = './data/lsmask_cmip3_144.nc'
if use_data_caching and os.path.isfile(X_all_filename):
    print('loading cached data file {}'.format(X_all_filename))
    X_all = np.load(X_all_filename)
else:
    print('loading and processing dataset for experiment {}'.format(experiment))
    X_ds = xr.open_dataset('./data/{0}/{0}_m6_tas_pr.nc'.format(experiment))
    X_arr = X_ds.to_array().transpose('model', 'time', 'lat', 'lon', 'variable')
    X_all = X_arr.values
    nan_count = np.product(X_all.shape) - np.count_nonzero(~np.isnan(X_all))
    print('found {} NaN values in data'.format(nan_count))
    assert nan_count == 0
    # standardize variables
    for i in range(X_all.shape[-1]):
        stddev = np.std(X_all[:,:,:,:,i])
        X_all[:,:,:,:,i] = X_all[:,:,:,:,i] / stddev
    if use_data_caching:
        np.save(X_all_filename, X_all)
    
N_models, N_time, N_lat, N_lon, N_vars = X_all.shape
BATCH_SIZE = N_models*N_time // 20
    
N_models_train = N_models - 2
N_models_valid = N_models - N_models_train
N_train = N_models_train*N_time
N_valid = N_models_valid*N_time

if os.path.isfile(ls_mask_filename):
    print('loading land/sea mask data file')
    X_ls_mask = xr.open_dataset(ls_mask_filename)
    X_sea_mask = 1 - X_ls_mask.to_array().values.reshape((1, N_lat, N_lon, 1))

if use_data_caching and os.path.isfile(X_train_filename):
    print('loading cached data file {}'.format(X_train_filename))
    X_train = np.load(X_train_filename)
else:
    print('generating training data')
    X_train = np.hstack(X_all[:N_models_train]).reshape((N_models_train*N_time, N_lat, N_lon, N_vars))
    if use_data_caching:
        np.save(X_train_filename, X_train)

if use_data_caching and os.path.isfile(X_valid_filename):
    print('loading cached data file {}'.format(X_valid_filename))
    X_valid = np.load(X_valid_filename)
else:
    print('generating validation data')
    X_valid = np.hstack(X_all[N_models_train:]).reshape((N_models_valid*N_time, N_lat, N_lon, N_vars))
    if use_data_caching:
        np.save(X_valid_filename, X_valid)

if not keep_original_data_loaded:
    del X_all
    
X_train_sea = X_train * X_sea_mask
X_valid_sea = X_valid * X_sea_mask

print('Loaded training data with shape: {}'.format(X_train.shape))
print('Loaded validation data with shape: {}'.format(X_valid.shape))

loading cached data file ./data/X_rcp26_all.npy
loading land/sea mask data file
loading cached data file ./data/X_rcp26_train.npy
loading cached data file ./data/X_rcp26_valid.npy
Loaded training data with shape: (11520, 72, 144, 2)
Loaded validation data with shape: (5760, 72, 144, 2)


In [3]:
from mpl_toolkits.basemap import Basemap

def plot_image_map(X, cmap="seismic", title="", min_max=None):
    basemap = Basemap(lat_0=0, lon_0=180)
    img = basemap.imshow(X, origin='lower', cmap=cmap)
    basemap.drawcoastlines()
    if min_max is not None:
        img.set_clim(vmin=min_max[0], vmax=min_max[1])
    plt.colorbar(fraction=0.035, pad=0.04)
    plt.title(title)

def plot_var_spatial(X, model, name="model", cmap='brg', t=0, c=0):
    X_t = np.expand_dims(X[t], axis=0)
    X_pred = model.predict(X_t)
    fig = plt.figure(figsize=(16,14))
    avg_x = np.mean(X_t)
    std_x = np.std(X_t)
    min_x = avg_x - 2*std_x
    max_x = avg_x + 2*std_x
    plt.subplot(1,2,1)
    plot_image_map(X_t[0,:,:,c], cmap=cmap, title='{}, t={}'.format(name, t))
    plt.clim(vmin=min_x, vmax=max_x)
    plt.subplot(1,2,2)
    plot_image_map(X_pred[0,:,:,c], cmap=cmap, title='{}, t={}, reconstructed'.format(name, t))
    plt.clim(vmin=min_x, vmax=max_x)
    plt.show()

    
def plot_err_spatial(X, model, target_shape=(1, N_lat, N_lon, N_vars), cmap="Reds", name="", t=0,c=0):
    plt.figure(figsize=(7,6))
    X_t = np.expand_dims(X[t], axis=0)
    X_pred = model.predict(X_t)
    X_err = np.abs(X_t - X_pred)[0,:,:,c]
    plot_image_map(X_err.reshape((N_lat, N_lon)), cmap=cmap, title='{}, total absolute error, t={}'.format(name, t))
    
def plot_var_time(X, model, model_name="model", name="", c=0):
    X_orig = X
    X_pred = np.zeros((0, *X_orig.shape[1:]))
    batch_size = X_orig.shape[0] // 10
    for i in range(10):
        X_next = model.predict_on_batch(X_orig[i*batch_size:(i+1)*batch_size])
        X_pred = np.concatenate([X_pred, X_next], axis=0)
    plt.plot(range(X_orig.shape[0]), np.mean(X_orig, axis=(1,2))[:,c], c='blue')
    plt.title('{}, global average, original'.format(name))
    plt.plot(range(X_orig.shape[0]), np.mean(X_pred, axis=(1,2))[:,c], ':', c='red')
    plt.title('{}, global average, reconstructed'.format(name))
    plt.legend(['original', model_name])
    
def show_activations(X, model, layer, output_shape, t=0, name=""):
    inputs = [K.learning_phase()] + model.inputs
    layer_fn = K.function(inputs, [layer.output])
    layer_out = layer_fn([0, X])[0]
    z_0 = layer_out[t].reshape(output_shape)
    plt.imshow(z_0, origin='lower')
    plt.title(name)
    plt.colorbar()
    
def show_conv_activations(X, model, layer, output_shape, t=0, c=0, name=""):
    inputs = [K.learning_phase()] + model.inputs
    layer_fn = K.function(inputs, [layer.output])
    layer_out = layer_fn([0, X])[0]
    print(layer_out.shape)
    if layer_out.shape[-1] == 32:
        n_rows, n_cols = 4,8
    elif layer_out.shape[-1] == 16:
        n_rows, n_cols = 4,4
    elif layer_out.shape[-1] == 8:
        n_rows, n_cols = 2,4
    else:
        raise Exception('unsupported channel count')
    z_0 = layer_out[t].reshape(output_shape)
    plt.figure(figsize=(8*n_cols,6*n_rows))
    for i in range(layer_out.shape[-1]):
        plt.subplot(n_rows, n_cols, i+1)
        plt.imshow(z_0[:,:,i], origin='lower')
        plt.title('{}, c={}'.format(name, c))
        plt.colorbar(fraction=0.030, pad=0.04)

In [4]:
gan = DCGAN(img_shape=X_train.shape[1:], latent_dims=128)

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_1 (Conv2D)            (None, 36, 72, 32)        608       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 36, 72, 32)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 36, 72, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 18, 36, 32)        9248      
_________________________________________________________________
zero_padding2d_1 (ZeroPaddin (None, 19, 37, 32)        0         
_________________________________________________________________
batch_normalization_1

In [5]:
gan.train(X_train, epochs=100, batch_size=144)

Instructions for updating:
Use tf.cast instead.


  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 1.104907, acc.: 38.19%] [G loss: 0.513509]
1 [D loss: 0.976704, acc.: 46.88%] [G loss: 0.678277]
2 [D loss: 0.836587, acc.: 52.78%] [G loss: 0.779451]
3 [D loss: 0.883030, acc.: 48.96%] [G loss: 0.903299]
4 [D loss: 0.946261, acc.: 42.36%] [G loss: 0.919827]
5 [D loss: 0.969149, acc.: 43.06%] [G loss: 0.849023]
6 [D loss: 1.004154, acc.: 42.01%] [G loss: 0.771782]
7 [D loss: 0.991606, acc.: 37.50%] [G loss: 0.906908]
8 [D loss: 1.066399, acc.: 36.46%] [G loss: 0.976532]
9 [D loss: 1.023476, acc.: 34.03%] [G loss: 0.817208]
10 [D loss: 0.971809, acc.: 42.36%] [G loss: 0.846376]
11 [D loss: 0.950653, acc.: 42.01%] [G loss: 0.798149]
12 [D loss: 1.040205, acc.: 34.72%] [G loss: 0.869416]
13 [D loss: 0.994012, acc.: 40.97%] [G loss: 0.876984]
14 [D loss: 0.901178, acc.: 47.92%] [G loss: 1.008154]
15 [D loss: 0.932318, acc.: 45.83%] [G loss: 0.967163]
16 [D loss: 0.864751, acc.: 51.04%] [G loss: 1.146248]
17 [D loss: 0.832770, acc.: 50.35%] [G loss: 1.006929]
18 [D loss: 0.863254

KeyboardInterrupt: 