In [None]:
import numpy as np
import os
import pickle
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset
import torch.nn.functional as F
import torchvision.models
import netCDF4 as nc4
from scipy.ndimage import zoom
from datetime import datetime
import utils as ut
import dependency as dep
import models64 as models
device = torch.device("cuda" if torch.cuda.is_available() else "CPU")

In [None]:
def visual(Bar=False, save_path=None, **images):
    """Plot images in one row with a single color bar and optionally save the plot."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    
    # Find the common color range
    vmin = min(image.min() for image in images.values())
    vmax = max(image.max() for image in images.values())
    
    # Generate random locations once
    first_image = list(images.values())[0]
    x_locs = np.random.randint(0, first_image.shape[1], 3)
    y_locs = np.random.randint(0, first_image.shape[0], 3)
    
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        if '_' in name:
            base_name, suffix = name.rsplit('_', 1)  # Split name and suffix
            plt.title(f"{base_name}$_{{{suffix}}}$")
        else:
            plt.title(name)
        
        img = plt.imshow(image, cmap='viridis', vmin=vmin, vmax=vmax)  # Use common color range
        
    if Bar:
        cbar = plt.colorbar(img, ax=plt.gca(), orientation='vertical', fraction=0.02, pad=0.04)
        cbar.ax.tick_params(labelsize=15)
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=1000)
    
    plt.show()


In [None]:
#Dataset paths
GPM = '/GPM dataset path/' #GPM file path
TIG = '/TIGG dataset path/'  #TIGG file path

In [None]:
# First do the evaluation for TIGG dataset
GroundTrutn, TIGG_Pred, Date_Position, TargetDate = dep.Evaluation(GPM, TIG, 
                                            GPM_zoom=True, GPM_desired_shape = (64, 64),
                                            TIG_zoom=True, TIG_desired_shape = (64, 64))

Random_Prediction, Random_GT = dep.RandomEvaluation(GPM, TargetDate, GPM_zoom=True, GPM_desired_shape = (64, 64))

In [None]:
## For model. First load the model and the test_loader
model, test_loader, Gmean, Gstd, norm = dep.load_model()
# Ensemble of trained models
model3, test_loader3, _, _, _ = dep.load_model()

In [None]:
# For model. First select the date that is available in TIGG and then create the loader and evaluate
subset_inputs, subset_targets = dep.extract_ordered_subset_from_loader(test_loader, Date_Position)
subset_dataset = TensorDataset(subset_inputs, subset_targets)

subset_inputs3, subset_targets3 = dep.extract_ordered_subset_from_loader(test_loader3, Date_Position)
subset_dataset3 = TensorDataset(subset_inputs3, subset_targets3)
# Create a DataLoader with the subset dataset based on TIGG data
subset_loader = DataLoader(subset_dataset, batch_size=1, shuffle=False)
subset_loader3 = DataLoader(subset_dataset3, batch_size=1, shuffle=False)
def ModelPredict(subset_loader, model):
    Data = list() # For prediction
    Target = list() # For Ground Truth
    C = 0
    for input, target in subset_loader:
            input, target = input.to(device),target.to(device)
            output = model(input)
            T,P = target[0,0].cpu().detach().numpy(), output[0,0].cpu().detach().numpy()
            if norm == 'znorm':
                T,P = ut.iznorm(T, Gmean, Gstd), ut.iznorm(P, Gmean, Gstd)
            if norm == 'norm':
                T,P = ut.inorm(T, Gmean, Gstd), ut.inorm(P, Gmean, Gstd)
            Data.append(P)
            Target.append(T)
    return np.array(Data), np.array(Target)
ModelPrediction, Model_GT = ModelPredict(subset_loader, model)
ModelPrediction3, _ = ModelPredict(subset_loader3, model3)

In [None]:
#Ensemble function
th= 0.1 #(Weighted)
weights = np.array([th, 1.0-th])
model_outputs = [TIGG_Pred, ModelPrediction3]
stacked_outputs = np.stack(model_outputs, axis=0)
print(stacked_outputs.shape)
ensemble_avg = np.tensordot(weights, stacked_outputs, axes=(0, 0))
print(ensemble_avg.shape)

I = [49, 14, 34] # select some image number to be visualized
for r in range(0,3):
    #i = np.random.randint(0, len(ModelPrediction))
    i = I[r]
    print(i)
    PATH = None#'./Result/SamplePredImage'+str(i)+'.png'
    visual(Bar=True, save_path=PATH, GT=Model_GT[i],NWP=TIGG_Pred[i], CLIM=Random_Prediction[i], UNET_30=ModelPrediction[i], UNET_12=ModelPrediction3[i],
           Ens = ensemble_avg[i])

In [None]:
#The Evaluation. MAE (GroundTruth, Prediction)
TIGG_MAE,random_MAE =dep.mae(Model_GT,TIGG_Pred),dep.mae(Random_GT,Random_Prediction)
model_MAE, model_MAE3=dep.mae(Model_GT, ModelPrediction),dep.mae(Model_GT, ModelPrediction3)
print('MAE results of \n 1. TIGG, 2. Climatology, 3. MidNight Model, 4. Night Model, 5. Ensemble Model')
print(TIGG_MAE,'|', random_MAE,'|', model_MAE,'|',model_MAE3, dep.mae(Model_GT,ensemble_avg))

In [None]:
# Define custom bin edges
bin_edges = [-1, 0, 1] + list(np.linspace(3, 5, num=48))
# Target and Data should be numpy arrays of shape (N, W, H)
hist1, bin_edges1 = dep.compute_histogram(GroundTrutn, TIGG_Pred, bin_edges, Print=False, Hist=False)
hist2, bin_edges2 = dep.compute_histogram(Random_GT, Random_Prediction, bin_edges, Print=False, Hist=False)
hist3, bin_edges3 = dep.compute_histogram(Model_GT, ModelPrediction, bin_edges, Print=False, Hist=False)
hist6, bin_edges6 = dep.compute_histogram(Model_GT, ModelPrediction3, bin_edges, Print=False, Hist=False)
hist7, bin_edges7 = dep.compute_histogram(Model_GT, ensemble_avg, bin_edges, Print=False, Hist=False)

print('Number of samples in the range: A: -1 to 0; B: 0 to 1')
print('TIGG:', hist1[0],hist1[1], '\n Climatology:', hist2[0],hist2[1],
      '\n MidNight:', hist3[0],hist3[1], '\n Night:', hist6[0],hist6[1], '\n Ensemble:', hist7[0],hist7[1])

In [None]:
#The Evaluation: TS/CSI, FPR, POD (GroundTruth, Prediction, threshold)
#1/100 (0.01) of an inch of rain – The first measurable amount of rainfall reported by The National Weather Service
th = 0.5 # 0r 10
thresholds = [th] #, 0.1, 0.2, 1
print('CSI, Precision, Recall, F1 of TIGG model at th=',th)
print(dep.score_calculate(GroundTrutn,TIGG_Pred, thresholds))
print('CSI, Precision, Recall, F1 of Climatology model at th=',th)
print(dep.score_calculate(Random_GT,Random_Prediction, thresholds))
print('CSI, Precision, Recall, F1 of MidNight model at th=',th)
print(dep.score_calculate(Model_GT, ModelPrediction, thresholds))
print('CSI, Precision, Recall, F1 of Night model at th=',th)
print(dep.score_calculate(Model_GT, ModelPrediction3, thresholds))
print('CSI, Precision, Recall, F1 of Ensemble model at th=',th)
print(dep.score_calculate(Model_GT, ensemble_avg, thresholds))