In [None]:
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import netCDF4
import abc
from os import listdir
from os.path import isfile, join

In [5]:
# from ocean_navigation_simulator.problem import C3Problem

In [None]:
import ocean_navigation_simulator.utils.plotting_utils as plot_utils
from ocean_navigation_simulator.utils.simulation_utils import convert_to_lat_lon_time_bounds, get_current_data_subset

class BaseProblem(metaclass=abc.ABCMeta):
    """A path planning problem for a Planner to solve.

    Attributes:
        x_0:
            The starting state, represented as (lat, lon, battery_level).
            Note that time is implemented as absolute time in POSIX.
        x_T:
            The target state, represented as (lon, lat).
            # TODO: currently we do point-2-point navigation though ultimately we'd like to do point to region
            this to be a set representation (point-2-region) because that is the more general formulation.
        t_0:
            A timezone aware datetime object of the absolute starting time of the platform at x_0

        platform_config_dict:
            A dict specifying the platform parameters, see the repos 'configs/platform.yaml' as example.

        # TO IMPLEMENT USAGE/FOR FUTURE
        forecast_delay_in_h:
            The hours of delay when a forecast becomes available
            e.g. forecast starts at 1st of Jan but only available from HYCOM 48h later on 3rd of January
        noise:
            # TODO: optionally implement a way to add noise to the hindcasts

        # TO REVIEW/THINK ABOUT
        x_t_tol:
            Radius around x_T that when reached counts as "target reached"
            # Note: not used currently as the sim config has that value too.
    """

    def __init__(self, x_0, x_T, t_0, platform_config_dict, plan_on_gt= False,
                 forecast_delay_in_h=0., noise=None, x_t_tol=0.1):

        # Plan on GT
        self.plan_on_gt = plan_on_gt
        
        # Need to be derived in the child_classes
        self.hindcast_grid_dict = None
        self.forecasts_dicts = None
        self.data_access = None
        self.local_hindcast_file = None

        # Basic check of inputs
        if len(x_0) != 3 or len(x_T) !=2:
            raise ValueError("x_0 should be (lat, lon, battery) and x_T (lat, lon)")
            
        # check t_0
        if t_0.tzinfo is None:
            print("Assuming input t_0 is in UTC time.")
            t_0 = t_0.replace(tzinfo=timezone.utc)
        elif t_0.tzinfo != timezone.utc:
            raise ValueError("Please provide t_0 as UTC or naive datetime object.")

        # Log start, goal and forecast delay
        self.x_0 = x_0 + [t_0.timestamp()]
        self.x_T = x_T
        self.forecast_delay_in_h = forecast_delay_in_h

        # derive relative batter dynamics variables from config_dict
        self.dyn_dict = self.derive_platform_dynamics(platform_config_dict)

    def __repr__(self):
        """Returns the string representation of a Problem, to be used for debugging.

        Returns:
            A String
        """
        # Print problem specs
        print("Navigate from {} at time {} to {}.".format(
            self.x_0[:3], datetime.utcfromtimestamp(self.x_0[3]), self.x_T
        ))
        print("Simulate with GT current files from {} to {}".format(
            self.hindcast_grid_dict['gt_t_range'][0],
            self.hindcast_grid_dict['gt_t_range'][1]))
        if self.plan_on_gt:
            print("Planning on GT.")
        else:
            print("Planning on {} Forecast files starting from {} to {}".format(
                len(self.forecasts_dicts),
                self.forecasts_dicts[0]['t_range'][0],
                self.forecasts_dicts[-1]['t_range'][0]))
        
        return ""

    
    def viz(self, time=None, video=False, filename=None, cut_out_in_deg=0.8, html_render=None):
        """Visualizes the Hindcast file with the ocean currents in a plot or a gif for a specific time or time range.

        Input Parameters:
        - time: the time to visualize the ocean currents as a datetime.datetime object if
                None, the visualization is at the t_0 time of the problem.
        - video: if True a matplotlib animation is created if filename is not None then it's saved, otherwise displayed
        - filename: a string for filepath and name with ending either '.gif' or '.mp4' under which it is saved
        - cut_out_in_deg: if None, the full fieldset is visualized, otherwise provide a float e.g. 0.5 to plot only
                a box of the x_0 and x_T including a 0.5 degrees outer buffer.

        Returns:
            None
        """
        
        # Step 0: Find the time, lat, lon bounds for data_subsetting
        t_interval, lat_interval, lon_interval = convert_to_lat_lon_time_bounds(self.x_0, self.x_T,
                                                                        deg_around_x0_xT_box=cut_out_in_deg,
                                                                        temp_horizon_in_h=temp_horizon_viz_in_h)

        print("Note only the GT file is currently visualized")
        # Step 1: get the data_subset for plotting (flexible query can be with local or C3 file)
        grids_dict, u_data, v_data = get_current_data_subset(t_interval, lat_interval, lon_interval,
                                                            data_type='H', access=self.data_access,
                                                            file=self.local_hindcast_file)
        
        def add_ax_func(ax, time=None, x_0=self.x_0[:2], x_T=self.x_T[:2]):
            del time
            ax.scatter(x_0[0], x_0[1], c='r', marker='o', s=200, label='start')
            ax.scatter(x_T[0], x_T[1], c='g', marker='*', s=200, label='goal')
            plt.legend(loc='upper right')

        # if we want to visualize with video
        if time is None and video:
            # create animation with extra func
            plot_utils.viz_current_animation(grids_dict['t_grid'], grids_dict, u_data, v_data,
                                             interval=200, ax_adding_func=add_ax_func, html_render=html_render,
                                             save_as_filename=filename)
        # otherwise plot static image
        else:
            if time is None:
                time = datetime.fromtimestamp(problem.x_0[3])
            # plot underlying currents at time
            ax = plot_utils.visualize_currents(time.timestamp(), grids_dict, u_data, v_data, autoscale=True, plot=False)
            # add the start and goal position to the plot
            add_ax_func(ax)
            plt.show()


    def get_hindcast_grid_dict(self):
        return None

    @abc.abstractmethod
    def get_forecast_dicts(self, forecast_files_list):
        """ Takes in a list of files and returns a list of tuples with:
        (start_time_posix, end_time_posix, grids, file) sorted according to start_time_posix
        """
        
    def derive_platform_dynamics(self, platform_specs):
        """Derives the relative battery capacity dynamics (from 0-1) based on absolute physical values.
        Input:
            platform_specs      a dictionary containing the required platform specs.
        Returns:
            A dictionary of settings for the Problem, i.e. {'charge': __, 'energy': __, 'u_max': __}
        """

        # derive calculation
        cap_in_joule = platform_specs['battery_cap'] * 3600
        energy_coeff = (platform_specs['drag_factor'] * (1 / platform_specs['motor_efficiency'])) / cap_in_joule
        charge_factor = platform_specs['avg_solar_power'] / cap_in_joule
        platform_dict = {'charge': charge_factor, 'energy': energy_coeff, 'u_max': platform_specs['u_max']}

        return platform_dict

In [None]:
# The C3 File-Based approach
class C3Problem(BaseProblem):
    def __init__(self, x_0, x_T, t_0, platform_config_dict, plan_on_gt = False,
                 forecast_delay_in_h=0., noise=None, x_t_tol=0.1):
        
        # initialize the Base Problem
        super().__init__(x_0, x_T, t_0, platform_config_dict, plan_on_gt, forecast_delay_in_h, noise, x_t_tol)
        
        # Derive the hindcast_grid_dict and the list forecast_dicts
        self.hindcast_grid_dict = self.get_hindcast_grid_dict()
        if not self.plan_on_gt:
            self.forecasts_dicts = self.get_forecast_dicts()
            
        # log variables for data_acces
        self.data_access = 'C3'
        
    def get_hindcast_grid_dict(self):
        """Helper function to create the hindcast grid dict from the multiple files in the C3 DB.
        The idea is: how many consecutive daily hindcast files do we have starting from t_0.
        """
        # Step 1: get required file references and data from C3 file DB
        # Step 1.1: Getting time and formatting for the db query
        start = datetime.fromtimestamp(self.x_0[3], timezone.utc)

        # Step 1.2: Getting correct range of nc files from database
        filter_string = 'start>=' + '"'+ start.strftime("%Y-%m-%d") + '"' + \
                        ' && status==' + '"' + 'downloaded' + '"'
        objs_list = c3.HindcastFile.fetch({'filter':filter_string, "order": "start"}).objs

        # some basic sanity checks
        if objs_list is None:
            print("No Hindcast files in DB")
            return {}
#             raise ValueError("No Hindcast files in the database for and after the selected start_time")
    
        # get spatial coverage
        y_range = [objs_list[0].subsetOptions.geospatialCoverage.start.latitude, objs_list[0].subsetOptions.geospatialCoverage.end.latitude]
        x_range = [objs_list[0].subsetOptions.geospatialCoverage.start.longitude, objs_list[0].subsetOptions.geospatialCoverage.end.longitude]

        # get time_range by iterating over the return elements that are consecutive/exactly 1h apart
        starts_list = [obj.start for obj in objs_list]
        ends_list =  [obj.end for obj in objs_list]
        for idx in range(len(starts_list)-1):
            if ends_list[idx] + timedelta(hours=1) == starts_list[idx + 1]:
                continue
            else:
                break
        time_range = [starts_list[0], ends_list[idx+1]]

        # greate dict
        return {"gt_t_range": time_range, "gt_y_range": y_range, "gt_x_range": x_range}
    
    
    def get_forecast_dicts(self):
        """Helper function to create a list of dicts of the available forecasts from the C3 DB.
        Starting from the most recent at t_0 in the future."""
        
        t_0_as_datetime = datetime.fromtimestamp(self.x_0[3], timezone.utc)
        
        # Step 1: get required file references and data from C3 file DB
        # get all relevant forecasts including most recent at t_0 and all the ones after
        from_run_onwards = t_0_as_datetime - timedelta(days=1)
        filter_string = 'runDate>=' + '"'+ from_run_onwards.strftime("%Y-%m-%dT%H:%M:%S") + '"' 
        objs_list = c3.HycomFMRC.fetch(spec={'include':"[this, fmrcFiles.file]",
                                             'filter': filter_string,
                                             "order": "runDate"}
                                      ).objs

        # basic sanity check
        if objs_list is None:
            raise ValueError("No forecast runs in the database for and after the selected start_time")
        # check if first forecast file contains t_0
        if not (objs_list[0].timeCoverage.start <  t_0_as_datetime and objs_list[0].timeCoverage.end > t_0_as_datetime):
            raise ValueError("First Forecast File retreived does not contain t_0.")

        # Step 2: create a list of dicts with one dict for each run
        forecast_dicts = []
        for run in objs_list:
            t_range = [run.timeCoverage.start, run.timeCoverage.end]
            forecast_dicts.append({'t_range': t_range, 'file': run.fmrcFiles[0].file.url})

        # sorting after t_range start (doubling because already in db query but to be safe)
        forecast_dicts.sort(key=lambda dict: dict['t_range'][0])

        return forecast_dicts

In [None]:
# Helper just to code everything up here, in the file it's imported from simulation utils
import numpy as np
def get_abs_time_grid_for_hycom_file(f, data_type):
    """Helper function to extract the t_grid in POSIX time from a HYCOM File f."""
    # Get the t_grid. note that this is in hours from HYCOM data!
    t_grid = f.variables['time'][:]
    # Get the time_origin of the file (Note: this is very tailered for the HYCOM Data)
    if data_type == 'H':
        time_origin = datetime.strptime(f.variables['time'].__dict__['time_origin'] + ' +0000',
                                            '%Y-%m-%d %H:%M:%S %z')
    else:
        time_origin = datetime.strptime(f.variables['time'].__dict__['units'] + ' +0000',
                                                     'hours since %Y-%m-%d %H:%M:%S.000 UTC %z')

    # for time indexing transform to POSIX time
    abs_t_grid = [(time_origin + timedelta(hours=X)).timestamp() for X in t_grid.data]
    return np.array(abs_t_grid)

In [None]:
# The local file-based approach
# import get_abs_time_grid_for_hycom_file

class Problem(BaseProblem):
    def __init__(self, x_0, x_T, t_0, platform_config_dict, hindcast_file, forecast_folder=None,
                 plan_on_gt = False, forecast_delay_in_h=0., noise=None, x_t_tol=0.1):
        
        # initialize the Base Problem
        super().__init__(x_0, x_T, t_0, platform_config_dict, plan_on_gt, forecast_delay_in_h, noise, x_t_tol)
        
        # Derive the hindcast_grid_dict 
        self.local_hindcast_file = hindcast_file
        self.hindcast_grid_dict = self.get_grid_dict_from_file(hindcast_file)
        
        # Derive the list forecasts_dicts
        if not self.plan_on_gt:
            self.forecasts_dicts = self.get_forecast_dicts(forecast_folder)
            
        # log variables for data_acces
        self.data_access = 'local'
        
    def get_grid_dict_from_file(self, file, data_type='H'):
        """Helper function to create a grid dict from a local file.
        Input: hindcast_file
        """
        f = netCDF4.Dataset(file)
        # get the time coverage in POSIX
        t_grid = get_abs_time_grid_for_hycom_file(f, data_type=data_type)
        # greate dict
        return {"gt_t_range": [datetime.utcfromtimestamp(t_grid[0]), datetime.utcfromtimestamp(t_grid[-1])],
                "gt_y_range": [f.variables['lat'][:][0], f.variables['lat'][:][-1]],
                "gt_x_range": [f.variables['lon'][:][0], f.variables['lon'][:][-1]]}
    
    
    def get_forecast_dicts(self, forecast_folder):
        """Helper function to create a list of dicts of the available forecasts from the local folder.
        Starting from the most recent at t_0 in the future."""
        
        # get a list of files from the folder
        forecast_files_list = [forecast_folder + f for f in listdir(forecast_folder) if
                               (isfile(join(forecast_folder, f)) and f != '.DS_Store')]
                               
        # iterate over all files to extract the t_ranges and put them in an ordered list of dicts
        forecast_dicts = []
        for file in forecast_files_list:
            grid_dict = self.get_grid_dict_from_file(file, data_type='F')
            forecast_dicts.append({'t_range': grid_dict['gt_t_range'], 'file': file})
        # sort the list
        forecast_dicts.sort(key=lambda dict: dict['t_range'][0])

        return forecast_dicts

In [None]:
# Set the platform configurations
from datetime import timezone
platform_config_dict = {'battery_cap': 20.0, 'u_max': 0.1, 'motor_efficiency': 1.0,
                        'avg_solar_power': 10.0, 'drag_factor': 10.0}

# Create the navigation problem
# t_0 = datetime(2021, 6, 1, 12, 10, 10)
# t_0 = datetime(2021, 11, 12, 12, 10, 10)
t_0 = datetime(2021, 11, 12, 14, 10, 10, tzinfo=timezone.utc)
x_0 = [-88.0, 25.0, 1]  # lon, lat, battery
x_T = [-88.2, 26.3]
c3.get_all_data()

hindcast_file = "./data/" + "2021_06_1-05_hourly.nc4"
forecast_folder = "./data/forecast_folder/"

problem = C3Problem(x_0, x_T, t_0, platform_config_dict, plan_on_gt=False)
# problem = Problem(x_0, x_T, t_0, platform_config_dict, hindcast_file, forecast_folder, plan_on_gt=False)

In [53]:
C3Problem

__main__.C3Problem

In [54]:
problem.forecasts_dicts

[{'t_range': [datetime.datetime(2021, 11, 12, 12, 0, tzinfo=datetime.timezone.utc),
   datetime.datetime(2021, 11, 17, 12, 0, tzinfo=datetime.timezone.utc)],
  'file': 'hycom-test/fmrc/GOMu0.04_901m000_FMRC_RUN_2021-11-12T12:00:00Z-2021-11-12T12:00:00Z-2021-11-17T12:00:00Z.nc/GOMu0.04_901m000_FMRC_RUN_2021-11-12T12:00:00Z-2021-11-12T12:00:00Z-2021-11-17T12:00:00Z.nc'},
 {'t_range': [datetime.datetime(2021, 11, 13, 12, 0, tzinfo=datetime.timezone.utc),
   datetime.datetime(2021, 11, 19, 0, 0, tzinfo=datetime.timezone.utc)],
  'file': 'hycom-test/fmrc/GOMu0.04_901m000_FMRC_RUN_2021-11-13T12:00:00Z-2021-11-13T12:00:00Z-2021-11-19T00:00:00Z.nc/GOMu0.04_901m000_FMRC_RUN_2021-11-13T12:00:00Z-2021-11-13T12:00:00Z-2021-11-19T00:00:00Z.nc'},
 {'t_range': [datetime.datetime(2021, 11, 14, 12, 0, tzinfo=datetime.timezone.utc),
   datetime.datetime(2021, 11, 20, 0, 0, tzinfo=datetime.timezone.utc)],
  'file': 'hycom-test/fmrc/GOMu0.04_901m000_FMRC_RUN_2021-11-14T12:00:00Z-2021-11-14T12:00:00Z-2021-

In [8]:
problem.forecasts_dicts

# get most recent forecast_idx for t_0
for i, dic in enumerate(self.forecasts_dict):
    # Note: this assumes the dict is ordered according to time-values
    if dic['t_range'][0] > t_0 + forecast_delay_in_h*3600:
        self.most_recent_forecast_idx = i - 1
        break

[{'t_range': [datetime.datetime(2021, 6, 2, 12, 0),
   datetime.datetime(2021, 6, 8, 0, 0)],
  'file': './data/forecast_folder/GOMu0.04_901m000_FMRC_RUN_2021-06-02T12_00_00Z.nc4'}]

In [6]:
problem.viz(time=None, video=True, filename='for_sure.gif', cut_out_in_deg=0.8, html_render=None)

Note only the GT file is currently visualized
Subsetted data from 2021-06-01 12:00:00 UTC to 2021-06-06 11:00:00 UTC in 121 time steps
