In [1]:
from pathlib import Path
import pickle
from joblib import Parallel, delayed
from itertools import product
from smount_predictors.src.SeamountHelp import PipelinePredictor
from smount_predictors import SeamountHelp
import numpy as np
import pandas as pd
import simplekml
import xarray as xr

In [2]:
longitude_pairs = []
for lon in range(-180, 180, 20):
    longitude_pairs.append((lon, lon + 20))

latitude_pairs = []
for lat in range(-60, 90, 20):
    latitude_pairs.append((lat, lat + 20))


latlons = list(product(longitude_pairs, latitude_pairs))
latlons = latlons[:4]

In [3]:
model = pickle.load(open('out/cluster_tuned_model.pkl', 'rb'))
data_p = Path('data/vgg_swot_masked.grd')

In [47]:
def recluster(model: PipelinePredictor, predictions: pd.DataFrame):
    def filter_clust_size(data: pd.DataFrame):
        def circle_ratio(data: pd.DataFrame):
            if abs(data['lon'].max() - data['lon'].min()) == 0:
                return 0
            if data.loc[0, 'cluster'] == -1:
                return 1
            circle = abs(data['lat'].max() - data['lat'].min()) / abs(data['lon'].max() - data['lon'].min())
            mass = (abs(data['lat'].max() - data['lat'].min()) * abs(data['lon'].max() - data['lon'].min())) / data.shape[0]
            mass = mass / (np.sqrt(np.pi) / 2)
            return circle * mass
        circle_range = data
        divs = circle_range.groupby('cluster').apply(circle_ratio)
        divs = divs[(divs > np.mean(divs) - np.std(divs)) & (divs < np.mean(divs) + np.std(divs))]
        circle_range = circle_range.loc[(~circle_range['cluster'].isin(divs.index)) & (circle_range['cluster'] != -1)]
        return circle_range

    def norm_z(data: pd.DataFrame):
        data['norm_z'] = (data['z'] - data['z'].min()) / (data['z'].max() - data['z'].min())
        return data
    predictions = predictions.reset_index()
    clust_filt = predictions.copy()

    clust_filt = filter_clust_size(clust_filt)
    clust_filt = norm_z(clust_filt)
    recluster_index = (clust_filt['cluster'] != -1) & (clust_filt['norm_z'] > 0.5)
    clust_pred = clust_filt.loc[recluster_index]


    model.clusterer.fit_predict(clust_pred[['lon', 'lat']])
    clust_filt.loc[recluster_index, 'cluster'] = model.clusterer.labels_ + predictions['cluster'].max() + 1
    clust_filt.loc[~recluster_index, 'cluster'] = -1
    clust_filt.set_index(['lon', 'lat'], inplace=True)
    predictions.set_index(['lon', 'lat'], inplace=True)
    predictions.loc[clust_filt.index, 'cluster'] = clust_filt.loc[:, 'cluster']
    return predictions.reset_index()

In [None]:
def predict_zone(zone):
    lon = zone[0]
    lat = zone[1]
    data = SeamountHelp.readAndFilterGRD(data_p, lat, lon).to_dataframe().reset_index()
    if np.all(data['z'] == 0):
        data['class'] = 0
        data['cluster'] = -1
        return data
    zone_pred = model.predict(data[['lon', 'lat', 'z']])
    return zone_pred

predictions = Parallel(n_jobs=-1)(delayed(predict_zone)(zone) for zone in latlons)
for idx, df in enumerate(predictions):  # ensure non-overlapping cluster numbers
    assert isinstance(df, pd.DataFrame)  # assertion for code linter typing features
    df.loc[df['cluster'] != -1, 'cluster'] = df.loc[df['cluster'] != -1, 'cluster'] + ((idx ** 2) * len(np.unique(df['cluster'])))
predictions = pd.concat(predictions)
predictions = recluster(model, predictions)
predictions['lon'] = np.degrees(predictions['lon'])
predictions['lat'] = np.degrees(predictions['lat'])
global_predictions = xr.Dataset.from_dataframe(predictions.set_index(['lon', 'lat']))
# global_predictions.to_netcdf('out/global_predictions.nc')

In [44]:
predictions = recluster(model, predictions)
predictions['lon'] = np.degrees(predictions['lon'])
predictions['lat'] = np.degrees(predictions['lat'])
global_predictions = xr.Dataset.from_dataframe(predictions.set_index(['lon', 'lat']))
# global_predictions.to_netcdf('out/global_predictions.nc')

              lon       lat          z  class  cluster
0       -3.141447 -1.047052  29.121763      0       -1
1       -3.141447 -1.046761  22.904421      1       -1
2       -3.141447 -1.046470  13.345173      1       -1
3       -3.141447 -1.046179   6.889595      0       -1
4       -3.141447 -1.045889   4.297480      0       -1
...           ...       ...        ...    ...      ...
5759995 -2.792672  0.347757  11.134361      0       -1
5759996 -2.792672  0.348048   9.470730      0       -1
5759997 -2.792672  0.348339   9.232473      0       -1
5759998 -2.792672  0.348630   6.752796      0       -1
5759999 -2.792672  0.348920   3.071087      0       -1

[5358224 rows x 5 columns]
            lon       lat          z  class  cluster
48    -3.141447 -1.033089  18.658752      1        0
49    -3.141447 -1.032799  28.955509      1        0
50    -3.141447 -1.032508  38.195141      1        0
51    -3.141447 -1.032217  39.845409      1        0
53    -3.141447 -1.031635  31.929838      1    

  divs = circle_range.groupby('cluster').apply(circle_ratio)


In [45]:
mounts = predictions.groupby('cluster').mean().reset_index()
kml = simplekml.Kml()
for i, row in mounts.iterrows():
    kml.newpoint(name=f'{row.cluster}', coords=[(row.lon, row.lat, row.z)])
kml.save('out/global_mounts.kml')

In [46]:
! open out/global_mounts.kml