In [3]:
import sys
import torch
import ot
import argparse
import os

import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import geopandas as gpd
import matplotlib.pyplot as plt
import pandas as pd
import cartopy.crs as ccrs

from tqdm.auto import trange
from scipy.stats import gaussian_kde
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.colors as mcolors
from datasets import *
sys.path.append("../..")

from utils.nf import realnvp
from utils.misc import *
from utils.nf.exp_map import create_NF
from utils.vmf import rand_vmf
from methods.swd import swd

device='cuda:2'

ImportError: cannot import name 'realnvp' from 'utils.nf' (/home/tranh4/Projects/s3w/s3wd/src/utils/nf/__init__.py)

In [3]:
import os
print("Current Working Directory:", os.getcwd())


Current Working Directory: /home/tranh4/Projects/Official S3W/s3wd/src/experiments/Density Estimation


In [None]:
models = []
# dataset = "quakes_all" 
dataset = "flood"
# dataset = "fire"
for i in range(5):
    model = create_NF(3, 48, 100).to(device)
    # model_file = f"./new_weights/nf_density_{dataset}_ari50_lr0e-05_epochs20001_{i}.model"
    model_file = f"./new_weights/nf_density_{dataset}_ri_s3w_lr0e-05_epochs20001_{i}.model"
    # model_file = "./new_weights/nf_density_quakes_all_ssw_lr0e-1_epochs20000_0.model"
    if os.path.exists(model_file):
        model.load_state_dict(torch.load(model_file, map_location=device))
        model.eval()
        models.append(model)
    else:
        print(f"File {model_file} not found, skipping.")

eps = 1e-5

config = {
    "training_size": 0.7,
    "validation_size": 0,
    "test_size": 0.3,
    "name": dataset
}

In [None]:
# models = []
# # dataset = "quakes_all" 
# # dataset = "flood"
# dataset = "fire"

# model = create_RealNVP(25,10).to(device)
#     # model_file = f"./weights/nf_density_{dataset}_ri_s3w_lr0e-001_{i}.model"
# model_file = f"./weights/nf_density_stereo_{dataset}_stereo_0.model"
# if os.path.exists(model_file):
#     model.load_state_dict(torch.load(model_file, map_location=device))
#     model.eval()
#     models.append(model)
# else:
#     print(f"File {model_file} not found, skipping.")

# eps = 1e-5

# config = {
#     "training_size": 0.7,
#     "validation_size": 0,
#     "test_size": 0.3,
#     "name": dataset
# }

In [None]:
handler = EarthDataHandler(config, eps)
train_loader, val_loader, test_loader = handler.get_dataloaders(15000, 15000)

In [None]:
test_densities = []
kdes = []
for model in models:
    densities = []
    kdes_=[]
    for data,_ in test_loader:
        data = data.to(device)
        _, log_density = model(data)
        densities.append(log_density.detach().cpu().numpy())
    test_densities.append(np.concatenate(densities))


for (data1, _), (data2, _) in zip(train_loader,train_loader):
    lats,lons = xyz_to_latlon(data1)
    lats_t, lons_t = xyz_to_latlon(data2)
    kde=gaussian_kde(np.vstack((lons.detach().cpu().numpy(),lats.detach().cpu().numpy())), bw_method=0.1)
    kde_densities_ = kde(np.vstack((lons_t.detach().cpu().numpy(), lats_t.detach().cpu().numpy())))
    kdes.append(kde_densities_)


test_densities = np.array(test_densities)
kdes = np.array(kdes)

In [None]:
kdes=(kdes - np.min(kdes)) / (np.max(kdes) - np.min(kdes))

test_densities = np.array(test_densities)
densities = np.exp(test_densities)
densities = np.mean(densities, axis=0)
densities=(densities - np.min(densities)) / (np.max(densities) - np.min(densities))

In [None]:
plt.close() 
fig, ax = plt.subplots(figsize=(40, 15), subplot_kw={'projection': ccrs.PlateCarree()})
ax.set_aspect('auto', adjustable='datalim')

ax.set_xlim(-180, 180)
ax.set_ylim(-90, 90)

vmin, vmax = 0, 1
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
for model in models:
    for data, _ in test_loader:
        xyz = data.numpy()
        lat, lon = xyz_to_latlon(xyz)
        sc = ax.scatter(lon, lat, c=densities, cmap='jet', norm=norm, s=100, transform=ccrs.PlateCarree())

ax.coastlines()

for spine in ax.spines.values():
    spine.set_linewidth(5) 
    
cbar = plt.colorbar(sc, orientation='vertical', pad=0.03, aspect=10)
cbar.ax.tick_params(labelsize=30) 

plt.savefig('RI1_Flood.pdf', bbox_inches='tight', dpi=100)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(40, 15), subplot_kw={'projection': ccrs.PlateCarree()})
ax.set_aspect('auto', adjustable='datalim')

ax.set_xlim(-180, 180)
ax.set_ylim(-90, 90)
vmin, vmax = 0, 1
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
for model in models:
    for data, _ in test_loader:
        xyz = data.numpy()
        lat, lon = xyz_to_latlon(xyz)
        sc = ax.scatter(lon, lat, c=kdes, cmap='jet', s=100, vmin=0, vmax=1, transform=ccrs.PlateCarree())

ax.coastlines()

for spine in ax.spines.values():
    spine.set_linewidth(5) 
    
cbar = plt.colorbar(sc, orientation='vertical', pad=0.03, aspect=10)#, extend='max')
cbar.set_ticks(np.linspace(0, 1, num=6)) 
cbar.ax.set_yticklabels(['0', '0.2', '0.4', '0.6', '0.8', '1'], fontsize=50) 
plt.savefig('KDE_Fire.pdf', bbox_inches='tight', dpi=100)
plt.show()