In [27]:
import torch

device = 'cuda:1'  # 'cuda:0' or 'cpu'

DEVICE = torch.device(device)


# Classes and functions

In [28]:
class ExpHydroCommon:

    def __init__(self):
        pass
        
    def interpolate(self, time_series, interpolator):
        time_series_torch, var_values_torch = interpolator
        print('time_series_torch.device', time_series_torch.device)
        time_series_torch = time_series_torch.to(DEVICE)
        print('time_series_torch.device', time_series_torch.device)
        var_values_torch = var_values_torch.to(DEVICE)
        time_series = time_series.to(DEVICE)

        return self.linear_interpolation(time_series, time_series_torch, var_values_torch)

    @staticmethod
    def linear_interpolation(x, xp, fp):
        """
        Perform linear interpolation on a 1D torch tensor.

        Parameters:
        x (torch.Tensor): The x-coordinates at which to evaluate the interpolated values.
        xp (torch.Tensor): The x-coordinates of the data points, must be increasing.
        fp (torch.Tensor): The y-coordinates of the data points, same length as xp.

        Returns:
        torch.Tensor: The interpolated values, same shape as x.
        """
        slopes = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1])
        indices = torch.searchsorted(xp, x, right=True) - 1
        indices = indices.clamp(0, len(slopes) - 1)

        return fp[indices] + slopes[indices] * (x - xp[indices])

# Define and prepare data

In [29]:
# Create an instance of ExpHydroCommon
exp_hydro_common = ExpHydroCommon()

# Define interpolator
time_series_torch = torch.tensor([0, 1, 2, 3, 4], dtype=torch.float32).to(DEVICE)
var_values_torch = torch.tensor([0, 1, 4, 9, 16], dtype=torch.float32).to(DEVICE)
interpolator = (time_series_torch, var_values_torch)

# Define time_series to interpolate
time_series = torch.tensor([0.5, 0.95, 1.05, 1.5, 1.95, 2.05, 2.5, 2.95, 3.05, 3.5, 3.95, 4.05], dtype=torch.float32).to(DEVICE)

# Results

In [30]:
# Perform interpolation
result = exp_hydro_common.interpolate(time_series, interpolator)
print(result)

time_series_torch.device cuda:1
time_series_torch.device cuda:1
tensor([ 0.5000,  0.9500,  1.1500,  2.5000,  3.8500,  4.2500,  6.5000,  8.7500,
         9.3500, 12.5000, 15.6500, 16.3500], device='cuda:1')
