In [1]:
import pickle
from pathlib import Path
from smount_predictors import SeamountHelp
import plotly.express as px
import plotly.graph_objects as go
import simplekml

In [2]:
class PipelinePredictor:
    def __init__(self, model, clusterer):
        self.model = model
        self.clusterer = clusterer

    def predict(self, data):
        predictions = self.model.predict(data)
        data['class'] = predictions
        self.clusterer.fit_predict(data[['lon', 'lat', 'class']])
        data['cluster'] = self.clusterer.labels_
        return data
    

model = pickle.load(open('out/script_accuracy_balenced_model.pkl', 'rb'))
points = (40, 55, -160, -130)
data = SeamountHelp.readAndFilterGRD(Path('data/swot_masked.grd'), points[:2], points[2:]).to_dataframe().reset_index()
predictions = model.predict(data)

In [None]:
px.scatter(x=predictions['lon'], y=predictions['lat'], color=predictions['class']).update_layout(
    width=800,
    height=800,
)

In [None]:
px.scatter(predictions[predictions['cluster'] != 0], x='lon', y='lat', color='cluster').update_layout(
    width=800,
    height=800,
).update_coloraxes(colorscale='HSV')

In [None]:
raw = px.scatter(predictions, x='lon', y='lat', color='z').update_layout(
    width=800,
    height=800,
)
raw.show()

In [None]:
centers = predictions.groupby('cluster').mean()
centers['size'] = predictions.groupby('cluster').size()
centers = centers.drop([0, -1])
centers = centers[['lon', 'lat', 'size']]
mounts = go.Figure()
mounts.add_trace(go.Scatter(x=predictions['lon'], y=predictions['lat'], mode='markers', marker=dict(color=predictions['z'], colorscale='Viridis')))
mounts.add_trace(go.Scatter(x=centers['lon'], y=centers['lat'], mode='markers', marker=dict(color='red')))
mounts.update_layout(
    width=800,
    height=800,
)

In [None]:
kml = simplekml.Kml()
for i, row in centers.iterrows():
    kml.newpoint(name=f'{i}', coords=[(row['lon'], row['lat'])])
kml.save('out/predicted_mounts.kml')

In [None]:
! open out/predicted_mounts.kml