In [None]:
from pathlib import Path
import pickle
from joblib import Parallel, delayed
from itertools import product
from smount_predictors.src.SeamountHelp import PipelinePredictor
from smount_predictors import SeamountHelp
import numpy as np
import pandas as pd
import simplekml
import xarray as xr

In [None]:
longitude_pairs = []
for lon in range(-180, 180, 20):
    longitude_pairs.append((lon, lon + 20))

latitude_pairs = []
for lat in range(-60, 90, 20):
    latitude_pairs.append((lat, lat + 20))


latlons = list(product(longitude_pairs, latitude_pairs))
latlons = latlons

In [None]:
model = pickle.load(open('out/cluster_tuned_model.pkl', 'rb'))
data_p = Path('data/vgg_swot_masked.grd')

In [None]:
def recluster(model, predictions):
    def filter_clust_size(data: pd.DataFrame):
        def circle_ratio(data: pd.DataFrame) -> pd.DataFrame:
            circle = abs(data['lat'].max() - data['lat'].min()) / abs(data['lon'].max() - data['lon'].min())
            return circle
        circle_range = data.copy()
        divs = circle_range.groupby('cluster').apply(circle_ratio)
        divs = divs[(divs > 1 - np.std(divs)) & (divs < 1 + np.std(divs))]
        circle_range = circle_range.loc[(~circle_range['cluster'].isin(divs.index)) & (circle_range['cluster'] != -1)]
        return circle_range

    def norm_z(data: pd.DataFrame):
        data['norm_z'] = (data['z'] - data['z'].min()) / (data['z'].max() - data['z'].min())
        return data

    clust_filt = predictions.copy()

    clust_filt = filter_clust_size(clust_filt)
    clust_filt = norm_z(clust_filt)
    clust_pred = clust_filt.loc[(clust_filt['cluster'] != -1) & (clust_filt['norm_z'] > 0.5)]


    model.clusterer.fit_predict(clust_pred[['lon', 'lat']])
    clust_filt.loc[(clust_filt['cluster'] != -1) & (clust_filt['norm_z'] > 0.5), 'cluster'] = model.clusterer.labels_ + predictions['cluster'].max() + 1
    clust_filt.loc[~((clust_filt['cluster'] != -1) & (clust_filt['norm_z'] > 0.5)), 'cluster'] = -1
    checked = predictions.loc[~predictions.index.isin(clust_filt.index)]
    checked = pd.concat([checked, clust_filt])
    return checked

In [None]:
def predict_zone(zone):
    lon = zone[0]
    lat = zone[1]
    # print(f'Predicting zone {lon}-{lat}')
    data = SeamountHelp.readAndFilterGRD(data_p, lat, lon).to_dataframe().reset_index()
    # masked_data = exclude_interface(data).to_dataframe().reset_index()
    if np.all(data['z'] == 0):
        data['class'] = 0
        data['cluster'] = -1
        return data
    zone_pred = model.predict(data[['lon', 'lat', 'z']])
    return zone_pred

predictions = Parallel(n_jobs=-1)(delayed(predict_zone)(zone) for zone in latlons)
for idx, df in enumerate(predictions):  # ensure non-overlapping cluster numbers
    df['cluster'][df['cluster'] != -1] = df['cluster'][df['cluster'] != -1] + ((idx**2) * len(np.unique(df['cluster'])))
predictions = pd.concat(predictions)
predictions = recluster(model, predictions)
predictions['lon'] = np.degrees(predictions['lon'])
predictions['lat'] = np.degrees(predictions['lat'])
global_predictions = xr.Dataset.from_dataframe(predictions.set_index(['lon', 'lat']))
global_predictions.to_netcdf('out/global_predictions.nc')

In [None]:
mounts = predictions.groupby('cluster').mean().reset_index()
kml = simplekml.Kml()
for i, row in mounts.iterrows():
    kml.newpoint(name=f'{row.cluster}', coords=[(row.lon, row.lat, row.z)])
kml.save('out/global_mounts.kml')

In [None]:
! open out/global_mounts.kml

In [None]:
mounts.shape

In [None]:
pd.read_csv('data/all.xyhrdnc', sep=' ', header=None, names=['lon', 'lat', 'z', 'h', 'r', 'd', 'n', 'c'])