In [109]:
import os
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats
import scipy.interpolate

class PropagationEnvironment:
    """
    Propagation environment with automatic computation of missing parameters.
    """

    def __init__(self, f, c_0, z_0=None, r_max=None, z_lims=None, 
                 roughness=None, topography=None, ssp=None, rho=None, 
                 alpha=None, measurement_points=None, ssp_directory=None):
        """
        INPUTS
        f: frequency of study [Hz]
        c_0: reference sound speed [m/s]
        z_0: starting depth for terrain [m]
        r_max: extent of topography [m]
        z_lims: two elements giving ideal bounds for z [m]
        roughness: roughly equivalent to std in depth [m]
        topography: pre-defined topography array (optional)
        ssp: sound speed profile (optional, will be computed if None)
        rho: density profile (optional, will be computed if None)
        alpha: attenuation profile (optional, will be computed if None)
        measurement_points: (N, 2) array of (r, z) measurement locations
        ssp_directory: path to directory containing SSP files
        """
        # Store basic parameters
        self.f = f
        self.c_0 = c_0
        self.h = c_0 / f / 3  # Grid spacing: 1/3 wavelength
        
        # Handle topography parameters
        if topography is None:
            # Need parameters to generate topography
            if any(item is None for item in [r_max, z_0, z_lims, roughness]):
                raise ValueError(
                    "'r_max', 'z_0', 'z_lims', and 'roughness' must be defined if no topography is given."
                )
            self.r_max = r_max
            self.z_0 = z_0
            self.z_lims = np.array(z_lims)
            self.roughness = roughness
            self.topography_input = self._compute_topography()
        else:
            # Use provided topography
            self.topography_input = topography
            self.r_max = r_max if r_max is not None else len(topography) * self.h
            self.z_0 = topography[0] if z_0 is None else z_0
            self.z_lims = z_lims
            self.roughness = roughness
        
        # Generate mesh from topography
        self._generate_mesh()
        
        # Store measurement points and SSP directory
        self.measurement_points = measurement_points
        self.ssp_directory = ssp_directory or self._get_default_ssp_directory()
        
        # Interpolate/compute physical properties
        self.ssp = self._process_ssp(ssp)
        self.rho = self._process_rho(rho)
        self.alpha = self._process_alpha(alpha)
        
        print(f"Environment initialized: {len(self.r_mesh)} × {len(self.z_mesh)} grid")

    def _compute_topography(self):
        """
        Randomly generate topography with skewed normal distribution.
        """
        n_points = int(self.r_max / self.h)
        topography = np.zeros(n_points)
        topography[0] = self.z_0
        
        for i in range(1, n_points):
            # Skew distribution towards center of z_lims
            center_deviation = (topography[i-1] - np.mean(self.z_lims)) / np.diff(self.z_lims)[0]
            skew_factor = -center_deviation / self.roughness
            skew_factor = np.clip(skew_factor, -5, 5)  # Prevent extreme values
            
            # Generate next point
            topography[i] = scipy.stats.skewnorm.rvs(
                skew_factor, 
                loc=topography[i-1], 
                scale=self.roughness
            )
            
            # Keep within bounds
            topography[i] = np.clip(topography[i], self.z_lims[0], self.z_lims[1])
        
        return topography

    def _generate_mesh(self):
        """
        Create mesh from topography, handling slope constraints.
        """
        # Discretize topography to grid spacing
        discretized_topo = self.h * np.round(self.topography_input / self.h)
        
        # Create depth mesh
        self.z_mesh = np.arange(0, np.max(discretized_topo) + 2*self.h, self.h)
        
        # Initialize interpolated topography and range mesh
        new_topography = [discretized_topo[0]]
        r_mesh = [0.0]
        i_orig = 1  # Index in original topography
        i_new = 1   # Index in interpolated arrays
        
        # Interpolate points where slope is too steep
        while i_orig < len(discretized_topo):
            depth_change = discretized_topo[i_orig] - new_topography[i_new-1]
            
            if np.abs(depth_change) >= self.h:
                # Slope too steep - add intermediate points
                starting_r = r_mesh[i_new-1]
                dr_total = 0
                
                while dr_total < self.h:
                    dr = self.h**2 / np.abs(depth_change)
                    dr = min(dr, self.h - dr_total)
                    
                    r_mesh.append(r_mesh[-1] + dr)
                    depth_step = self.h * np.sign(depth_change)
                    new_topography.append(new_topography[-1] + depth_step)
                    
                    dr_total += dr
                    i_new += 1
                
                # Add final point
                r_mesh.append(r_mesh[-1] + self.h)
                new_topography.append(discretized_topo[i_orig])
                
            elif np.abs(depth_change) == 0:
                # Flat section
                r_mesh.append(r_mesh[-1] + self.h)
                new_topography.append(discretized_topo[i_orig])
                
            else:
                raise ValueError(
                    f"Unexpected depth change: {depth_change:.3f} < h = {self.h:.3f}. "
                    "Topography not discretized correctly!"
                )
            
            i_orig += 1
            i_new += 1
        
        # Store final mesh
        self.r_mesh = np.array(r_mesh)
        self.topography = np.array(new_topography)
        
        print(f"Mesh created: {len(self.r_mesh)} range × {len(self.z_mesh)} depth points")

    def _get_default_ssp_directory(self):
        """Get default SSP directory path."""
        return '/Users/matthewaidan/Documents/pythonscripts/acoustic-models/SSP_USGS'

    def _process_ssp(self, ssp_input):
        """
        Process sound speed profile.
        If provided: interpolate to mesh
        If None: load from files or create default
        """
        if ssp_input is not None:
            # User provided SSP - interpolate to mesh
            if self.measurement_points is None:
                raise ValueError("measurement_points required when providing ssp data")
            
            ssp_function = scipy.interpolate.LinearNDInterpolator(
                self.measurement_points, 
                ssp_input
            )
            
            # Create mesh grid for interpolation
            r_grid, z_grid = np.meshgrid(self.r_mesh, self.z_mesh, indexing='ij')
            mesh_points = np.column_stack([r_grid.ravel(), z_grid.ravel()])
            
            mesh_ssp = ssp_function(mesh_points).reshape(len(self.r_mesh), len(self.z_mesh))
        
        else:
            # Load SSP from files or create default
            mesh_ssp = self._load_or_create_ssp()
        
        return mesh_ssp

    def _load_or_create_ssp(self):
        """
        Load SSP from files or create default profile.
        """
        # Try to load SSP files
        try:
            ssp_arr, depth_arr = self._load_ssp_files()
        except (FileNotFoundError, OSError) as e:
            print(f"Could not load SSP files: {e}")
            print("Using default Munk profile instead")
            return self._create_default_ssp()
        
        # Assign SSP to each range slice
        nr = len(self.r_mesh)
        nz = len(self.z_mesh)
        mesh_ssp = np.zeros((nr, nz))
        
        for i_range in range(nr):
            # Find profiles deep enough for this range
            valid_profiles = np.where(
                np.nanmax(depth_arr, axis=1) > self.topography[i_range]
            )[0]
            
            if len(valid_profiles) == 0:
                # No valid profiles - use default
                mesh_ssp[i_range, :] = self._create_default_ssp()[0, :]
                continue
            
            # Randomly select a profile
            ind_choice = np.random.choice(valid_profiles)
            
            # Remove NaN values for interpolation
            valid_mask = (~np.isnan(depth_arr[ind_choice, :]) & np.greater(ssp_arr[ind_choice,:], 0))
            z_valid = depth_arr[ind_choice, valid_mask]
            c_valid = ssp_arr[ind_choice, valid_mask]
            
            # Interpolate to mesh
            mesh_ssp[i_range, :] = np.interp(self.z_mesh, z_valid, c_valid)
            if all(depth == 0 for depth in mesh_ssp[i_range,1:]):
                print(self.z_mesh)
                print(z_valid)
                print(c_valid)
                raise ValueError(f"got some zeros here at {ind_choice}")
            
            # Assign bottom properties
            mesh_ssp[i_range, self.z_mesh > self.topography[i_range]] = 1624  # Sediment
            
            # Assign surface properties
            mesh_ssp[i_range, 0] = 343  # Air
        
        return mesh_ssp

    def _load_ssp_files(self):
        """Load SSP data from files."""
        if not os.path.exists(self.ssp_directory):
            raise FileNotFoundError(f"SSP directory not found: {self.ssp_directory}")
        
        ssp_list = []
        z_ssp_list = []
        
        # Read all SSP files
        for filename in os.listdir(self.ssp_directory):
            filepath = os.path.join(self.ssp_directory, filename)
            
            if not os.path.isfile(filepath):
                continue
            
            try:
                with open(filepath, 'r') as file:
                    content = file.read()
                    lines = content.split('\n')[3:-1]  # Skip header
                    
                    data = []
                    for line in lines:
                        parts = line.split()
                        if len(parts) >= 2:
                            data.append([float(parts[0]), float(parts[1])])
                    
                    if len(data) > 0:
                        data = np.array(data)
                        z_ssp_list.append(data[:, 0])
                        ssp_list.append(data[:, 1])
                        
            except Exception as e:
                print(f"Warning: Could not read {filename}: {e}")
        
        if len(ssp_list) == 0:
            raise FileNotFoundError("No valid SSP files found")
        
        # Convert to padded arrays
        max_length = max(len(s) for s in ssp_list)
        ssp_arr = np.full((len(ssp_list), max_length), np.nan)
        depth_arr = np.full((len(ssp_list), max_length), np.nan)
        
        for i in range(len(ssp_list)):
            ssp_arr[i, :len(ssp_list[i])] = ssp_list[i]
            depth_arr[i, :len(z_ssp_list[i])] = z_ssp_list[i]
        
        return ssp_arr, depth_arr

    def _create_default_ssp(self):
        """Create default Munk profile."""
        nr = len(self.r_mesh)
        nz = len(self.z_mesh)
        mesh_ssp = np.zeros((nr, nz))
        
        # Munk profile parameters
        z_axis = 1000
        epsilon = 0.00737
        c_min = 1490
        
        for i in range(nr):
            eta = 2 * (self.z_mesh - z_axis) / z_axis
            c_z = c_min * (1 + epsilon * (eta + np.exp(-eta) - 1))
            mesh_ssp[i, :] = c_z
        
        return mesh_ssp

    def _process_rho(self, rho_input):
        """Process density profile."""
        nr = len(self.r_mesh)
        nz = len(self.z_mesh)
        
        if rho_input is not None:
            # User provided density - interpolate
            if self.measurement_points is None:
                raise ValueError("measurement_points required when providing rho data")
            
            rho_function = scipy.interpolate.LinearNDInterpolator(
                self.measurement_points, 
                rho_input
            )
            
            r_grid, z_grid = np.meshgrid(self.r_mesh, self.z_mesh, indexing='ij')
            mesh_points = np.column_stack([r_grid.ravel(), z_grid.ravel()])
            rho = rho_function(mesh_points).reshape(nr, nz)
        
        else:
            # Create default density profile
            rho = 1000 * np.ones((nr, nz))  # Water
            
            for i_range in range(nr):
                rho[i_range, self.z_mesh > self.topography[i_range]] = 1700  # Sediment
                rho[i_range, 0] = 1.21  # Air
        
        # Add extra column for density ratio calculations in IFD
        rho = np.column_stack([rho, rho[:, -1]])
        
        return rho

    def _process_alpha(self, alpha_input):
        """Process attenuation profile."""
        nr = len(self.r_mesh)
        nz = len(self.z_mesh)
        
        if alpha_input is not None:
            # User provided attenuation - interpolate
            if self.measurement_points is None:
                raise ValueError("measurement_points required when providing alpha data")
            
            alpha_function = scipy.interpolate.LinearNDInterpolator(
                self.measurement_points, 
                alpha_input
            )
            
            r_grid, z_grid = np.meshgrid(self.r_mesh, self.z_mesh, indexing='ij')
            mesh_points = np.column_stack([r_grid.ravel(), z_grid.ravel()])
            alpha = alpha_function(mesh_points).reshape(nr, nz)
        
        else:
            # Create default attenuation profile
            alpha = 50e-3 * np.ones((nr, nz))  # Water (low attenuation)
            
            for i_range in range(nr):
                alpha[i_range, 0] = 20  # Air (high attenuation)
                alpha[i_range, self.z_mesh > self.topography[i_range]] = 10  # Sediment
        
        return alpha

    def plot_environment(self):
        """Visualize the environment."""
        fig, axes = plt.subplots(3, 1, figsize=(15, 12))
        
        # Sound speed
        ax = axes[0]
        im = ax.pcolormesh(self.r_mesh, self.z_mesh, self.ssp.T, shading='auto', vmin = 1400, vmax = 1500)
        ax.plot(self.r_mesh, self.topography, 'k-', lw=2)
        ax.set_ylabel('Depth [m]')
        ax.set_title('Sound Speed [m/s]')
        ax.invert_yaxis()
        plt.colorbar(im, ax=ax)
        
        # Density
        ax = axes[1]
        im = ax.pcolormesh(self.r_mesh, self.z_mesh, self.rho[:, :-1].T, shading='auto')
        ax.plot(self.r_mesh, self.topography, 'k-', lw=2)
        ax.set_ylabel('Depth [m]')
        ax.set_title('Density [kg/m³]')
        ax.invert_yaxis()
        plt.colorbar(im, ax=ax)
        
        # Attenuation
        ax = axes[2]
        im = ax.pcolormesh(self.r_mesh, self.z_mesh, self.alpha.T, shading='auto')
        ax.plot(self.r_mesh, self.topography, 'k-', lw=2)
        ax.set_xlabel('Range [m]')
        ax.set_ylabel('Depth [m]')
        ax.set_title('Attenuation [dB/λ]')
        ax.invert_yaxis()
        plt.colorbar(im, ax=ax)
        
        plt.tight_layout()
        plt.show()