<a href="https://colab.research.google.com/github/kangwonlee/eng-math-2/blob/update-utils/Ch13_03.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


In [None]:
import dataclasses

import matplotlib.pyplot as plt
import numpy as np

from mpl_toolkits.mplot3d import Axes3D



In [None]:
@dataclasses.dataclass
class EigenFunctionPlotter:
    L : float = 1.0
    max_n : int = 10
    A : float = 1.0
    n_points : int = 181

    def __post_init__(self):
        self.x = np.linspace(0, self.L, self.n_points)
        self.pi_L = np.pi / self.L
        self.pi_x_L = self.pi_L * self.x

    def calc_eigenfunction(self, n:int):
        return np.sin(n * self.pi_x_L)

    def plot(self):
        axs = plt.figure(
            figsize=(6, (self.n_points*0.1))
        ).subplots(self.max_n, 1)

        for n, ax in zip(range(1, self.max_n+1), axs):
            ax.plot(self.x, self.calc_eigenfunction(n))
            ax.set_xlim(0, self.L)
            ax.set_ylabel(f'n={n}')
            ax.grid(True)
        ax.set_xlabel('x')
        plt.tight_layout()



In [None]:
p = EigenFunctionPlotter()
p.plot()



In [None]:
@dataclasses.dataclass
class SolutionPlotter(EigenFunctionPlotter):
    k : float = 1.0
    t_max: float = 1.0
    delta_t : float = 1e-2

    def __post_init__(self):
        super().__post_init__()
        self.t = np.arange(
            0.0,
            self.t_max + (self.delta_t * 0.5),
            self.delta_t
        )

        self.X, self.T = np.meshgrid(self.x, self.t)

    @staticmethod
    def calc_An(n:int):
        return (
            (200 * (1 + (-1)**(n+1)))
            /
            (n * np.pi)
        )

    def calc_sol_n(self, n):
        return (
            self.calc_An(n) * np.exp((-1)*(n**2)*self.T) * np.sin(n * self.X)
        )

    def cumsum_sol(self):
        return sum(
            map(
                self.calc_sol_n,
                range(1, self.max_n),
            )
        )

    def plot(self, elev:float=30.0, azim:float=30.0):
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.view_init(elev=elev, azim=azim)
        self.surf = ax.plot_surface(
            self.X, self.T, self.cumsum_sol(),
            cmap='viridis',
        )
        plt.colorbar(self.surf)
        ax.set_xlabel('x')
        ax.set_ylabel('t')
        ax.grid(True)



In [None]:
sp = SolutionPlotter(max_n=100, L=np.pi, delta_t=1e-3)
sp.plot()

