In [1]:
# import packages and functions
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import netCDF4 as nc
import gsw

# load the indices that we need

In [2]:
path = "/scratch/mmurakami/WAOM/"

In [3]:
# first dataset (years 1-2)
filename = path + "waom_2years/ocean_flt_select.nc"
ds1 = xr.open_dataset(filename)
a = ds1.Xgrid[-1].values

# load the second dataset as b
filename = path + "waom_6years/ocean_flt.nc"
ds3 = xr.open_dataset(filename)
b = np.array(ds3.variables['Xgrid'][11])

# find the matching indices
indices_in_a = []

# Loop over each value in `b` and find the first occurrence in `a`
for i, val in enumerate(b):
    match_idx = np.where(a == val)[0]  # Find all occurrences in `a`
    if match_idx.size > 0:  # If a match is found
        indices_in_a.append(match_idx[0])  # Store only the first match

# Convert to NumPy arrays
indices_in_a = np.array(indices_in_a)

# Print results
print(f"indices_in_a: {indices_in_a}")  # Indices where `a` matches `b` (first occurrence only)

indices_in_a: [   0    1    2 ... 3892 3893 3894]


In [4]:
# import numpy as np
new_locs = np.loadtxt('/scratch/mmurakami/WAOM/new_locs.txt', dtype=int)

# Define parameters
size = 40664  # Total number of elements
reset_interval = 104  # After this many indices, reset
start = 11  # Start value
increment = 672  # Step size

# Create array
arr = np.zeros(size, dtype=int)

for i in range(size):
    arr[i] = start + (i % reset_interval) * increment  # Reset every 104 elements

# Print first few sequences to verify
#print(arr[:200])  # Print more to verify the pattern
print(f"Array shape: {arr.shape}")


# now select from new_locs and then from indices_in_a
times = arr[new_locs][indices_in_a]

Array shape: (40664,)


In [5]:
times # this is the first value of the items in the points we want to select from
max(times)

69227

# load the saved datasets

In [6]:
# load the previous two files that we created - check that these are saved and all looks ok
path = "/scratch/mmurakami/WAOM/"
filename = path + "ds_2years.nc"
ds_a = xr.open_dataset(filename)
ds_a

In [7]:
filename = path + "ds_4years.nc"
ds_b = xr.open_dataset(filename)
ds_b

In [9]:
# max_ocn_time = max(ds_b.ocean_time.values)
min_ocn_time = min(ds_b.ocean_time.values)+11
min_ocn_time = 70092
# orig_data_end

In [20]:
# do an example for the first drifter because I do not understand
i_start = times[0]
i_end = max(times)-i_start
max_time = max(times)
max_ocn_time = max(ds_b.ocean_time.values)
min_ocn_time = min(ds_b.ocean_time.values)+11

num_nans = max_time - i_start
orig_data_end = max_ocn_time - num_nans

# Convert ocean_time values to indices
ocean_time_indices = ds_b.ocean_time.values
i_end_vardata = np.searchsorted(ocean_time_indices, orig_data_end)

# Select data using integer-based slicing

original_data = ds_b['Xgrid'].isel(drifter=0)
original_data_touse = original_data.isel(ocean_time=slice(11, i_end_vardata+1))

# Define ocean_time coordinates for NaN padding
nan_time_coords = np.arange(i_end, i_end+num_nans)

# Create NaN padding with correct coordinates
nan_padding = xr.DataArray(
    np.full((num_nans,), np.nan),  # Fill with NaNs
    dims=original_data_touse.dims,  # Keep the same dimensions
    coords={
        "ocean_time": nan_time_coords,
    },
    name="Xgrid",  # Set the variable name
)


padded_drifter_data = xr.concat([original_data_touse, nan_padding], dim="ocean_time")


In [23]:
# backfill the last ds_b with nan, then can check if we have the same number of nans in each drifter
def pad_dataset(ds,times):
    drifters = ds.drifter
    max_time = max(times)
    max_ocn_time = max(ds.ocean_time.values)
    min_ocn_time = min(ds.ocean_time.values) + 11

    # convert ocean_time to indices
    ocean_time_values = ds.ocean_time.values

    # Define a common ocean_time coordinate
    full_ocean_time = np.arange(min(ocean_time_values) + 11, max_time + 1)
    
    padded_data = []

    for var in ds.data_vars:
        drifter_data = {}
        for i, drifter in enumerate(drifters):
            # compute index of nans to fill
            i_start = times[i]
            num_nans = max_time - i_start
            orig_data_end = max_ocn_time - num_nans

            # extract the data we want
            i_end_vardata = np.searchsorted(ocean_time_values, orig_data_end)
            original_data = ds[var].isel(drifter=i)
            original_data_touse = original_data.isel(ocean_time=slice(11,i_end_vardata+1))

            # define ocean_time coordinates for nan_padding
            nan_time_coords = np.arange(ocean_time_values[i_end_vardata] + 1, 
                                        ocean_time_values[i_end_vardata] + num_nans + 1)

            # Create NaN padding
            nan_padding = xr.DataArray(
                 np.full((num_nans,), np.nan),  # Fill with NaNs
                dims=['ocean_time'],  # Keep the same dimensions
                coords={
                    "ocean_time": nan_time_coords,
                },
                name=var,  # Set the variable name
            )

            # Concatenate valid data with NaN padding
            #padded_drifter_data = xr.concat([original_data_touse, nan_padding], dim="ocean_time")
            #padded_list.append(padded_drifter_data)
            merged_data = xr.concat([original_data_touse, nan_padding],dim="ocean_time")
            aligned_data = merged_data.reindex(ocean_time=full_ocean_time)
            drifter_data[var] = aligned_data

        padded_drifter_ds = xr.Dataset(drifter_data, coords={"drifter": [drifter], "ocean_time": full_ocean_time})
        padded_data.append(padded_drifter_ds)
        
    # create a new dataset with updated ocean_time
    padded_ds = xr.concat(padded_data,dim="drifter")

    # return the dataset
    return padded_ds

In [None]:
padded_ds_b = pad_dataset(ds_b,times)

In [None]:
padded_ds_b.to_netcdf(path + "ds_4years_fillna.nc")

In [None]:
import xarray as xr

# Load datasets if not already opened
# ds_a = xr.open_dataset("path_to_ds_a.nc")
# ds_b = xr.open_dataset("path_to_ds_b.nc")

# Concatenate along the ocean_time dimension
ds_combined = xr.concat([ds_a, ds_b], dim="ocean_time")

# Save to a new NetCDF file
ds_combined.to_netcdf("combined_dataset.nc")
