# Predicting Eigenmode Decompositions in Vibroacoustic Systems

Thank you for your interest in our method!
With this notebook , you can paint a indentation pattern and preview the frequency response curve.

To execute this notebook a cuda GPU is nessary. You can also open a version of this notebook in google colab. 

<a href="https://colab.research.google.com/github/ecker-lab/modeonet/blob/main/notebooks/plate_paint_google_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

IMPORTANT: This notebook does not implement the numerical simulation of the actual vibrations given the plate design. Thus, results displayed here are only based on the deep learning regression model. These results are expected to be biased towards lower values than numerically simulated vibrations.

In [1]:
import ipycanvas
import ipywidgets as widgets
from IPython.display import display
from PIL import Image
import io
import numpy as np

from vibromodes.models import ModeONet
from vibromodes.kirchhoff import DEFAULT_PLATE_PARAMETERS
import torch
import hdf5plugin
import h5py
from matplotlib.colors import rgb2hex

In [2]:
model = ModeONet(32,16)

model.eval()
model.load_state_dict(torch.load(f"./model.pth"))

device = torch.device("cuda")

model = model.to(device)

## Plate Paint

In the following cell you can draw an indentation pattern. The red dot is the excitation point. Below, the frequency response function and the mode responses are shown.

In [3]:
from ipywidgets import VBox, HBox, Button, Layout
from ipycanvas import MultiCanvas
from ipywidgets import VBox, Button
from IPython.display import display
import bqplot as bq
import matplotlib.pyplot as plt

from vibromodes.kirchhoff import PlateParameter
from vibromodes.velocity_field import field_dict2frf,linear2db


modeonet_color = "#4681B4"

class FrfPlotter:
    def __init__(self):
        # --- Step 1: Create a bqplot figure ---

        self.n_freqs = 300
        self.freqs = np.linspace(0,300,self.n_freqs)

        y = np.sin(self.freqs)

        x_sc = bq.LinearScale()
        y_sc = bq.LinearScale(min=0,max=70)

        self.frf_line = bq.Lines(x=self.freqs, y=y, scales={'x': x_sc, 'y': y_sc}, colors=[modeonet_color])

        self.mode_lines = []


        cmap = plt.get_cmap("bone")
        for i in range(16):

            color = cmap(i/16*0.6+0.2)
            mode_line = bq.Lines(x=self.freqs, y=np.sin(self.freqs*i), scales={'x': x_sc, 'y': y_sc}, colors=[rgb2hex(color)])
            self.mode_lines.append(mode_line)



        ax_x = bq.Axis(scale=x_sc, label='Frequency (Hz)')
        ax_y = bq.Axis(scale=y_sc, orientation='vertical', label='Velocity (dB)')

        fig = bq.Figure(marks=self.mode_lines + [self.frf_line], axes=[ax_x, ax_y], 
                        layout=Layout(width='400px', height='250px'),
                        fig_margin={'top': 10, 'bottom': 40, 'left': 40, 'right': 10})
        self.fig = fig

        self._setup_batch()
        self.update(np.zeros((121,181,4),dtype=np.uint8),0.5,0.5)

    def _setup_batch(self):
        self.phy_para = PlateParameter.from_array(torch.tensor(DEFAULT_PLATE_PARAMETERS).unsqueeze(0).float().to(device))
        self.tr_freqs = torch.linspace(-1,1,self.n_freqs).to(device).float().unsqueeze(0)
        #preallocate memory
        self.pattern = torch.empty((121,181,4),device=device,dtype=torch.uint8)
        self.model = model
        self.model.eval()
        self.model.compile()


    @torch.no_grad()
    def update(self,pattern,force_x,force_y):
        with torch.autocast("cuda",dtype=torch.bfloat16,enabled=True):
            self.pattern.copy_(torch.from_numpy(pattern),non_blocking=True)
            pattern = (self.pattern.to(torch.float32)/255)*2.-1.
            pattern = pattern.mean(dim=-1)
            #x,y are switched in the phy_paras
            self.phy_para.force_x[0]=force_y
            self.phy_para.force_y[0]=force_x


            mode_shapes,mode_respones = model.forward_eigenmodes_mode_dynamics(pattern.unsqueeze(0),self.phy_para.to_dict() ,
                                                                               self.tr_freqs)
            pred_field, _ = model.superposition(mode_shapes,mode_respones,self.tr_freqs)
            pred_frf = field_dict2frf(pred_field)
            pred_frf = pred_frf.cpu()

            mode_respones = model.superposition.mode_response_normed2physical(mode_respones,self.tr_freqs)
            mode_respones = linear2db(mode_respones)
            #for visibility move the mode responses down
            mode_respones -= 10
            mode_respones = mode_respones.cpu()
        
        self.frf_line.y = pred_frf[0].numpy()


        for i in range(16):
            self.mode_lines[i].y = mode_respones[0,:,i]


class DrawEngine:
    def __init__(self):
        self.frf_plotter = FrfPlotter()

        # Create two layers: layer 0 (drawing), layer 1 (circle)
        self.canvases = MultiCanvas(2, width=181, height=121,sync_image_data=True)
        self.canvases.layout.width = f'{181*2}px'
        self.canvases.layout.height = f'{121*2}px'


        self.draw_layer = self.canvases[0]
        self.draw_layer.sync_image_data = True
        self.circle_layer = self.canvases[1]

        ### Draw style

        self.draw_layer.stroke_style = "white"
        self.draw_layer.line_width = 10
        self.draw_layer.line_cap = 'round'      # makes stroke edges circular
        self.draw_layer.line_join = 'round'  

        ### Create color buttons and text panels
        colors = ['black', 'white']
        self.buttons = [widgets.Button(description='', 
                                layout=
                                widgets.Layout(width='30px', height='30px',
                                               border="3px solid gray",
                                               ),
                                style=dict(button_color=c)) for c in colors]

        # Circle properties
        self.circle_x = 91
        self.circle_y = 61
        self.circle_r = 4
        self.is_dragging = False

        # Drawing state
        self.is_drawing = False
        self.last_x, self.last_y = None, None


        self.clear_btn = Button(description='Clear Drawing')


    def run(self):
        self.draw_frame()
        self.draw_circle()

        self.draw_layer.fill_style = "black"
        self.draw_layer.fill_rect(0,0,self.draw_layer.width,self.draw_layer.height)

        self.clear_btn.on_click(self.clear_clicked)


        for button in self.buttons:
            button.on_click(self.on_color_click)
        # --- Bind events ---
        self.canvases.on_mouse_down(self.handle_mouse_down)
        self.canvases.on_mouse_move(self.handle_mouse_move)
        self.canvases.on_mouse_up(self.handle_mouse_up)

        text = widgets.Text("Select a color: ", layout=Layout(width='120px'), disabled=True)

        colors_box = widgets.HBox([text]+self.buttons)
        display(VBox([colors_box,
                HBox([self.canvases,self.frf_plotter.fig], layout=Layout(position='relative', width='900px', height='300px')),
                    self.clear_btn]))

    # --- draw frame ---
    # Add black frame
    def draw_frame(self):
        #self.draw_layer.stroke_rect(0, 0, 300, 300)
        pass



    # --- Circle drawing ---
    def draw_circle(self):
        self.circle_layer.clear()
        self.circle_layer.fill_style = "red"
        self.circle_layer.begin_path()
        self.circle_layer.arc(self.circle_x, self.circle_y, self.circle_r, 0, 2 * 3.1416)
        self.circle_layer.fill()


    # --- Event handlers ---
    def handle_mouse_down(self,x, y):
        dx, dy = x - self.circle_x, y - self.circle_y
        if dx*dx + dy*dy <= self.circle_r**2:
            self.is_dragging = True
        else:
            self.is_drawing = True
            self.last_x, self.last_y = x, y

    def handle_mouse_move(self,x, y):
        if self.is_dragging:
            self.circle_x, self.circle_y = x, y
            self.draw_circle()
        elif self.is_drawing:
            self.draw_layer.stroke_line(self.last_x, self.last_y, x, y)
            self.last_x, self.last_y = x, y

    def handle_mouse_up(self,x, y):
        self.is_dragging = False
        self.is_drawing = False

        self.draw_frf()
    
    def draw_frf(self):
        force_x = self.circle_x / 181
        force_y = self.circle_y / 121
        image = self.draw_layer.get_image_data()

        self.frf_plotter.update(image,force_x,force_y)

    # --- Clear button ---
    def clear_clicked(self,b):
        self.draw_layer.clear()
        self.draw_layer.fill_style = "black"
        self.draw_layer.fill_rect(0,0,self.draw_layer.width,self.draw_layer.height)


    # --- Set color button ---
    def on_color_click(self,b):
        self.draw_layer.stroke_style = b.style.button_color
        self.draw_frame()  # redraw frame after color change

    
engine = DrawEngine()
engine.run()

VBox(children=(HBox(children=(Text(value='Select a color: ', disabled=True, layout=Layout(width='120px')), Butâ€¦