Skip to content

Commit

Permalink
minor update
Browse files Browse the repository at this point in the history
  • Loading branch information
obouchaara committed Jun 11, 2024
1 parent d01b399 commit ab9333a
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions src/mechpy/core/symbolic/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,13 @@ def create_linear(cls, coord_system=None, data=None, field_params=None):
)
return cls(coord_system, scalar_field, field_params)

def plot(self, x_range=(-10, 10), y_range=(-10, 10), z_range=(-10, 10), num_points=20):
def plot(
self, x_range=(-10, 10), y_range=(-10, 10), z_range=(-10, 10), num_points=20
):
if not isinstance(self.coord_system, SymbolicCartesianCoordSystem):
raise NotImplementedError("Plotting is only implemented for Cartesian coordinates")
raise NotImplementedError(
"Plotting is only implemented for Cartesian coordinates"
)

# Create a meshgrid for the plot
x = np.linspace(x_range[0], x_range[1], num_points)
Expand All @@ -204,7 +208,9 @@ def plot(self, x_range=(-10, 10), y_range=(-10, 10), z_range=(-10, 10), num_poin

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
plt.subplots_adjust(left=0.1, bottom=0.25 + 0.05 * len(self.field_params)) # Adjust space for sliders
plt.subplots_adjust(
left=0.1, bottom=0.25 + 0.05 * len(self.field_params)
) # Adjust space for sliders

sliders = {}
slider_axes = []
Expand All @@ -214,8 +220,17 @@ def plot(self, x_range=(-10, 10), y_range=(-10, 10), z_range=(-10, 10), num_poin
if values is None:
raise ValueError(f"the param {param} values in not defined")
if isinstance(values, set):
ax_slider = plt.axes([0.1, 0.1 + 0.05 * i, 0.65, 0.03], facecolor="lightgoldenrodyellow")
slider = Slider(ax_slider, str(param), min(values), max(values), valinit=min(values), valstep=list(values))
ax_slider = plt.axes(
[0.1, 0.1 + 0.05 * i, 0.65, 0.03], facecolor="lightgoldenrodyellow"
)
slider = Slider(
ax_slider,
str(param),
min(values),
max(values),
valinit=min(values),
valstep=list(values),
)
sliders[param] = slider
slider_axes.append(ax_slider)
else:
Expand Down Expand Up @@ -270,18 +285,19 @@ def create(cls, coord_system=None, data=None, field_params=None):
f1 = f1(*basis)
f2 = f2(*basis)
f3 = f3(*basis)
data = sp.ImmutableDenseNDimArray([f1, f2, f3])
data = sp.NDimArray([f1, f2, f3])
else:
if coord_system is None:
coord_system = SymbolicCartesianCoordSystem()
try:
components = sp.NDimArray(data, shape=(3, 3))
if not all(isinstance(_, (sp.Expr, sp.Number)) for _ in components):
is_symbolic = lambda _: isinstance(_, (sp.Number, sp.Symbol, sp.Expr)) # to validation module
if not all(is_symbolic(_) for _ in components):
raise ValueError("data type error")
except:
raise ValueError("Conversion error")

data = sp.ImmutableDenseNDimArray(data)
data = sp.NDimArray(data)

return cls(data, coord_system, field_params)

Expand Down

0 comments on commit ab9333a

Please sign in to comment.