In [None]:
import xarray as xr
import dask.array as da
import numpy as np
import pandas as pd
from sklearn.ensemble import IsolationForest
import plotly.graph_objs as go
import plotly.express as px
import ipywidgets as widgets
from IPython.display import display

# Step 1: Load the Dataset with Dask
ds = xr.open_mfdataset(
    "Z:/NARR/air/air.200112.nc",
    combine="by_coords",
    parallel=True,
    # Adjust chunks based on your system
    chunks={"time": 1, "level": 1, "y": 277, "x": 349}
)

# Select the 'air' variable
air = ds['air']

# Step 2: Preprocess the Data

# Stack the data to create a single 'sample' dimension
air_reshaped = air.stack(sample=('time', 'level', 'y', 'x'))

# Get the original shape before stacking
original_shape = air.shape  # (time, level, y, x)

# Total number of samples
total_samples = air_reshaped.size

# Define sampling fraction
sampling_fraction = 0.1  # 10% sampling
sample_size = int(total_samples * sampling_fraction)

# Randomly sample indices
np.random.seed(42)  # For reproducibility
sample_indices = np.random.choice(
    total_samples, size=sample_size, replace=False)

# Extract sampled data
# Check if the data is a Dask array
if isinstance(air_reshaped.values, da.Array):
    air_sampled = air_reshaped.values[sample_indices].compute()
else:
    air_sampled = air_reshaped.values[sample_indices]

# Convert flat indices back to multi-dimensional indices
time_idx, level_idx, y_idx, x_idx = np.unravel_index(
    sample_indices, original_shape)

# Extract corresponding coordinates
time_sampled = ds['time'].values[time_idx]
level_sampled = ds['level'].values[level_idx]
lat_sampled = ds['lat'].values[y_idx, x_idx]
lon_sampled = ds['lon'].values[y_idx, x_idx]

# Create a DataFrame with the sampled data
df = pd.DataFrame({
    'air': air_sampled,
    'time': pd.to_datetime(time_sampled),
    'level': level_sampled,
    'lat': lat_sampled,
    'lon': lon_sampled
})

# Feature Engineering
# Convert time to numerical format (seconds since epoch)
df['time_num'] = df['time'].astype(np.int64) // 10**9

# Select features for Isolation Forest
features = df[['air', 'level', 'time_num', 'lat', 'lon']].values

# Handle missing values by removing any rows with NaNs
mask = ~np.isnan(features).any(axis=1)
features = features[mask]
df = df[mask].reset_index(drop=True)

# Step 3: Fit Isolation Forest
iso_forest = IsolationForest(
    n_estimators=100,
    contamination='auto',
    random_state=42,
    n_jobs=-1
)
iso_forest.fit(features)

# Step 4: Get Anomaly Scores
# The decision_function provides the anomaly score
scores = iso_forest.decision_function(features)
anomaly_scores = iso_forest.score_samples(features)

# Add anomaly scores back to the DataFrame
df['anomaly_score'] = anomaly_scores

# Step 5: Visualization with Plotly and Interactive Widgets

# Define unique times and levels for the sliders
unique_times = np.sort(df['time'].unique())
unique_levels = np.sort(df['level'].unique())

# Create widgets for time and level
time_widget = widgets.SelectionSlider(
    options=unique_times,
    value=unique_times[0],
    description='Time',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    layout={'width': '800px'}
)

level_widget = widgets.Dropdown(
    options=unique_levels,
    value=unique_levels[0],
    description='Level',
    disabled=False,
)

# Define the update function for the widgets


def update_plot(time, level):
    # Filter the DataFrame based on the selected time and level
    df_subset = df[(df['time'] == time) & (df['level'] == level)]

    if df_subset.empty:
        print(f"No data available for Time: {time}, Level: {level}")
        return

    # Create the scatter geo plot
    fig = px.scatter_geo(
        df_subset,
        lon='lon',
        lat='lat',
        color='anomaly_score',
        color_continuous_scale='RdBu',
        scope='north america',
        title=f"Isolation Forest Anomaly Scores<br>Time: {
            time.strftime('%Y-%m-%d')}, Level: {level}",
        labels={'anomaly_score': 'Anomaly Score'},
        hover_data=['air', 'anomaly_score']
    )

    # Update the layout for better visualization
    fig.update_layout(
        geo=dict(
            projection=go.layout.geo.Projection(type='lambert'),
            showland=True,
            landcolor="lightgray",
            coastlinecolor="black"
        )
    )

    fig.show()


# Link the widgets to the update function
interactive_plot = widgets.interactive(
    update_plot, time=time_widget, level=level_widget)
display(interactive_plot)

# Initialize the first plot
update_plot(time_widget.value, level_widget.value)