From b11a351845ddb84679c501a460909ddaae381522 Mon Sep 17 00:00:00 2001 From: RemDelaporteMathurin Date: Sat, 2 Nov 2024 13:12:43 -0400 Subject: [PATCH 01/11] initial class and tests --- environment.yml | 1 + src/hisp/scenario.py | 57 +++++++++++++++++++++ test/test_scenario_python.py | 96 ++++++++++++++++++++++++++++++++++++ 3 files changed, 154 insertions(+) create mode 100644 src/hisp/scenario.py create mode 100644 test/test_scenario_python.py diff --git a/environment.yml b/environment.yml index 741eaa10..892b0044 100644 --- a/environment.yml +++ b/environment.yml @@ -6,6 +6,7 @@ dependencies: - fenics-dolfinx=0.9.0 - matplotlib - scipy + - pandas - pint # need to install pint with conda to avoid conflicts (HTM) - pip - pip: diff --git a/src/hisp/scenario.py b/src/hisp/scenario.py new file mode 100644 index 00000000..d514ef32 --- /dev/null +++ b/src/hisp/scenario.py @@ -0,0 +1,57 @@ +import pandas as pd +from typing import List + +class Pulse: + pulse_type: str + nb_pulses: int + ramp_up: float + steady_state: float + ramp_down: float + waiting: float + + def __init__(self, pulse_type: str, nb_pulses: int, ramp_up: float, steady_state: float, ramp_down: float, waiting: float): + self.pulse_type = pulse_type + self.nb_pulses = nb_pulses + self.ramp_up = ramp_up + self.steady_state = steady_state + self.ramp_down = ramp_down + self.waiting = waiting + + @property + def total_duration(self) -> float: + return self.ramp_up + self.steady_state + self.ramp_down + self.waiting + +class Scenario: + def __init__(self, pulses = None): + self._pulses = pulses if pulses is not None else [] + + @property + def pulses(self) -> List[Pulse]: + return self._pulses + + def to_txt_file(self, filename: str): + df = pd.DataFrame([{ + "pulse_type": pulse.pulse_type, + "nb_pulses": pulse.nb_pulses, + "ramp_up": pulse.ramp_up, + "steady_state": pulse.steady_state, + "ramp_down": pulse.ramp_down, + "waiting": pulse.waiting + } for pulse in self.pulses]) + df.to_csv(filename, index=False) + + @staticmethod + def from_txt_file(filename: str): + df = pd.read_csv(filename) + pulses = [ + Pulse( + pulse_type=row["pulse_type"], + nb_pulses=int(row["nb_pulses"]), + ramp_up=float(row["ramp_up"]), + steady_state=float(row["steady_state"]), + ramp_down=float(row["ramp_down"]), + waiting=float(row["waiting"]), + ) + for _, row in df.iterrows() + ] + return Scenario(pulses) diff --git a/test/test_scenario_python.py b/test/test_scenario_python.py new file mode 100644 index 00000000..70fd8269 --- /dev/null +++ b/test/test_scenario_python.py @@ -0,0 +1,96 @@ +from hisp.scenario import Scenario, Pulse + + +def test_scenario(): + scenario = Scenario() + assert len(scenario.pulses) == 0 + + scenario.pulses.append( + Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=0.1, + steady_state=0.2, + ramp_down=0.3, + waiting=0.4, + ) + ) + + assert len(scenario.pulses) == 1 + assert scenario.pulses[0].pulse_type == "FP" + assert scenario.pulses[0].nb_pulses == 2 + assert scenario.pulses[0].ramp_up == 0.1 + assert scenario.pulses[0].steady_state == 0.2 + assert scenario.pulses[0].ramp_down == 0.3 + assert scenario.pulses[0].waiting == 0.4 + + scenario.to_txt_file("test_scenario.txt") + scenario2 = Scenario.from_txt_file("test_scenario.txt") + + assert len(scenario2.pulses) == 1 + assert scenario2.pulses[0].pulse_type == "FP" + assert scenario2.pulses[0].nb_pulses == 2 + assert scenario2.pulses[0].ramp_up == 0.1 + assert scenario2.pulses[0].steady_state == 0.2 + assert scenario2.pulses[0].ramp_down == 0.3 + assert scenario2.pulses[0].waiting == 0.4 + +def test_scenario_several_pulses(): + scenario = Scenario() + assert len(scenario.pulses) == 0 + + scenario.pulses.append( + Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=0.1, + steady_state=0.2, + ramp_down=0.3, + waiting=0.4, + ) + ) + + scenario.pulses.append( + Pulse( + pulse_type="ICWC", + nb_pulses=3, + ramp_up=0.5, + steady_state=0.6, + ramp_down=0.7, + waiting=0.8, + ) + ) + + assert len(scenario.pulses) == 2 + assert scenario.pulses[0].pulse_type == "FP" + assert scenario.pulses[0].nb_pulses == 2 + assert scenario.pulses[0].ramp_up == 0.1 + assert scenario.pulses[0].steady_state == 0.2 + assert scenario.pulses[0].ramp_down == 0.3 + assert scenario.pulses[0].waiting == 0.4 + + assert scenario.pulses[1].pulse_type == "ICWC" + assert scenario.pulses[1].nb_pulses == 3 + assert scenario.pulses[1].ramp_up == 0.5 + assert scenario.pulses[1].steady_state == 0.6 + assert scenario.pulses[1].ramp_down == 0.7 + assert scenario.pulses[1].waiting == 0.8 + + scenario.to_txt_file("test_scenario.txt") + scenario2 = Scenario.from_txt_file("test_scenario.txt") + + assert len(scenario2.pulses) == 2 + assert scenario2.pulses[0].pulse_type == "FP" + assert scenario2.pulses[0].nb_pulses == 2 + assert scenario2.pulses[0].ramp_up == 0.1 + assert scenario2.pulses[0].steady_state == 0.2 + assert scenario2.pulses[0].ramp_down == 0.3 + assert scenario2.pulses[0].waiting == 0.4 + + assert scenario2.pulses[1].pulse_type == "ICWC" + assert scenario2.pulses[1].nb_pulses == 3 + assert scenario2.pulses[1].ramp_up == 0.5 + assert scenario2.pulses[1].steady_state == 0.6 + assert scenario2.pulses[1].ramp_down == 0.7 + assert scenario2.pulses[1].waiting == 0.8 + \ No newline at end of file From 3575686a50f9ca8cbea9d4817db117035cd9ac63 Mon Sep 17 00:00:00 2001 From: RemDelaporteMathurin Date: Sat, 2 Nov 2024 13:39:30 -0400 Subject: [PATCH 02/11] ported existing tests --- src/hisp/scenario.py | 105 +++++++++++++++++++++++++++++++++++ test/test_scenario_python.py | 95 ++++++++++++++++++++++++++++++- 2 files changed, 198 insertions(+), 2 deletions(-) diff --git a/src/hisp/scenario.py b/src/hisp/scenario.py index d514ef32..9814688c 100644 --- a/src/hisp/scenario.py +++ b/src/hisp/scenario.py @@ -21,6 +21,10 @@ def __init__(self, pulse_type: str, nb_pulses: int, ramp_up: float, steady_state def total_duration(self) -> float: return self.ramp_up + self.steady_state + self.ramp_down + self.waiting + @property + def duration_no_waiting(self) -> float: + return self.total_duration - self.waiting + class Scenario: def __init__(self, pulses = None): self._pulses = pulses if pulses is not None else [] @@ -55,3 +59,104 @@ def from_txt_file(filename: str): for _, row in df.iterrows() ] return Scenario(pulses) + + def get_row(self, t:float) -> int: + """Returns the row of the scenario file that corresponds to the time t. + + Args: + t: the time in seconds + + Returns: + int: the row index of the scenario file corresponding to the time t + """ + current_time = 0 + for i, pulse in enumerate(self.pulses): + phase_duration = pulse.nb_pulses * pulse.total_duration + if t <= current_time + phase_duration: + return i + else: + current_time += phase_duration + + raise ValueError( + f"Time t {t} is out of bounds of the scenario file. Maximum time is {self.get_maximum_time()}" + ) + + def get_pulse(self, t: float) -> Pulse: + """Returns the pulse at time t. + + Args: + t: the time in seconds + + Returns: + Pulse: the pulse at time t + """ + row_idx = self.get_row(t) + return self.pulses[row_idx] + + def get_pulse_type(self, t: float) -> str: + """Returns the pulse type as a string at time t. + + Args: + t: time in seconds + + Returns: + pulse type (eg. FP, ICWC, RISP, GDC, BAKE) + """ + return self.get_pulse(t).pulse_type + + def get_maximum_time(self) -> float: + """Returns the maximum time of the scenario in seconds. + + Returns: + the maximum time of the scenario in seconds + """ + return sum([pulse.nb_pulses * pulse.total_duration for pulse in self.pulses]) + + def get_time_start_current_pulse(self, t: float): + """Returns the time (s) at which the current pulse started. + + Args: + t: the time in seconds + + Returns: + the time at which the current pulse started + """ + current_pulse = self.get_pulse(t) + pulse_index = self.pulses.index(current_pulse) + return sum([pulse.nb_pulses * pulse.total_duration for pulse in self.pulses[:pulse_index]]) + + # TODO this is the same as get_time_start_current_pulse, remove + def get_time_till_row(self, row:int) -> float: + """Returns the time (s) until the row in the scenario file. + + Args: + row: the row index in the scenario file + + Returns: + the time until the row in the scenario file + """ + return sum([pulse.nb_pulses * pulse.total_duration for pulse in self.pulses[:row]]) + + # TODO remove + def get_pulse_duration_no_waiting(self, row:int) -> float: + """Returns the total duration (without the waiting time) of a pulse in seconds for a given row in the file. + + Args: + row: the row index in the scenario file + + Returns: + the total duration of the pulse in seconds + """ + return self.pulses[row].duration_no_waiting + + # TODO remove + def get_pulse_duration(self, row:int) -> float: + """Returns the total duration of a pulse in seconds for a given row in the file. + + Args: + row: the row index in the scenario file + + Returns: + the total duration of the pulse in seconds + """ + return self.pulses[row].total_duration \ No newline at end of file diff --git a/test/test_scenario_python.py b/test/test_scenario_python.py index 70fd8269..9daa0187 100644 --- a/test/test_scenario_python.py +++ b/test/test_scenario_python.py @@ -1,5 +1,5 @@ from hisp.scenario import Scenario, Pulse - +import pytest def test_scenario(): scenario = Scenario() @@ -93,4 +93,95 @@ def test_scenario_several_pulses(): assert scenario2.pulses[1].steady_state == 0.6 assert scenario2.pulses[1].ramp_down == 0.7 assert scenario2.pulses[1].waiting == 0.8 - \ No newline at end of file + + +def test_maximum_time(): + # BUILD + + pulse1 = Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=455, + steady_state=455, + ramp_down=650, + waiting=1000, + ) + pulse2 = Pulse( + pulse_type="ICWC", + nb_pulses=2, + ramp_up=36, + steady_state=36, + ramp_down=180, + waiting=1000, + ) + my_scenario = Scenario([pulse1, pulse2]) + + + expected_maximum_time = 2 * (455 + 455 + 650 + 1000) + 2 * (36 + 36 + 180 + 1000) + + # RUN + computed_maximum_time = my_scenario.get_maximum_time() + + # TEST + assert computed_maximum_time == expected_maximum_time + + +pulse1 = Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=455, + steady_state=455, + ramp_down=650, + waiting=1000, + ) +pulse2 = Pulse( + pulse_type="ICWC", + nb_pulses=2, + ramp_up=36, + steady_state=36, + ramp_down=180, + waiting=1000, +) +@pytest.mark.parametrize("t, expected_pulse", [(0, pulse1), (6000, pulse2), (1e5, None)]) +def test_get_pulse(t, expected_pulse): + + my_scenario = Scenario([pulse1, pulse2]) + + if expected_pulse is None: + with pytest.raises(ValueError): + my_scenario.get_pulse(t=t) + else: + pulse = my_scenario.get_pulse(t=t) + assert pulse == expected_pulse + +@pytest.mark.parametrize("t, expected_pulse", [(100, pulse1)]) +def test_one_pulse_scenario(t, expected_pulse): + my_scenario = Scenario([expected_pulse]) + + pulse = my_scenario.get_pulse(t=t) + + assert pulse == expected_pulse + +@pytest.mark.parametrize("row, expected_duration", [(0, 2560), (1, 1252)]) +def test_get_pulse_duration(row, expected_duration): + my_scenario = Scenario([pulse1, pulse2]) + + duration = my_scenario.get_pulse_duration(row=row) + + assert duration == expected_duration + +@pytest.mark.parametrize("row, expected_duration", [(0, 1560), (1, 252)]) +def test_get_pulse_duration_no_waiting(row, expected_duration): + my_scenario = Scenario([pulse1, pulse2]) + + duration = my_scenario.get_pulse_duration_no_waiting(row=row) + + assert duration == expected_duration + +@pytest.mark.parametrize("row, expected_time", [(0, 0.0), (1, 5120.0)]) +def test_get_time_till_row(row, expected_time): + my_scenario = Scenario([pulse1, pulse2]) + + elapsed_time = my_scenario.get_time_till_row(row=row) + + assert elapsed_time == expected_time \ No newline at end of file From d4174243359c57bf5bb5986e11f8cd484bb1875f Mon Sep 17 00:00:00 2001 From: RemDelaporteMathurin Date: Sat, 2 Nov 2024 13:40:55 -0400 Subject: [PATCH 03/11] one final test --- test/test_scenario_python.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/test/test_scenario_python.py b/test/test_scenario_python.py index 9daa0187..31c00351 100644 --- a/test/test_scenario_python.py +++ b/test/test_scenario_python.py @@ -1,5 +1,7 @@ from hisp.scenario import Scenario, Pulse import pytest +import matplotlib.pyplot as plt +import numpy as np def test_scenario(): scenario = Scenario() @@ -184,4 +186,29 @@ def test_get_time_till_row(row, expected_time): elapsed_time = my_scenario.get_time_till_row(row=row) - assert elapsed_time == expected_time \ No newline at end of file + assert elapsed_time == expected_time + +def test_reading_a_file(): + my_scenario = Scenario([pulse1, pulse2]) + + times = np.linspace(0, my_scenario.get_maximum_time(), 1000) + pulse_types = [] + for t in times: + pulse_type = my_scenario.get_pulse_type(t) + pulse_types.append(pulse_type) + + # color the line based on the pulse type + color = { + "FP": "red", + "ICWC": "blue", + "RISP": "green", + "GDC": "orange", + "BAKE": "purple", + } + + colors = [color[pulse_type] for pulse_type in pulse_types] + + for i in range(len(times) - 1): + plt.plot(times[i : i + 2], np.ones_like(times[i : i + 2]), c=colors[i]) + # plt.xscale("log") + # plt.show() \ No newline at end of file From 9e7d754376269a35176d1db578467f32519c8fa2 Mon Sep 17 00:00:00 2001 From: RemDelaporteMathurin Date: Sat, 2 Nov 2024 13:43:47 -0400 Subject: [PATCH 04/11] added temporary warning for RISP --- src/hisp/scenario.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/hisp/scenario.py b/src/hisp/scenario.py index 9814688c..b72ffe8e 100644 --- a/src/hisp/scenario.py +++ b/src/hisp/scenario.py @@ -1,5 +1,6 @@ import pandas as pd from typing import List +import warnings class Pulse: pulse_type: str @@ -19,6 +20,10 @@ def __init__(self, pulse_type: str, nb_pulses: int, ramp_up: float, steady_state @property def total_duration(self) -> float: + all_zeros = self.ramp_up == 0 and self.steady_state == 0 and self.ramp_down == 0 and self.waiting == 0 + if self.pulse_type == "RISP" and all_zeros: + warnings.warn("RISP pulse has all zeros for ramp_up, steady_state, ramp_down, waiting. Using hardcoded values. Please check the values in the scenario file.") + return 10 + 250 + 10 + 1530 return self.ramp_up + self.steady_state + self.ramp_down + self.waiting @property From aab1449f9ff8270612d725798fd9f84c9a4e81a0 Mon Sep 17 00:00:00 2001 From: RemDelaporteMathurin Date: Sat, 2 Nov 2024 13:49:49 -0400 Subject: [PATCH 05/11] Use the new scenario class --- mb_scenario.py | 5 +- src/hisp/__init__.py | 3 +- src/hisp/helpers.py | 2 +- src/hisp/scenario.py | 23 +++- test/test_scenario.py | 227 ++++++++++++++++++++++++++--------- test/test_scenario_python.py | 214 --------------------------------- 6 files changed, 201 insertions(+), 273 deletions(-) delete mode 100644 test/test_scenario_python.py diff --git a/mb_scenario.py b/mb_scenario.py index 57ddeeb2..d44362aa 100644 --- a/mb_scenario.py +++ b/mb_scenario.py @@ -10,7 +10,8 @@ import dolfinx.fem as fem import dolfinx -from hisp.helpers import PulsedSource, Scenario +from hisp.helpers import PulsedSource +from hisp.scenario import Scenario from hisp import CustomProblem # dolfinx.log.set_log_level(dolfinx.log.LogLevel.INFO) @@ -42,7 +43,7 @@ def gaussian_distribution(x, mod=ufl): def make_mb_model(nb_mb, scenario_file): ############# Input Flux, Heat Data ############# - my_scenario = Scenario(scenario_file) + my_scenario = Scenario.from_txt_file(scenario_file, old_format=True) my_model = CustomProblem() diff --git a/src/hisp/__init__.py b/src/hisp/__init__.py index 391c7d92..7666f43d 100644 --- a/src/hisp/__init__.py +++ b/src/hisp/__init__.py @@ -1,3 +1,4 @@ -from .helpers import PulsedSource, Scenario +from .helpers import PulsedSource from .h_transport_class import CustomProblem +from .scenario import Scenario, Pulse diff --git a/src/hisp/helpers.py b/src/hisp/helpers.py index 43b18b2b..083d2ac7 100644 --- a/src/hisp/helpers.py +++ b/src/hisp/helpers.py @@ -36,7 +36,7 @@ def update(self, t: float): self.flux_fenics.value = self.flux(t) -class Scenario: +class ScenarioOld: def __init__(self, filename: str): self.filename = filename data = np.genfromtxt(filename, dtype=str, comments="#") diff --git a/src/hisp/scenario.py b/src/hisp/scenario.py index b72ffe8e..14680ff6 100644 --- a/src/hisp/scenario.py +++ b/src/hisp/scenario.py @@ -50,7 +50,28 @@ def to_txt_file(self, filename: str): df.to_csv(filename, index=False) @staticmethod - def from_txt_file(filename: str): + def from_txt_file(filename: str, old_format=False): + if old_format: + pulses = [] + with open(filename, "r") as f: + for line in f: + # skip first line + if line.startswith("#"): + continue + + # assume this is the format + pulse_type, nb_pulses, ramp_up, steady_state, ramp_down, waiting = line.split() + pulses.append( + Pulse( + pulse_type=pulse_type, + nb_pulses=int(nb_pulses), + ramp_up=float(ramp_up), + steady_state=float(steady_state), + ramp_down=float(ramp_down), + waiting=float(waiting), + ) + ) + return Scenario(pulses) df = pd.read_csv(filename) pulses = [ Pulse( diff --git a/test/test_scenario.py b/test/test_scenario.py index 2ad8fe06..31c00351 100644 --- a/test/test_scenario.py +++ b/test/test_scenario.py @@ -1,17 +1,124 @@ -from hisp import Scenario -import os +from hisp.scenario import Scenario, Pulse import pytest -import numpy as np import matplotlib.pyplot as plt +import numpy as np -current_dir = os.path.dirname(__file__) -scenario_path = os.path.join(current_dir, "scenario_test.txt") -one_line_scenario_path = os.path.join(current_dir, "one_line_scenario.txt") +def test_scenario(): + scenario = Scenario() + assert len(scenario.pulses) == 0 + + scenario.pulses.append( + Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=0.1, + steady_state=0.2, + ramp_down=0.3, + waiting=0.4, + ) + ) + + assert len(scenario.pulses) == 1 + assert scenario.pulses[0].pulse_type == "FP" + assert scenario.pulses[0].nb_pulses == 2 + assert scenario.pulses[0].ramp_up == 0.1 + assert scenario.pulses[0].steady_state == 0.2 + assert scenario.pulses[0].ramp_down == 0.3 + assert scenario.pulses[0].waiting == 0.4 + + scenario.to_txt_file("test_scenario.txt") + scenario2 = Scenario.from_txt_file("test_scenario.txt") + + assert len(scenario2.pulses) == 1 + assert scenario2.pulses[0].pulse_type == "FP" + assert scenario2.pulses[0].nb_pulses == 2 + assert scenario2.pulses[0].ramp_up == 0.1 + assert scenario2.pulses[0].steady_state == 0.2 + assert scenario2.pulses[0].ramp_down == 0.3 + assert scenario2.pulses[0].waiting == 0.4 + +def test_scenario_several_pulses(): + scenario = Scenario() + assert len(scenario.pulses) == 0 + + scenario.pulses.append( + Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=0.1, + steady_state=0.2, + ramp_down=0.3, + waiting=0.4, + ) + ) + + scenario.pulses.append( + Pulse( + pulse_type="ICWC", + nb_pulses=3, + ramp_up=0.5, + steady_state=0.6, + ramp_down=0.7, + waiting=0.8, + ) + ) + + assert len(scenario.pulses) == 2 + assert scenario.pulses[0].pulse_type == "FP" + assert scenario.pulses[0].nb_pulses == 2 + assert scenario.pulses[0].ramp_up == 0.1 + assert scenario.pulses[0].steady_state == 0.2 + assert scenario.pulses[0].ramp_down == 0.3 + assert scenario.pulses[0].waiting == 0.4 + + assert scenario.pulses[1].pulse_type == "ICWC" + assert scenario.pulses[1].nb_pulses == 3 + assert scenario.pulses[1].ramp_up == 0.5 + assert scenario.pulses[1].steady_state == 0.6 + assert scenario.pulses[1].ramp_down == 0.7 + assert scenario.pulses[1].waiting == 0.8 + + scenario.to_txt_file("test_scenario.txt") + scenario2 = Scenario.from_txt_file("test_scenario.txt") + + assert len(scenario2.pulses) == 2 + assert scenario2.pulses[0].pulse_type == "FP" + assert scenario2.pulses[0].nb_pulses == 2 + assert scenario2.pulses[0].ramp_up == 0.1 + assert scenario2.pulses[0].steady_state == 0.2 + assert scenario2.pulses[0].ramp_down == 0.3 + assert scenario2.pulses[0].waiting == 0.4 + + assert scenario2.pulses[1].pulse_type == "ICWC" + assert scenario2.pulses[1].nb_pulses == 3 + assert scenario2.pulses[1].ramp_up == 0.5 + assert scenario2.pulses[1].steady_state == 0.6 + assert scenario2.pulses[1].ramp_down == 0.7 + assert scenario2.pulses[1].waiting == 0.8 def test_maximum_time(): # BUILD - my_scenario = Scenario(scenario_path) + + pulse1 = Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=455, + steady_state=455, + ramp_down=650, + waiting=1000, + ) + pulse2 = Pulse( + pulse_type="ICWC", + nb_pulses=2, + ramp_up=36, + steady_state=36, + ramp_down=180, + waiting=1000, + ) + my_scenario = Scenario([pulse1, pulse2]) + + expected_maximum_time = 2 * (455 + 455 + 650 + 1000) + 2 * (36 + 36 + 180 + 1000) # RUN @@ -21,75 +128,87 @@ def test_maximum_time(): assert computed_maximum_time == expected_maximum_time -@pytest.mark.parametrize("t, expected_row", [(0, 0), (6000, 1), (1e5, ValueError)]) -def test_get_pulse_row(t, expected_row): - my_scenario = Scenario(scenario_path) - - if isinstance(expected_row, type) and issubclass(expected_row, Exception): - with pytest.raises(expected_row): - my_scenario.get_row(t=t) +pulse1 = Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=455, + steady_state=455, + ramp_down=650, + waiting=1000, + ) +pulse2 = Pulse( + pulse_type="ICWC", + nb_pulses=2, + ramp_up=36, + steady_state=36, + ramp_down=180, + waiting=1000, +) +@pytest.mark.parametrize("t, expected_pulse", [(0, pulse1), (6000, pulse2), (1e5, None)]) +def test_get_pulse(t, expected_pulse): + + my_scenario = Scenario([pulse1, pulse2]) + + if expected_pulse is None: + with pytest.raises(ValueError): + my_scenario.get_pulse(t=t) else: - pulse_row = my_scenario.get_row(t=t) - assert pulse_row == expected_row - - -def test_reading_a_file(): - my_scenario = Scenario(scenario_path) - - times = np.linspace(0, my_scenario.get_maximum_time(), 1000) - pulse_types = [] - for t in times: - pulse_type = my_scenario.get_pulse_type(t) - pulse_types.append(pulse_type) + pulse = my_scenario.get_pulse(t=t) + assert pulse == expected_pulse - # color the line based on the pulse type - color = { - "FP": "red", - "ICWC": "blue", - "RISP": "green", - "GDC": "orange", - "BAKE": "purple", - } - - colors = [color[pulse_type] for pulse_type in pulse_types] - - for i in range(len(times) - 1): - plt.plot(times[i : i + 2], np.ones_like(times[i : i + 2]), c=colors[i]) - # plt.xscale("log") - # plt.show() +@pytest.mark.parametrize("t, expected_pulse", [(100, pulse1)]) +def test_one_pulse_scenario(t, expected_pulse): + my_scenario = Scenario([expected_pulse]) + pulse = my_scenario.get_pulse(t=t) -@pytest.mark.parametrize("t, expected_row", [(100, 0)]) -def test_one_line_scenario(t, expected_row): - my_scenario = Scenario(one_line_scenario_path) - - pulse_row = my_scenario.get_row(t=t) - - assert pulse_row == expected_row - + assert pulse == expected_pulse @pytest.mark.parametrize("row, expected_duration", [(0, 2560), (1, 1252)]) def test_get_pulse_duration(row, expected_duration): - my_scenario = Scenario(scenario_path) + my_scenario = Scenario([pulse1, pulse2]) duration = my_scenario.get_pulse_duration(row=row) assert duration == expected_duration - @pytest.mark.parametrize("row, expected_duration", [(0, 1560), (1, 252)]) def test_get_pulse_duration_no_waiting(row, expected_duration): - my_scenario = Scenario(scenario_path) + my_scenario = Scenario([pulse1, pulse2]) duration = my_scenario.get_pulse_duration_no_waiting(row=row) assert duration == expected_duration - @pytest.mark.parametrize("row, expected_time", [(0, 0.0), (1, 5120.0)]) def test_get_time_till_row(row, expected_time): - my_scenario = Scenario(scenario_path) + my_scenario = Scenario([pulse1, pulse2]) elapsed_time = my_scenario.get_time_till_row(row=row) assert elapsed_time == expected_time + +def test_reading_a_file(): + my_scenario = Scenario([pulse1, pulse2]) + + times = np.linspace(0, my_scenario.get_maximum_time(), 1000) + pulse_types = [] + for t in times: + pulse_type = my_scenario.get_pulse_type(t) + pulse_types.append(pulse_type) + + # color the line based on the pulse type + color = { + "FP": "red", + "ICWC": "blue", + "RISP": "green", + "GDC": "orange", + "BAKE": "purple", + } + + colors = [color[pulse_type] for pulse_type in pulse_types] + + for i in range(len(times) - 1): + plt.plot(times[i : i + 2], np.ones_like(times[i : i + 2]), c=colors[i]) + # plt.xscale("log") + # plt.show() \ No newline at end of file diff --git a/test/test_scenario_python.py b/test/test_scenario_python.py deleted file mode 100644 index 31c00351..00000000 --- a/test/test_scenario_python.py +++ /dev/null @@ -1,214 +0,0 @@ -from hisp.scenario import Scenario, Pulse -import pytest -import matplotlib.pyplot as plt -import numpy as np - -def test_scenario(): - scenario = Scenario() - assert len(scenario.pulses) == 0 - - scenario.pulses.append( - Pulse( - pulse_type="FP", - nb_pulses=2, - ramp_up=0.1, - steady_state=0.2, - ramp_down=0.3, - waiting=0.4, - ) - ) - - assert len(scenario.pulses) == 1 - assert scenario.pulses[0].pulse_type == "FP" - assert scenario.pulses[0].nb_pulses == 2 - assert scenario.pulses[0].ramp_up == 0.1 - assert scenario.pulses[0].steady_state == 0.2 - assert scenario.pulses[0].ramp_down == 0.3 - assert scenario.pulses[0].waiting == 0.4 - - scenario.to_txt_file("test_scenario.txt") - scenario2 = Scenario.from_txt_file("test_scenario.txt") - - assert len(scenario2.pulses) == 1 - assert scenario2.pulses[0].pulse_type == "FP" - assert scenario2.pulses[0].nb_pulses == 2 - assert scenario2.pulses[0].ramp_up == 0.1 - assert scenario2.pulses[0].steady_state == 0.2 - assert scenario2.pulses[0].ramp_down == 0.3 - assert scenario2.pulses[0].waiting == 0.4 - -def test_scenario_several_pulses(): - scenario = Scenario() - assert len(scenario.pulses) == 0 - - scenario.pulses.append( - Pulse( - pulse_type="FP", - nb_pulses=2, - ramp_up=0.1, - steady_state=0.2, - ramp_down=0.3, - waiting=0.4, - ) - ) - - scenario.pulses.append( - Pulse( - pulse_type="ICWC", - nb_pulses=3, - ramp_up=0.5, - steady_state=0.6, - ramp_down=0.7, - waiting=0.8, - ) - ) - - assert len(scenario.pulses) == 2 - assert scenario.pulses[0].pulse_type == "FP" - assert scenario.pulses[0].nb_pulses == 2 - assert scenario.pulses[0].ramp_up == 0.1 - assert scenario.pulses[0].steady_state == 0.2 - assert scenario.pulses[0].ramp_down == 0.3 - assert scenario.pulses[0].waiting == 0.4 - - assert scenario.pulses[1].pulse_type == "ICWC" - assert scenario.pulses[1].nb_pulses == 3 - assert scenario.pulses[1].ramp_up == 0.5 - assert scenario.pulses[1].steady_state == 0.6 - assert scenario.pulses[1].ramp_down == 0.7 - assert scenario.pulses[1].waiting == 0.8 - - scenario.to_txt_file("test_scenario.txt") - scenario2 = Scenario.from_txt_file("test_scenario.txt") - - assert len(scenario2.pulses) == 2 - assert scenario2.pulses[0].pulse_type == "FP" - assert scenario2.pulses[0].nb_pulses == 2 - assert scenario2.pulses[0].ramp_up == 0.1 - assert scenario2.pulses[0].steady_state == 0.2 - assert scenario2.pulses[0].ramp_down == 0.3 - assert scenario2.pulses[0].waiting == 0.4 - - assert scenario2.pulses[1].pulse_type == "ICWC" - assert scenario2.pulses[1].nb_pulses == 3 - assert scenario2.pulses[1].ramp_up == 0.5 - assert scenario2.pulses[1].steady_state == 0.6 - assert scenario2.pulses[1].ramp_down == 0.7 - assert scenario2.pulses[1].waiting == 0.8 - - -def test_maximum_time(): - # BUILD - - pulse1 = Pulse( - pulse_type="FP", - nb_pulses=2, - ramp_up=455, - steady_state=455, - ramp_down=650, - waiting=1000, - ) - pulse2 = Pulse( - pulse_type="ICWC", - nb_pulses=2, - ramp_up=36, - steady_state=36, - ramp_down=180, - waiting=1000, - ) - my_scenario = Scenario([pulse1, pulse2]) - - - expected_maximum_time = 2 * (455 + 455 + 650 + 1000) + 2 * (36 + 36 + 180 + 1000) - - # RUN - computed_maximum_time = my_scenario.get_maximum_time() - - # TEST - assert computed_maximum_time == expected_maximum_time - - -pulse1 = Pulse( - pulse_type="FP", - nb_pulses=2, - ramp_up=455, - steady_state=455, - ramp_down=650, - waiting=1000, - ) -pulse2 = Pulse( - pulse_type="ICWC", - nb_pulses=2, - ramp_up=36, - steady_state=36, - ramp_down=180, - waiting=1000, -) -@pytest.mark.parametrize("t, expected_pulse", [(0, pulse1), (6000, pulse2), (1e5, None)]) -def test_get_pulse(t, expected_pulse): - - my_scenario = Scenario([pulse1, pulse2]) - - if expected_pulse is None: - with pytest.raises(ValueError): - my_scenario.get_pulse(t=t) - else: - pulse = my_scenario.get_pulse(t=t) - assert pulse == expected_pulse - -@pytest.mark.parametrize("t, expected_pulse", [(100, pulse1)]) -def test_one_pulse_scenario(t, expected_pulse): - my_scenario = Scenario([expected_pulse]) - - pulse = my_scenario.get_pulse(t=t) - - assert pulse == expected_pulse - -@pytest.mark.parametrize("row, expected_duration", [(0, 2560), (1, 1252)]) -def test_get_pulse_duration(row, expected_duration): - my_scenario = Scenario([pulse1, pulse2]) - - duration = my_scenario.get_pulse_duration(row=row) - - assert duration == expected_duration - -@pytest.mark.parametrize("row, expected_duration", [(0, 1560), (1, 252)]) -def test_get_pulse_duration_no_waiting(row, expected_duration): - my_scenario = Scenario([pulse1, pulse2]) - - duration = my_scenario.get_pulse_duration_no_waiting(row=row) - - assert duration == expected_duration - -@pytest.mark.parametrize("row, expected_time", [(0, 0.0), (1, 5120.0)]) -def test_get_time_till_row(row, expected_time): - my_scenario = Scenario([pulse1, pulse2]) - - elapsed_time = my_scenario.get_time_till_row(row=row) - - assert elapsed_time == expected_time - -def test_reading_a_file(): - my_scenario = Scenario([pulse1, pulse2]) - - times = np.linspace(0, my_scenario.get_maximum_time(), 1000) - pulse_types = [] - for t in times: - pulse_type = my_scenario.get_pulse_type(t) - pulse_types.append(pulse_type) - - # color the line based on the pulse type - color = { - "FP": "red", - "ICWC": "blue", - "RISP": "green", - "GDC": "orange", - "BAKE": "purple", - } - - colors = [color[pulse_type] for pulse_type in pulse_types] - - for i in range(len(times) - 1): - plt.plot(times[i : i + 2], np.ones_like(times[i : i + 2]), c=colors[i]) - # plt.xscale("log") - # plt.show() \ No newline at end of file From c0050e8523381d8b8ef47852f41c3534baad81ef Mon Sep 17 00:00:00 2001 From: RemDelaporteMathurin Date: Sat, 2 Nov 2024 13:50:30 -0400 Subject: [PATCH 06/11] removed class --- src/hisp/helpers.py | 118 -------------------------------------------- 1 file changed, 118 deletions(-) diff --git a/src/hisp/helpers.py b/src/hisp/helpers.py index 083d2ac7..dfe0f7d6 100644 --- a/src/hisp/helpers.py +++ b/src/hisp/helpers.py @@ -1,7 +1,6 @@ import festim as F from dolfinx.fem.function import Constant import ufl -import numpy as np class PulsedSource(F.ParticleSource): @@ -34,120 +33,3 @@ def create_value_fenics(self, mesh, temperature, t: Constant): def update(self, t: float): self.flux_fenics.value = self.flux(t) - - -class ScenarioOld: - def __init__(self, filename: str): - self.filename = filename - data = np.genfromtxt(filename, dtype=str, comments="#") - if isinstance(data[0], str): - self.data = [data] - else: - self.data = data - - def get_row(self, t: float): - """Returns the row of the scenario file that corresponds to the time t. - - Args: - t (float): the time in seconds - - Returns: - int: the row index of the scenario file corresponding to the time t - """ - current_time = 0 - for i, row in enumerate(self.data): - nb_pulses = int(row[1]) - phase_duration = nb_pulses * self.get_pulse_duration(i) - if t <= current_time + phase_duration: - return i - else: - current_time += phase_duration - - raise ValueError( - f"Time t {t} is out of bounds of the scenario file. Maximum time is {self.get_maximum_time()}" - ) - - def get_pulse_type(self, t: float) -> str: - """Returns the pulse type as a string at time t. - - Args: - t (float): time in seconds - - Returns: - str: pulse type (eg. FP, ICWC, RISP, GDC, BAKE) - """ - row_idx = self.get_row(t) - return self.data[row_idx][0] - - def get_pulse_duration(self, row: int) -> float: - """Returns the total duration of a pulse in seconds for a given row in the file. - - Args: - row (int): the row index in the scenario file - - Returns: - float: the total duration of the pulse in seconds - """ - row_data = self.data[row] - pulse_type = row_data[0] - if pulse_type == "RISP": # hard coded because it's zero in the files - ramp_up = 10 - steady_state = 250 - ramp_down = 10 - waiting = 1530 - total_duration = ramp_up + steady_state + ramp_down + waiting - return total_duration - - ramp_up = float(row_data[2]) - steady_state = float(row_data[4]) - ramp_down = float(row_data[3]) - waiting = float(row_data[5]) - - total_duration = ramp_up + steady_state + ramp_down + waiting - return total_duration - - def get_pulse_duration_no_waiting(self, row: int) -> float: - """Returns the total duration (without the waiting time) of a pulse in seconds for a given row in the file. - - Args: - row (int): the row index in the scenario file - - Returns: - float: the total duration of the pulse in seconds - """ - row_data = self.data[row] - pulse_type = row_data[0] - if pulse_type == "RISP": # hard coded because it's zero in the files - waiting_time = 1530 - else: - waiting_time = float(row_data[5]) - - duration = self.get_pulse_duration(row) - waiting_time - return duration - - def get_time_till_row(self, row:int) -> float: - """Returns the time that has elapsed in scenario up until start of current row. - - Args: - row (int): the row index in the scenario file - - Returns: - float: the time that has elapsed in scenario until and not including input row. - """ - time_elapsed = 0 - for prev_row_id in range(0,row): - nb_pulses = int(self.data[prev_row_id][1]) - time_elapsed += nb_pulses * self.get_pulse_duration(prev_row_id) - return time_elapsed - - def get_maximum_time(self) -> float: - """Returns the maximum time in seconds for the scenario file. - - Returns: - float: the maximum time in seconds - """ - max_time = 0 - for i, row in enumerate(self.data): - nb_pulses = int(row[1]) - max_time += nb_pulses * self.get_pulse_duration(i) - return max_time From 7029ed67126fb176138c9ec36cf0f3a25a188e6e Mon Sep 17 00:00:00 2001 From: RemDelaporteMathurin Date: Sat, 2 Nov 2024 13:51:40 -0400 Subject: [PATCH 07/11] black formatting --- src/hisp/scenario.py | 77 +++++++++++++++++++++++++++++++------------- 1 file changed, 54 insertions(+), 23 deletions(-) diff --git a/src/hisp/scenario.py b/src/hisp/scenario.py index 14680ff6..69e62334 100644 --- a/src/hisp/scenario.py +++ b/src/hisp/scenario.py @@ -2,6 +2,7 @@ from typing import List import warnings + class Pulse: pulse_type: str nb_pulses: int @@ -10,19 +11,34 @@ class Pulse: ramp_down: float waiting: float - def __init__(self, pulse_type: str, nb_pulses: int, ramp_up: float, steady_state: float, ramp_down: float, waiting: float): + def __init__( + self, + pulse_type: str, + nb_pulses: int, + ramp_up: float, + steady_state: float, + ramp_down: float, + waiting: float, + ): self.pulse_type = pulse_type self.nb_pulses = nb_pulses self.ramp_up = ramp_up self.steady_state = steady_state self.ramp_down = ramp_down self.waiting = waiting - + @property def total_duration(self) -> float: - all_zeros = self.ramp_up == 0 and self.steady_state == 0 and self.ramp_down == 0 and self.waiting == 0 + all_zeros = ( + self.ramp_up == 0 + and self.steady_state == 0 + and self.ramp_down == 0 + and self.waiting == 0 + ) if self.pulse_type == "RISP" and all_zeros: - warnings.warn("RISP pulse has all zeros for ramp_up, steady_state, ramp_down, waiting. Using hardcoded values. Please check the values in the scenario file.") + warnings.warn( + "RISP pulse has all zeros for ramp_up, steady_state, ramp_down, waiting. Using hardcoded values. Please check the values in the scenario file." + ) return 10 + 250 + 10 + 1530 return self.ramp_up + self.steady_state + self.ramp_down + self.waiting @@ -30,23 +46,29 @@ def total_duration(self) -> float: def duration_no_waiting(self) -> float: return self.total_duration - self.waiting + class Scenario: - def __init__(self, pulses = None): + def __init__(self, pulses=None): self._pulses = pulses if pulses is not None else [] - + @property def pulses(self) -> List[Pulse]: return self._pulses def to_txt_file(self, filename: str): - df = pd.DataFrame([{ - "pulse_type": pulse.pulse_type, - "nb_pulses": pulse.nb_pulses, - "ramp_up": pulse.ramp_up, - "steady_state": pulse.steady_state, - "ramp_down": pulse.ramp_down, - "waiting": pulse.waiting - } for pulse in self.pulses]) + df = pd.DataFrame( + [ + { + "pulse_type": pulse.pulse_type, + "nb_pulses": pulse.nb_pulses, + "ramp_up": pulse.ramp_up, + "steady_state": pulse.steady_state, + "ramp_down": pulse.ramp_down, + "waiting": pulse.waiting, + } + for pulse in self.pulses + ] + ) df.to_csv(filename, index=False) @staticmethod @@ -60,7 +82,9 @@ def from_txt_file(filename: str, old_format=False): continue # assume this is the format - pulse_type, nb_pulses, ramp_up, steady_state, ramp_down, waiting = line.split() + pulse_type, nb_pulses, ramp_up, steady_state, ramp_down, waiting = ( + line.split() + ) pulses.append( Pulse( pulse_type=pulse_type, @@ -86,7 +110,7 @@ def from_txt_file(filename: str, old_format=False): ] return Scenario(pulses) - def get_row(self, t:float) -> int: + def get_row(self, t: float) -> int: """Returns the row of the scenario file that corresponds to the time t. Args: @@ -137,7 +161,7 @@ def get_maximum_time(self) -> float: the maximum time of the scenario in seconds """ return sum([pulse.nb_pulses * pulse.total_duration for pulse in self.pulses]) - + def get_time_start_current_pulse(self, t: float): """Returns the time (s) at which the current pulse started. @@ -149,10 +173,15 @@ def get_time_start_current_pulse(self, t: float): """ current_pulse = self.get_pulse(t) pulse_index = self.pulses.index(current_pulse) - return sum([pulse.nb_pulses * pulse.total_duration for pulse in self.pulses[:pulse_index]]) + return sum( + [ + pulse.nb_pulses * pulse.total_duration + for pulse in self.pulses[:pulse_index] + ] + ) # TODO this is the same as get_time_start_current_pulse, remove - def get_time_till_row(self, row:int) -> float: + def get_time_till_row(self, row: int) -> float: """Returns the time (s) until the row in the scenario file. Args: @@ -161,10 +190,12 @@ def get_time_till_row(self, row:int) -> float: Returns: the time until the row in the scenario file """ - return sum([pulse.nb_pulses * pulse.total_duration for pulse in self.pulses[:row]]) + return sum( + [pulse.nb_pulses * pulse.total_duration for pulse in self.pulses[:row]] + ) # TODO remove - def get_pulse_duration_no_waiting(self, row:int) -> float: + def get_pulse_duration_no_waiting(self, row: int) -> float: """Returns the total duration (without the waiting time) of a pulse in seconds for a given row in the file. Args: @@ -176,7 +207,7 @@ def get_pulse_duration_no_waiting(self, row:int) -> float: return self.pulses[row].duration_no_waiting # TODO remove - def get_pulse_duration(self, row:int) -> float: + def get_pulse_duration(self, row: int) -> float: """Returns the total duration of a pulse in seconds for a given row in the file. Args: @@ -185,4 +216,4 @@ def get_pulse_duration(self, row:int) -> float: Returns: the total duration of the pulse in seconds """ - return self.pulses[row].total_duration \ No newline at end of file + return self.pulses[row].total_duration From b73566cbe0a22bcd20b4ceedbdd4052c48c51e5a Mon Sep 17 00:00:00 2001 From: RemDelaporteMathurin Date: Sat, 2 Nov 2024 13:54:39 -0400 Subject: [PATCH 08/11] docs --- src/hisp/scenario.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/hisp/scenario.py b/src/hisp/scenario.py index 69e62334..4f9a2810 100644 --- a/src/hisp/scenario.py +++ b/src/hisp/scenario.py @@ -48,7 +48,12 @@ def duration_no_waiting(self) -> float: class Scenario: - def __init__(self, pulses=None): + def __init__(self, pulses: List[Pulse] = None): + """Initializes a Scenario object containing several pulses. + + Args: + pulses: The list of pulses in the scenario. Each pulse is a Pulse object. + """ self._pulses = pulses if pulses is not None else [] @property @@ -72,7 +77,7 @@ def to_txt_file(self, filename: str): df.to_csv(filename, index=False) @staticmethod - def from_txt_file(filename: str, old_format=False): + def from_txt_file(filename: str, old_format=False) -> "Scenario": if old_format: pulses = [] with open(filename, "r") as f: From 75f61a687d8fff05c1b4def2c743d89b49b77336 Mon Sep 17 00:00:00 2001 From: RemDelaporteMathurin Date: Sat, 2 Nov 2024 13:56:27 -0400 Subject: [PATCH 09/11] black + removed txt files --- test/one_line_scenario.txt | 2 -- test/scenario_test.txt | 3 --- test/test_scenario.py | 32 +++++++++++++++++++++----------- 3 files changed, 21 insertions(+), 16 deletions(-) delete mode 100644 test/one_line_scenario.txt delete mode 100644 test/scenario_test.txt diff --git a/test/one_line_scenario.txt b/test/one_line_scenario.txt deleted file mode 100644 index 328054bc..00000000 --- a/test/one_line_scenario.txt +++ /dev/null @@ -1,2 +0,0 @@ -# PULSE TYPE, NUMBER OF PULSES, RAMP UP, RAMP DOWN, STEADY STATE, WAITING -FP 2 455 455 650 1000 \ No newline at end of file diff --git a/test/scenario_test.txt b/test/scenario_test.txt deleted file mode 100644 index 15ff0dc0..00000000 --- a/test/scenario_test.txt +++ /dev/null @@ -1,3 +0,0 @@ -# PULSE TYPE, NUMBER OF PULSES, RAMP UP, RAMP DOWN, STEADY STATE, WAITING -FP 2 455 455 650 1000 -ICWC 2 36 36 180 1000 \ No newline at end of file diff --git a/test/test_scenario.py b/test/test_scenario.py index 31c00351..aefb7d31 100644 --- a/test/test_scenario.py +++ b/test/test_scenario.py @@ -3,6 +3,7 @@ import matplotlib.pyplot as plt import numpy as np + def test_scenario(): scenario = Scenario() assert len(scenario.pulses) == 0 @@ -37,6 +38,7 @@ def test_scenario(): assert scenario2.pulses[0].ramp_down == 0.3 assert scenario2.pulses[0].waiting == 0.4 + def test_scenario_several_pulses(): scenario = Scenario() assert len(scenario.pulses) == 0 @@ -118,7 +120,6 @@ def test_maximum_time(): ) my_scenario = Scenario([pulse1, pulse2]) - expected_maximum_time = 2 * (455 + 455 + 650 + 1000) + 2 * (36 + 36 + 180 + 1000) # RUN @@ -129,13 +130,13 @@ def test_maximum_time(): pulse1 = Pulse( - pulse_type="FP", - nb_pulses=2, - ramp_up=455, - steady_state=455, - ramp_down=650, - waiting=1000, - ) + pulse_type="FP", + nb_pulses=2, + ramp_up=455, + steady_state=455, + ramp_down=650, + waiting=1000, +) pulse2 = Pulse( pulse_type="ICWC", nb_pulses=2, @@ -144,18 +145,23 @@ def test_maximum_time(): ramp_down=180, waiting=1000, ) -@pytest.mark.parametrize("t, expected_pulse", [(0, pulse1), (6000, pulse2), (1e5, None)]) + + +@pytest.mark.parametrize( + "t, expected_pulse", [(0, pulse1), (6000, pulse2), (1e5, None)] +) def test_get_pulse(t, expected_pulse): my_scenario = Scenario([pulse1, pulse2]) - if expected_pulse is None: + if expected_pulse is None: with pytest.raises(ValueError): my_scenario.get_pulse(t=t) else: pulse = my_scenario.get_pulse(t=t) assert pulse == expected_pulse + @pytest.mark.parametrize("t, expected_pulse", [(100, pulse1)]) def test_one_pulse_scenario(t, expected_pulse): my_scenario = Scenario([expected_pulse]) @@ -164,6 +170,7 @@ def test_one_pulse_scenario(t, expected_pulse): assert pulse == expected_pulse + @pytest.mark.parametrize("row, expected_duration", [(0, 2560), (1, 1252)]) def test_get_pulse_duration(row, expected_duration): my_scenario = Scenario([pulse1, pulse2]) @@ -172,6 +179,7 @@ def test_get_pulse_duration(row, expected_duration): assert duration == expected_duration + @pytest.mark.parametrize("row, expected_duration", [(0, 1560), (1, 252)]) def test_get_pulse_duration_no_waiting(row, expected_duration): my_scenario = Scenario([pulse1, pulse2]) @@ -180,6 +188,7 @@ def test_get_pulse_duration_no_waiting(row, expected_duration): assert duration == expected_duration + @pytest.mark.parametrize("row, expected_time", [(0, 0.0), (1, 5120.0)]) def test_get_time_till_row(row, expected_time): my_scenario = Scenario([pulse1, pulse2]) @@ -188,6 +197,7 @@ def test_get_time_till_row(row, expected_time): assert elapsed_time == expected_time + def test_reading_a_file(): my_scenario = Scenario([pulse1, pulse2]) @@ -211,4 +221,4 @@ def test_reading_a_file(): for i in range(len(times) - 1): plt.plot(times[i : i + 2], np.ones_like(times[i : i + 2]), c=colors[i]) # plt.xscale("log") - # plt.show() \ No newline at end of file + # plt.show() From 0836e1e65baff713cd21c6c0adcd022304e94c99 Mon Sep 17 00:00:00 2001 From: RemDelaporteMathurin Date: Sat, 2 Nov 2024 14:03:29 -0400 Subject: [PATCH 10/11] added tests for Pulse + better handling of RISP --- src/hisp/scenario.py | 13 +++++--- test/test_pulse.py | 75 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 4 deletions(-) create mode 100644 test/test_pulse.py diff --git a/src/hisp/scenario.py b/src/hisp/scenario.py index 4f9a2810..aa330122 100644 --- a/src/hisp/scenario.py +++ b/src/hisp/scenario.py @@ -36,10 +36,15 @@ def total_duration(self) -> float: and self.waiting == 0 ) if self.pulse_type == "RISP" and all_zeros: - warnings.warn( - "RISP pulse has all zeros for ramp_up, steady_state, ramp_down, waiting. Using hardcoded values. Please check the values in the scenario file." - ) - return 10 + 250 + 10 + 1530 + msg = "RISP pulse has all zeros for ramp_up, steady_state, ramp_down, waiting. " + msg += "Setting hardcoded values. Please check the values in the scenario file." + warnings.warn(msg, UserWarning) + + self.ramp_up = 10 + self.steady_state = 250 + self.ramp_down = 10 + self.waiting = 1530 + return self.ramp_up + self.steady_state + self.ramp_down + self.waiting @property diff --git a/test/test_pulse.py b/test/test_pulse.py new file mode 100644 index 00000000..631fdf67 --- /dev/null +++ b/test/test_pulse.py @@ -0,0 +1,75 @@ +from hisp import Pulse + +import pytest + +def test_pulse_initialization(): + pulse = Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=0.1, + steady_state=0.2, + ramp_down=0.3, + waiting=0.4, + ) + + assert pulse.pulse_type == "FP" + assert pulse.nb_pulses == 2 + assert pulse.ramp_up == 0.1 + assert pulse.steady_state == 0.2 + assert pulse.ramp_down == 0.3 + assert pulse.waiting == 0.4 + + +def test_pulse_total_duration(): + pulse = Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=0.1, + steady_state=0.2, + ramp_down=0.3, + waiting=0.4, + ) + + assert pulse.total_duration == 1.0 + + +def test_pulse_total_duration_with_zeros(): + pulse = Pulse( + pulse_type="RISP", + nb_pulses=1, + ramp_up=0.0, + steady_state=0.0, + ramp_down=0.0, + waiting=0.0, + ) + + assert pulse.total_duration == 1800.0 + +def test_pulse_risp_with_zeros_raises_warning(): + with pytest.warns(UserWarning): + pulse = Pulse( + pulse_type="RISP", + nb_pulses=1, + ramp_up=0.0, + steady_state=0.0, + ramp_down=0.0, + waiting=0.0, + ) + pulse.total_duration + assert pulse.ramp_up != 0.0 + assert pulse.steady_state != 0.0 + assert pulse.ramp_down != 0.0 + assert pulse.waiting != 0.0 + +def test_pulse_duration_no_waiting(): + pulse = Pulse( + pulse_type="FP", + nb_pulses=2, + ramp_up=0.1, + steady_state=0.2, + ramp_down=0.3, + waiting=0.4, + ) + + assert pulse.duration_no_waiting == 0.6 + assert pulse.duration_no_waiting == pulse.total_duration - pulse.waiting \ No newline at end of file From c2e3d6fd06cde5ce261562f78d6e0a5aeb0f9cc9 Mon Sep 17 00:00:00 2001 From: RemDelaporteMathurin Date: Sat, 2 Nov 2024 14:22:22 -0400 Subject: [PATCH 11/11] added one test for no waiting --- test/test_pulse.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/test_pulse.py b/test/test_pulse.py index 631fdf67..e5445bef 100644 --- a/test/test_pulse.py +++ b/test/test_pulse.py @@ -45,6 +45,18 @@ def test_pulse_total_duration_with_zeros(): assert pulse.total_duration == 1800.0 +def test_pulse_total_duration_no_waiting_with_zeros(): + pulse = Pulse( + pulse_type="RISP", + nb_pulses=1, + ramp_up=0.0, + steady_state=0.0, + ramp_down=0.0, + waiting=0.0, + ) + + assert pulse.duration_no_waiting == 270.0 + def test_pulse_risp_with_zeros_raises_warning(): with pytest.warns(UserWarning): pulse = Pulse(