In [None]:
from Network import Generator
from tensorflow import keras
from keras.layers import Activation, BatchNormalization, UpSampling2D, Flatten
from keras.layers import Dense, Input, Conv2D, LeakyReLU, PReLU, add
from keras.models import Model
from keras.initializers import RandomNormal
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Softmax
import numpy as np
from tqdm import tqdm
from numpy import load
import tensorflow.keras.backend as K
import tensorflow as tf
from keras.callbacks import Callback, ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import xarray as xr
from tensorflow.keras.utils import to_categorical

# Set memory growth for GPUs
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)
# Number of classes for classification
n_classes = 4

# Function to create a weighted categorical cross-entropy loss
def weighted_categorical_crossentropy(weights):
   # reshape weights to a 1D tensor
   weights = weights.reshape((1, 1, 1, n_classes))

   # define an internal function for calculating the weighted categorical cross-entropy loss
   # -> this is the function that will be called when the loss is evaluated
   # -> it takes the true labels and the predicted labels as input and returns the loss
   # -> labels refer to the catagories (0, 1, 2, 3) and not the one-hot encoded labels
   def wcce(y_true, y_pred):
      # use the keras backend to calculate the categorical cross-entropy loss
      Kweights = K.constant(weights)
      y_true = K.cast(y_true, y_pred.dtype)
      return K.categorical_crossentropy(y_true, y_pred) * K.sum(y_true * Kweights, axis=-1)

   return wcce

# Class weights for the weighted categorical cross-entropy loss
class_weights = np.array([4, 19, 23, 56])  # Inverse percentage of classes
class_loss = weighted_categorical_crossentropy(weights=class_weights)

# Function to create the Fractions Skill Score (FSS) loss
def make_FSS_loss(mask_size):
   def my_FSS_loss(y_true, y_pred):
      # Discretize y_true and y_pred to binary values (0/1) or soft discretization
      want_hard_discretization = False
      cutoff = 0.5  # Cutoff value for discretization

      if want_hard_discretization:
         y_true_binary = tf.where(y_true > cutoff, 1.0, 0.0)
         y_pred_binary = tf.where(y_pred > cutoff, 1.0, 0.0)
      else:
         c = 10  # Steepness of sigmoid function
         y_true_binary = tf.math.sigmoid(c * (y_true - cutoff))
         y_pred_binary = tf.math.sigmoid(c * (y_pred - cutoff))

      # Calculate densities using average pooling
      pool1 = tf.keras.layers.AveragePooling2D(pool_size=(mask_size, mask_size), strides=(1, 1), padding="same")
      y_true_density = pool1(y_true_binary)
      n_density_pixels = tf.cast((tf.shape(y_true_density)[1] * tf.shape(y_true_density)[2]), tf.float32)

      pool2 = tf.keras.layers.AveragePooling2D(pool_size=(mask_size, mask_size), strides=(1, 1), padding="same")
      y_pred_density = pool2(y_pred_binary)

      # Calculate Mean Squared Error (MSE) for densities
      MSE_n = tf.keras.losses.MeanSquaredError()(y_true_density, y_pred_density)

      # Calculate reference MSE for normalization
      O_n_squared_sum = tf.reduce_sum(tf.keras.layers.Flatten()(tf.keras.layers.Multiply()([y_true_density, y_true_density])))
      M_n_squared_sum = tf.reduce_sum(tf.keras.layers.Flatten()(tf.keras.layers.Multiply()([y_pred_density, y_pred_density])))
      MSE_n_ref = (O_n_squared_sum + M_n_squared_sum) / n_density_pixels

      # Avoid division by zero
      my_epsilon = tf.keras.backend.epsilon()
      if want_hard_discretization:
         if MSE_n_ref == 0:
            return MSE_n
         else:
            return MSE_n / MSE_n_ref
      else:
         return MSE_n / (MSE_n_ref + my_epsilon)

   return my_FSS_loss

# Mask size for FSS loss
mask_size = 3

# Image shapes
image_shape_hr = (96, 132, 1)  # High resolution image shape
image_shape_lr = (8, 11, 13)  # Low resolution image shape
downscale_factor = 12

# Paths to data
PATH = "./Data/"  # Change to your own path

# Load training and validation data
reforecast_train = load(PATH + "X_train_ensemble.npy")
yhr_train = load(PATH + "y_hr_train.npy")
reforecast_val = load(PATH + "X_val_ensemble.npy")
yhr_val = load(PATH + "y_hr_val.npy")
reanalysis_class_train = load(PATH + "y_class_train.npy")
reanalysis_class_val = load(PATH + "y_class_val.npy")

# Convert class labels to categorical
reanalysis_class_train = to_categorical(reanalysis_class_train, num_classes=n_classes)
reanalysis_class_val = to_categorical(reanalysis_class_val, num_classes=n_classes)

# Training function
def train(epochs, batch_size):
   x_train_lr = reforecast_train
   y_train_hr = yhr_train
   y_train_class = reanalysis_class_train

   x_val_lr = reforecast_val
   y_val_hr = yhr_val
   y_val_class = reanalysis_class_val

   batch_count = int(x_train_lr.shape[0] / batch_size)

   # Initialize and compile the generator model
   generator = Generator(image_shape_lr).generator()
   generator.compile(
      loss=[class_loss, make_FSS_loss(mask_size)],
      optimizer=Adam(learning_rate=0.0001, beta_1=0.9),
      loss_weights=[0.01, 1.0],
      metrics=["mae", "mse"],
   )

   # Open a file to log losses
   loss_file = open("losses.txt", "w+")
   loss_file.close()

   # Training loop
   for e in range(1, epochs + 1):
      print("-" * 15, "Epoch %d" % e, "-" * 15)

      for _ in tqdm(range(batch_count)):
         rand_nums = np.random.randint(0, x_train_lr.shape[0], size=batch_size)

         x_lr = x_train_lr[rand_nums]
         y_hr = y_train_hr[rand_nums]
         y_class = y_train_class[rand_nums]

         gen_loss = generator.train_on_batch(x_lr, [y_class, y_hr])

      gen_loss = str(gen_loss)
      val_loss = generator.evaluate(x_val_lr, [y_val_class, y_val_hr], verbose=0)
      val_loss = str(val_loss)

      # Log losses
      loss_file = open("losses.txt", "a")
      loss_file.write("epoch%d : generator_loss = %s; validation_loss = %s\n" % (e, gen_loss, val_loss))
      loss_file.close()

      # Save model checkpoints
      if e <= 10:
         if e % 5 == 0:
            generator.save("gen_model%d.h5" % e)
      else:
         if e % 10 == 0:
            generator.save("gen_model%d.h5" % e)

# Train the model
train(5, 64)


I0000 00:00:1734185351.098052   59912 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 4080 MB memory:  -> device: 0, name: NVIDIA GeForce GTX 1660 Ti with Max-Q Design, pci bus id: 0000:01:00.0, compute capability: 7.5
W0000 00:00:1734185351.326980   63972 gpu_backend_lib.cc:579] Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may result in compilation or runtime failures, if the program we try to run uses routines from libdevice.
Searched for CUDA in the following directories:
  ./cuda_sdk_lib
  ipykernel_launcher.runfiles/cuda_nvcc
  ipykern/cuda_nvcc
  
  /usr/local/cuda
  /home/gmankali/.local/lib/python3.12/site-packages/tensorflow/python/platform/../../../nvidia/cuda_nvcc
  /home/gmankali/.local/lib/python3.12/site-packages/tensorflow/python/platform/../../../../nvidia/cuda_nvcc
  /home/gmankali/.local/lib/python3.12/site-packages/tensorflow/python/platform/../../cuda
  .
You can choose the search directory by setting xla_gpu_c

--------------- Epoch 1 ---------------


  0%|                                                  | 0/319 [00:00<?, ?it/s]I0000 00:00:1734185357.185839   59912 service.cc:148] XLA service 0x557fa0f0e3c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1734185357.185891   59912 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce GTX 1660 Ti with Max-Q Design, Compute Capability 7.5
2024-12-14 17:09:17.288576: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
E0000 00:00:1734185357.844231   59912 cuda_dnn.cc:522] Loaded runtime CuDNN library: 9.1.0 but source was compiled with: 9.3.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
E0000 0