diff --git a/environment.yml b/environment.yml index 741eaa10..892b0044 100644 --- a/environment.yml +++ b/environment.yml @@ -6,6 +6,7 @@ dependencies: - fenics-dolfinx=0.9.0 - matplotlib - scipy + - pandas - pint # need to install pint with conda to avoid conflicts (HTM) - pip - pip: diff --git a/mb_scenario.py b/mb_scenario.py index 57ddeeb2..d44362aa 100644 --- a/mb_scenario.py +++ b/mb_scenario.py @@ -10,7 +10,8 @@ import dolfinx.fem as fem import dolfinx -from hisp.helpers import PulsedSource, Scenario +from hisp.helpers import PulsedSource +from hisp.scenario import Scenario from hisp import CustomProblem # dolfinx.log.set_log_level(dolfinx.log.LogLevel.INFO) @@ -42,7 +43,7 @@ def gaussian_distribution(x, mod=ufl): def make_mb_model(nb_mb, scenario_file): ############# Input Flux, Heat Data ############# - my_scenario = Scenario(scenario_file) + my_scenario = Scenario.from_txt_file(scenario_file, old_format=True) my_model = CustomProblem() diff --git a/src/hisp/__init__.py b/src/hisp/__init__.py index 391c7d92..7666f43d 100644 --- a/src/hisp/__init__.py +++ b/src/hisp/__init__.py @@ -1,3 +1,4 @@ -from .helpers import PulsedSource, Scenario +from .helpers import PulsedSource from .h_transport_class import CustomProblem +from .scenario import Scenario, Pulse diff --git a/src/hisp/helpers.py b/src/hisp/helpers.py index 43b18b2b..dfe0f7d6 100644 --- a/src/hisp/helpers.py +++ b/src/hisp/helpers.py @@ -1,7 +1,6 @@ import festim as F from dolfinx.fem.function import Constant import ufl -import numpy as np class PulsedSource(F.ParticleSource): @@ -34,120 +33,3 @@ def create_value_fenics(self, mesh, temperature, t: Constant): def update(self, t: float): self.flux_fenics.value = self.flux(t) - - -class Scenario: - def __init__(self, filename: str): - self.filename = filename - data = np.genfromtxt(filename, dtype=str, comments="#") - if isinstance(data[0], str): - self.data = [data] - else: - self.data = data - - def get_row(self, t: float): - """Returns the row of the scenario file that corresponds to the time t. - - Args: - t (float): the time in seconds - - Returns: - int: the row index of the scenario file corresponding to the time t - """ - current_time = 0 - for i, row in enumerate(self.data): - nb_pulses = int(row[1]) - phase_duration = nb_pulses * self.get_pulse_duration(i) - if t <= current_time + phase_duration: - return i - else: - current_time += phase_duration - - raise ValueError( - f"Time t {t} is out of bounds of the scenario file. Maximum time is {self.get_maximum_time()}" - ) - - def get_pulse_type(self, t: float) -> str: - """Returns the pulse type as a string at time t. - - Args: - t (float): time in seconds - - Returns: - str: pulse type (eg. FP, ICWC, RISP, GDC, BAKE) - """ - row_idx = self.get_row(t) - return self.data[row_idx][0] - - def get_pulse_duration(self, row: int) -> float: - """Returns the total duration of a pulse in seconds for a given row in the file. - - Args: - row (int): the row index in the scenario file - - Returns: - float: the total duration of the pulse in seconds - """ - row_data = self.data[row] - pulse_type = row_data[0] - if pulse_type == "RISP": # hard coded because it's zero in the files - ramp_up = 10 - steady_state = 250 - ramp_down = 10 - waiting = 1530 - total_duration = ramp_up + steady_state + ramp_down + waiting - return total_duration - - ramp_up = float(row_data[2]) - steady_state = float(row_data[4]) - ramp_down = float(row_data[3]) - waiting = float(row_data[5]) - - total_duration = ramp_up + steady_state + ramp_down + waiting - return total_duration - - def get_pulse_duration_no_waiting(self, row: int) -> float: - """Returns the total duration (without the waiting time) of a pulse in seconds for a given row in the file. - - Args: - row (int): the row index in the scenario file - - Returns: - float: the total duration of the pulse in seconds - """ - row_data = self.data[row] - pulse_type = row_data[0] - if pulse_type == "RISP": # hard coded because it's zero in the files - waiting_time = 1530 - else: - waiting_time = float(row_data[5]) - - duration = self.get_pulse_duration(row) - waiting_time - return duration - - def get_time_till_row(self, row:int) -> float: - """Returns the time that has elapsed in scenario up until start of current row. - - Args: - row (int): the row index in the scenario file - - Returns: - float: the time that has elapsed in scenario until and not including input row. - """ - time_elapsed = 0 - for prev_row_id in range(0,row): - nb_pulses = int(self.data[prev_row_id][1]) - time_elapsed += nb_pulses * self.get_pulse_duration(prev_row_id) - return time_elapsed - - def get_maximum_time(self) -> float: - """Returns the maximum time in seconds for the scenario file. - - Returns: - float: the maximum time in seconds - """ - max_time = 0 - for i, row in enumerate(self.data): - nb_pulses = int(row[1]) - max_time += nb_pulses * self.get_pulse_duration(i) - return max_time diff --git a/src/hisp/scenario.py b/src/hisp/scenario.py new file mode 100644 index 00000000..aa330122 --- /dev/null +++ b/src/hisp/scenario.py @@ -0,0 +1,229 @@ +import pandas as pd +from typing import List +import warnings + + +class Pulse: + pulse_type: str + nb_pulses: int + ramp_up: float + steady_state: float + ramp_down: float + waiting: float + + def __init__( + self, + pulse_type: str, + nb_pulses: int, + ramp_up: float, + steady_state: float, + ramp_down: float, + waiting: float, + ): + self.pulse_type = pulse_type + self.nb_pulses = nb_pulses + self.ramp_up = ramp_up + self.steady_state = steady_state + self.ramp_down = ramp_down + self.waiting = waiting + + @property + def total_duration(self) -> float: + all_zeros = ( + self.ramp_up == 0 + and self.steady_state == 0 + and self.ramp_down == 0 + and self.waiting == 0 + ) + if self.pulse_type == "RISP" and all_zeros: + msg = "RISP pulse has all zeros for ramp_up, steady_state, ramp_down, waiting. " + msg += "Setting hardcoded values. Please check the values in the scenario file." + warnings.warn(msg, UserWarning) + + self.ramp_up = 10 + self.steady_state = 250 + self.ramp_down = 10 + self.waiting = 1530 + + return self.ramp_up + self.steady_state + self.ramp_down + self.waiting + + @property + def duration_no_waiting(self) -> float: + return self.total_duration - self.waiting + + +class Scenario: + def __init__(self, pulses: List[Pulse] = None): + """Initializes a Scenario object containing several pulses. + + Args: + pulses: The list of pulses in the scenario. Each pulse is a Pulse object. + """ + self._pulses = pulses if pulses is not None else [] + + @property + def pulses(self) -> List[Pulse]: + return self._pulses + + def to_txt_file(self, filename: str): + df = pd.DataFrame( + [ + { + "pulse_type": pulse.pulse_type, + "nb_pulses": pulse.nb_pulses, + "ramp_up": pulse.ramp_up, + "steady_state": pulse.steady_state, + "ramp_down": pulse.ramp_down, + "waiting": pulse.waiting, + } + for pulse in self.pulses + ] + ) + df.to_csv(filename, index=False) + + @staticmethod + def from_txt_file(filename: str, old_format=False) -> "Scenario": + if old_format: + pulses = [] + with open(filename, "r") as f: + for line in f: + # skip first line + if line.startswith("#"): + continue + + # assume this is the format + pulse_type, nb_pulses, ramp_up, steady_state, ramp_down, waiting = ( + line.split() + ) + pulses.append( + Pulse( + pulse_type=pulse_type, + nb_pulses=int(nb_pulses), + ramp_up=float(ramp_up), + steady_state=float(steady_state), + ramp_down=float(ramp_down), + waiting=float(waiting), + ) + ) + return Scenario(pulses) + df = pd.read_csv(filename) + pulses = [ + Pulse( + pulse_type=row["pulse_type"], + nb_pulses=int(row["nb_pulses"]), + ramp_up=float(row["ramp_up"]), + steady_state=float(row["steady_state"]), + ramp_down=float(row["ramp_down"]), + waiting=float(row["waiting"]), + ) + for _, row in df.iterrows() + ] + return Scenario(pulses) + + def get_row(self, t: float) -> int: + """Returns the row of the scenario file that corresponds to the time t. + + Args: + t: the time in seconds + + Returns: + int: the row index of the scenario file corresponding to the time t + """ + current_time = 0 + for i, pulse in enumerate(self.pulses): + phase_duration = pulse.nb_pulses * pulse.total_duration + if t <= current_time + phase_duration: + return i + else: + current_time += phase_duration + + raise ValueError( + f"Time t {t} is out of bounds of the scenario file. Maximum time is {self.get_maximum_time()}" + ) + + def get_pulse(self, t: float) -> Pulse: + """Returns the pulse at time t. + + Args: + t: the time in seconds + + Returns: + Pulse: the pulse at time t + """ + row_idx = self.get_row(t) + return self.pulses[row_idx] + + def get_pulse_type(self, t: float) -> str: + """Returns the pulse type as a string at time t. + + Args: + t: time in seconds + + Returns: + pulse type (eg. FP, ICWC, RISP, GDC, BAKE) + """ + return self.get_pulse(t).pulse_type + + def get_maximum_time(self) -> float: + """Returns the maximum time of the scenario in seconds. + + Returns: + the maximum time of the scenario in seconds + """ + return sum([pulse.nb_pulses * pulse.total_duration for pulse in self.pulses]) + + def get_time_start_current_pulse(self, t: float): + """Returns the time (s) at which the current pulse started. + + Args: + t: the time in seconds + + Returns: + the time at which the current pulse started + """ + current_pulse = self.get_pulse(t) + pulse_index = self.pulses.index(current_pulse) + return sum( + [ + pulse.nb_pulses * pulse.total_duration + for pulse in self.pulses[:pulse_index] + ] + ) + + # TODO this is the same as get_time_start_current_pulse, remove + def get_time_till_row(self, row: int) -> float: + """Returns the time (s) until the row in the scenario file. + + Args: + row: the row index in the scenario file + + Returns: + the time until the row in the scenario file + """ + return sum( + [pulse.nb_pulses * pulse.total_duration for pulse in self.pulses[:row]] + ) + + # TODO remove + def get_pulse_duration_no_waiting(self, row: int) -> float: + """Returns the total duration (without the waiting time) of a pulse in seconds for a given row in the file. + + Args: + row: the row index in the scenario file + + Returns: + the total duration of the pulse in seconds + """ + return self.pulses[row].duration_no_waiting + + # TODO remove + def get_pulse_duration(self, row: int) -> float: + """Returns the total duration of a pulse in seconds for a given row in the file. + + Args: + row: the row index in the scenario file + + Returns: + the total duration of the pulse in seconds + """ + return self.pulses[row].total_duration diff --git a/test/one_line_scenario.txt b/test/one_line_scenario.txt deleted file mode 100644 index 328054bc..00000000 --- a/test/one_line_scenario.txt +++ /dev/null @@ -1,2 +0,0 @@ -# PULSE TYPE, NUMBER OF PULSES, RAMP UP, RAMP DOWN, STEADY STATE, WAITING -FP 2 455 455 650 1000 \ No newline at end of file diff --git a/test/scenario_test.txt b/test/scenario_test.txt deleted file mode 100644 index 15ff0dc0..00000000 --- a/test/scenario_test.txt +++ /dev/null @@ -1,3 +0,0 @@ -# PULSE TYPE, NUMBER OF PULSES, RAMP UP, RAMP DOWN, STEADY STATE, WAITING -FP 2 455 455 650 1000 -ICWC 2 36 36 180 1000 \ No newline at end of file diff --git a/test/test_pulse.py b/test/test_pulse.py new file mode 100644 index 00000000..e5445bef --- /dev/null +++ b/test/test_pulse.py @@ -0,0 +1,87 @@ +from hisp import Pulse + +import pytest + +def test_pulse_initialization(): + pulse = Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=0.1, + steady_state=0.2, + ramp_down=0.3, + waiting=0.4, + ) + + assert pulse.pulse_type == "FP" + assert pulse.nb_pulses == 2 + assert pulse.ramp_up == 0.1 + assert pulse.steady_state == 0.2 + assert pulse.ramp_down == 0.3 + assert pulse.waiting == 0.4 + + +def test_pulse_total_duration(): + pulse = Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=0.1, + steady_state=0.2, + ramp_down=0.3, + waiting=0.4, + ) + + assert pulse.total_duration == 1.0 + + +def test_pulse_total_duration_with_zeros(): + pulse = Pulse( + pulse_type="RISP", + nb_pulses=1, + ramp_up=0.0, + steady_state=0.0, + ramp_down=0.0, + waiting=0.0, + ) + + assert pulse.total_duration == 1800.0 + +def test_pulse_total_duration_no_waiting_with_zeros(): + pulse = Pulse( + pulse_type="RISP", + nb_pulses=1, + ramp_up=0.0, + steady_state=0.0, + ramp_down=0.0, + waiting=0.0, + ) + + assert pulse.duration_no_waiting == 270.0 + +def test_pulse_risp_with_zeros_raises_warning(): + with pytest.warns(UserWarning): + pulse = Pulse( + pulse_type="RISP", + nb_pulses=1, + ramp_up=0.0, + steady_state=0.0, + ramp_down=0.0, + waiting=0.0, + ) + pulse.total_duration + assert pulse.ramp_up != 0.0 + assert pulse.steady_state != 0.0 + assert pulse.ramp_down != 0.0 + assert pulse.waiting != 0.0 + +def test_pulse_duration_no_waiting(): + pulse = Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=0.1, + steady_state=0.2, + ramp_down=0.3, + waiting=0.4, + ) + + assert pulse.duration_no_waiting == 0.6 + assert pulse.duration_no_waiting == pulse.total_duration - pulse.waiting \ No newline at end of file diff --git a/test/test_scenario.py b/test/test_scenario.py index 2ad8fe06..aefb7d31 100644 --- a/test/test_scenario.py +++ b/test/test_scenario.py @@ -1,17 +1,125 @@ -from hisp import Scenario -import os +from hisp.scenario import Scenario, Pulse import pytest -import numpy as np import matplotlib.pyplot as plt +import numpy as np -current_dir = os.path.dirname(__file__) -scenario_path = os.path.join(current_dir, "scenario_test.txt") -one_line_scenario_path = os.path.join(current_dir, "one_line_scenario.txt") + +def test_scenario(): + scenario = Scenario() + assert len(scenario.pulses) == 0 + + scenario.pulses.append( + Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=0.1, + steady_state=0.2, + ramp_down=0.3, + waiting=0.4, + ) + ) + + assert len(scenario.pulses) == 1 + assert scenario.pulses[0].pulse_type == "FP" + assert scenario.pulses[0].nb_pulses == 2 + assert scenario.pulses[0].ramp_up == 0.1 + assert scenario.pulses[0].steady_state == 0.2 + assert scenario.pulses[0].ramp_down == 0.3 + assert scenario.pulses[0].waiting == 0.4 + + scenario.to_txt_file("test_scenario.txt") + scenario2 = Scenario.from_txt_file("test_scenario.txt") + + assert len(scenario2.pulses) == 1 + assert scenario2.pulses[0].pulse_type == "FP" + assert scenario2.pulses[0].nb_pulses == 2 + assert scenario2.pulses[0].ramp_up == 0.1 + assert scenario2.pulses[0].steady_state == 0.2 + assert scenario2.pulses[0].ramp_down == 0.3 + assert scenario2.pulses[0].waiting == 0.4 + + +def test_scenario_several_pulses(): + scenario = Scenario() + assert len(scenario.pulses) == 0 + + scenario.pulses.append( + Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=0.1, + steady_state=0.2, + ramp_down=0.3, + waiting=0.4, + ) + ) + + scenario.pulses.append( + Pulse( + pulse_type="ICWC", + nb_pulses=3, + ramp_up=0.5, + steady_state=0.6, + ramp_down=0.7, + waiting=0.8, + ) + ) + + assert len(scenario.pulses) == 2 + assert scenario.pulses[0].pulse_type == "FP" + assert scenario.pulses[0].nb_pulses == 2 + assert scenario.pulses[0].ramp_up == 0.1 + assert scenario.pulses[0].steady_state == 0.2 + assert scenario.pulses[0].ramp_down == 0.3 + assert scenario.pulses[0].waiting == 0.4 + + assert scenario.pulses[1].pulse_type == "ICWC" + assert scenario.pulses[1].nb_pulses == 3 + assert scenario.pulses[1].ramp_up == 0.5 + assert scenario.pulses[1].steady_state == 0.6 + assert scenario.pulses[1].ramp_down == 0.7 + assert scenario.pulses[1].waiting == 0.8 + + scenario.to_txt_file("test_scenario.txt") + scenario2 = Scenario.from_txt_file("test_scenario.txt") + + assert len(scenario2.pulses) == 2 + assert scenario2.pulses[0].pulse_type == "FP" + assert scenario2.pulses[0].nb_pulses == 2 + assert scenario2.pulses[0].ramp_up == 0.1 + assert scenario2.pulses[0].steady_state == 0.2 + assert scenario2.pulses[0].ramp_down == 0.3 + assert scenario2.pulses[0].waiting == 0.4 + + assert scenario2.pulses[1].pulse_type == "ICWC" + assert scenario2.pulses[1].nb_pulses == 3 + assert scenario2.pulses[1].ramp_up == 0.5 + assert scenario2.pulses[1].steady_state == 0.6 + assert scenario2.pulses[1].ramp_down == 0.7 + assert scenario2.pulses[1].waiting == 0.8 def test_maximum_time(): # BUILD - my_scenario = Scenario(scenario_path) + + pulse1 = Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=455, + steady_state=455, + ramp_down=650, + waiting=1000, + ) + pulse2 = Pulse( + pulse_type="ICWC", + nb_pulses=2, + ramp_up=36, + steady_state=36, + ramp_down=180, + waiting=1000, + ) + my_scenario = Scenario([pulse1, pulse2]) + expected_maximum_time = 2 * (455 + 455 + 650 + 1000) + 2 * (36 + 36 + 180 + 1000) # RUN @@ -21,56 +129,51 @@ def test_maximum_time(): assert computed_maximum_time == expected_maximum_time -@pytest.mark.parametrize("t, expected_row", [(0, 0), (6000, 1), (1e5, ValueError)]) -def test_get_pulse_row(t, expected_row): - my_scenario = Scenario(scenario_path) - - if isinstance(expected_row, type) and issubclass(expected_row, Exception): - with pytest.raises(expected_row): - my_scenario.get_row(t=t) +pulse1 = Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=455, + steady_state=455, + ramp_down=650, + waiting=1000, +) +pulse2 = Pulse( + pulse_type="ICWC", + nb_pulses=2, + ramp_up=36, + steady_state=36, + ramp_down=180, + waiting=1000, +) + + +@pytest.mark.parametrize( + "t, expected_pulse", [(0, pulse1), (6000, pulse2), (1e5, None)] +) +def test_get_pulse(t, expected_pulse): + + my_scenario = Scenario([pulse1, pulse2]) + + if expected_pulse is None: + with pytest.raises(ValueError): + my_scenario.get_pulse(t=t) else: - pulse_row = my_scenario.get_row(t=t) - assert pulse_row == expected_row - - -def test_reading_a_file(): - my_scenario = Scenario(scenario_path) - - times = np.linspace(0, my_scenario.get_maximum_time(), 1000) - pulse_types = [] - for t in times: - pulse_type = my_scenario.get_pulse_type(t) - pulse_types.append(pulse_type) - - # color the line based on the pulse type - color = { - "FP": "red", - "ICWC": "blue", - "RISP": "green", - "GDC": "orange", - "BAKE": "purple", - } - - colors = [color[pulse_type] for pulse_type in pulse_types] - - for i in range(len(times) - 1): - plt.plot(times[i : i + 2], np.ones_like(times[i : i + 2]), c=colors[i]) - # plt.xscale("log") - # plt.show() + pulse = my_scenario.get_pulse(t=t) + assert pulse == expected_pulse -@pytest.mark.parametrize("t, expected_row", [(100, 0)]) -def test_one_line_scenario(t, expected_row): - my_scenario = Scenario(one_line_scenario_path) +@pytest.mark.parametrize("t, expected_pulse", [(100, pulse1)]) +def test_one_pulse_scenario(t, expected_pulse): + my_scenario = Scenario([expected_pulse]) - pulse_row = my_scenario.get_row(t=t) + pulse = my_scenario.get_pulse(t=t) - assert pulse_row == expected_row + assert pulse == expected_pulse @pytest.mark.parametrize("row, expected_duration", [(0, 2560), (1, 1252)]) def test_get_pulse_duration(row, expected_duration): - my_scenario = Scenario(scenario_path) + my_scenario = Scenario([pulse1, pulse2]) duration = my_scenario.get_pulse_duration(row=row) @@ -79,7 +182,7 @@ def test_get_pulse_duration(row, expected_duration): @pytest.mark.parametrize("row, expected_duration", [(0, 1560), (1, 252)]) def test_get_pulse_duration_no_waiting(row, expected_duration): - my_scenario = Scenario(scenario_path) + my_scenario = Scenario([pulse1, pulse2]) duration = my_scenario.get_pulse_duration_no_waiting(row=row) @@ -88,8 +191,34 @@ def test_get_pulse_duration_no_waiting(row, expected_duration): @pytest.mark.parametrize("row, expected_time", [(0, 0.0), (1, 5120.0)]) def test_get_time_till_row(row, expected_time): - my_scenario = Scenario(scenario_path) + my_scenario = Scenario([pulse1, pulse2]) elapsed_time = my_scenario.get_time_till_row(row=row) assert elapsed_time == expected_time + + +def test_reading_a_file(): + my_scenario = Scenario([pulse1, pulse2]) + + times = np.linspace(0, my_scenario.get_maximum_time(), 1000) + pulse_types = [] + for t in times: + pulse_type = my_scenario.get_pulse_type(t) + pulse_types.append(pulse_type) + + # color the line based on the pulse type + color = { + "FP": "red", + "ICWC": "blue", + "RISP": "green", + "GDC": "orange", + "BAKE": "purple", + } + + colors = [color[pulse_type] for pulse_type in pulse_types] + + for i in range(len(times) - 1): + plt.plot(times[i : i + 2], np.ones_like(times[i : i + 2]), c=colors[i]) + # plt.xscale("log") + # plt.show()