In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from datetime import datetime, timedelta

df = pd.read_csv('ensemble_predictions_2050.csv')
df['time'] = pd.to_datetime(df['time'])
if (df['depth'] > 0).any():

    df['depth'] = df['depth'] * -1


global_min_lat = df['latitude'].min()
global_max_lat = df['latitude'].max()
global_min_lon = df['longitude'].min()
global_max_lon = df['longitude'].max()
global_min_depth = df['depth'].min()
global_max_depth = df['depth'].max()
global_min_temp = df['sea_water_temperature'].min()
global_max_temp = df['sea_water_temperature'].max()



start_date = datetime(2023, 6, 1)
end_date = datetime(2050, 12, 31)
current_date = start_date
while current_date <= end_date:

    month_data = df[(df['time'].dt.year == current_date.year) & 
                    (df['time'].dt.month == current_date.month)]
    

    if not month_data.empty:

        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection='3d')


        sc = ax.scatter(month_data['longitude'], month_data['latitude'], 
                month_data['depth'], c=month_data['sea_water_temperature'], 
                cmap='coolwarm', vmin=global_min_temp, vmax=global_max_temp)



        plt.colorbar(sc)

        ax.set_xlabel('Longitude')
        ax.set_ylabel('Latitude')
        ax.set_zlabel('Depth')

        ax.set_xlim(global_min_lon, global_max_lon)
        ax.set_ylim(global_min_lat, global_max_lat)
        ax.set_zlim(global_min_depth, global_max_depth)
        ax.set_title(f'Sea Water Temperature for {current_date.strftime("%B %Y")}')

        plt.savefig(f'sea_temp_img_2050\sea_temperature_{current_date.strftime("%Y_%m")}.png', bbox_inches='tight')
        plt.close()

    current_date += timedelta(days=31 - current_date.day + 1)

print("Plots generated.")