diff --git a/README.md b/README.md index 32e202d..df700c4 100644 --- a/README.md +++ b/README.md @@ -21,8 +21,16 @@ Artificial neural networks is a popular field of research in artificial intellig 2. Create a neural network model and process it. An example of this process is given in `examples/process_mnist_model.py` on [MNIST](http://yann.lecun.com/exdb/mnist/) data. 3. Start the visualization tool `start_tool.py` and select the neural network via `Load Processed Network` to render the representation of the neural network. +Or + +1. Run `start_tool.py --demo` to download data of an already processed model and render it. + Multiple scripts are located in `examples`, which can be adapted to create and process neural networks. `examples/evaluation_plots.py` for example can be used to recreate the evaluation data and plots of my thesis. +### Sample Model Importance + +A processed model can be found [here](https://drive.google.com/file/d/1LiVzBfB7LPrR95q_VO44wx4MyGNTj6vD/view?usp=sharing). + ## Rendering Tool The visualization tool `start_tool.py` can be used to render and/or process neural networks. Instead of existing ones, you can also generate random networks and process them of various sizes. For neural networks the visualization results in a more structured view of a neural network in regards to their trained parameters compared to the most common ones. diff --git a/VR_TOOL.md b/VR_TOOL.md index 185da0f..6577bc5 100644 --- a/VR_TOOL.md +++ b/VR_TOOL.md @@ -21,6 +21,10 @@ This tool can be used to render a processed neural network in VR. * `python start_tool_vr.py` +Or + +* Run `start_tool.py --demo` to download data of an already processed model and render it. + ## Controls Using Oculus Quest 2 controller: @@ -35,7 +39,7 @@ Using Oculus Quest 2 controller: ### GUI -See [README.md](./README.md)) for information on the desktop GUI +See [README.md](./README.md) for information on the desktop GUI ## Used Systems diff --git a/configs/processing.json b/configs/processing.json index 6211ef9..a2d8343 100644 --- a/configs/processing.json +++ b/configs/processing.json @@ -4,7 +4,7 @@ "layer_distance": 0.4, "layer_width": 1.0, "node_bandwidth_reduction": 0.95, - "prune_percentage": 0.0, + "prune_percentage": 0.9, "sampling_rate": 15.0, "smoothing": true, "smoothing_iterations": 8 diff --git a/configs/window.json b/configs/window.json index d1f7e69..91f8809 100644 --- a/configs/window.json +++ b/configs/window.json @@ -5,8 +5,8 @@ "monitor_id": 0, "screen_height": 900, "screen_width": 1600, - "screen_x": 2765, - "screen_y": 127, + "screen_x": 1644, + "screen_y": 237, "title": "NNVis Render", "width": 1600 } \ No newline at end of file diff --git a/gui/ui_window.py b/gui/ui_window.py index b9188d9..0bccf8d 100644 --- a/gui/ui_window.py +++ b/gui/ui_window.py @@ -79,10 +79,11 @@ def save_processed_nn_file(self) -> None: def open_processed_nn_file(self) -> None: filename = filedialog.askopenfilename(initialdir=DATA_PATH, title='Select A File', filetypes=(('processed nn files', '*.pro.npz'),)) - data_loader: ProcessedNNHandler = ProcessedNNHandler(filename) - self.settings['network_name'] = ntpath.basename( - filename) + '_processed' - self.update_layer(data_loader.layer_data, processed_nn=data_loader) + if filename != '': + data_loader: ProcessedNNHandler = ProcessedNNHandler(filename) + self.settings['network_name'] = ntpath.basename( + filename) + '_processed' + self.update_layer(data_loader.layer_data, processed_nn=data_loader) def open_importance_file(self) -> None: filename = filedialog.askopenfilename(initialdir=DATA_PATH, title='Select A File', diff --git a/requirements.txt b/requirements.txt index 24eb263..58b0dd7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,9 +2,10 @@ glfw==2.0.0 matplotlib==3.3.3 numpy==1.23.5 pandas==1.2.0 -Pillow==9.0.1 +Pillow==9.3.0 progressbar2==3.53.1 PyOpenGL==3.1.5 pyrr==0.10.3 scikit_learn==0.24.0 -tensorflow==2.9.2 +tensorflow==2.9.3 +wget==3.2 diff --git a/requirements_vr.txt b/requirements_vr.txt index f910941..3d8a188 100644 --- a/requirements_vr.txt +++ b/requirements_vr.txt @@ -3,9 +3,10 @@ matplotlib==3.3.3 numpy==1.23.5 openvr pandas==1.2.0 -Pillow==9.0.1 +Pillow==9.3.0 progressbar2==3.53.1 PyOpenGL==3.1.5 pyrr==0.10.3 scikit_learn==0.24.0 -tensorflow==2.9.2 +tensorflow==2.9.3 +wget==3.2 diff --git a/start_tool.py b/start_tool.py index 3030ee9..37a64b5 100644 --- a/start_tool.py +++ b/start_tool.py @@ -1,11 +1,17 @@ import logging +import ntpath +import os import threading import time +import zipfile +from argparse import ArgumentParser from typing import Optional +import wget from OpenGL.GL import GL_MAJOR_VERSION, GL_MINOR_VERSION, glGetIntegerv -from definitions import CameraPose +from data.data_handler import ProcessedNNHandler +from definitions import DATA_PATH, CameraPose from gui.constants import StatisticLink from gui.ui_window import OptionGui from opengl_helper.screenshot import create_screenshot @@ -15,13 +21,34 @@ from utility.performance import track_time from utility.window import Window, WindowHandler -global options_gui -options_gui = OptionGui() -setup_logger('tool') + +def download_and_unzip_sample() -> str: + output_directory = DATA_PATH + filename = wget.download( + 'https://drive.google.com/uc?export=download&id=1LiVzBfB7LPrR95q_VO44wx4MyGNTj6vD', out=output_directory) + zip_filepath = os.path.join(output_directory, filename) + with zipfile.ZipFile(zip_filepath, 'r') as zip_ref: + zip_ref.extractall(DATA_PATH) + return os.path.join(DATA_PATH, 'sample_model.npz') + + +def open_processed_network(option_gui: OptionGui, filename: str) -> None: + data_loader: ProcessedNNHandler = ProcessedNNHandler(filename) + option_gui.processing_config['prune_percentage'] = 0.9 + option_gui.processing_setting.set() + option_gui.settings['network_name'] = ntpath.basename( + filename) + '_processed' + option_gui.update_layer(data_loader.layer_data, processed_nn=data_loader) def compute_render(some_name: str) -> None: global options_gui + global use_sample + + if use_sample: + global sample_filepath + logging.info('Loading sample model...') + open_processed_network(options_gui, sample_filepath) width, height = 1920, 1200 @@ -72,6 +99,7 @@ def frame() -> None: if not options_gui.settings['Closed']: print('Start building network: ' + str(options_gui.settings['current_layer_data'])) + options_gui.settings['update_model'] = False network_processor = NetworkProcessor(options_gui.settings['current_layer_data'], options_gui.processing_config, importance_data=options_gui.settings['importance_data'], @@ -154,9 +182,39 @@ def frame() -> None: options_gui.destroy() -compute_render_thread: threading.Thread = threading.Thread( - target=compute_render, args=(1,)) -compute_render_thread.setDaemon(True) -compute_render_thread.start() +def parse_args() -> bool: + parser = ArgumentParser(prog='Start nn_vis tool') + parser.add_argument('--demo', action='store_true', + help='Download sample of a processed model and render it with 90% pruned edges instead of generating a random model.') + args = parser.parse_args() + return args.demo + -options_gui.start() +if __name__ == '__main__': + global options_gui + options_gui = OptionGui() + + global sample_filepath + sample_filepath = 'sample_model.npz' + + global use_sample + use_sample = parse_args() + + setup_logger('tool') + + if use_sample: + expected_sample_path = os.path.join(DATA_PATH, sample_filepath) + if not os.path.exists(expected_sample_path): + logging.info( + f'Downloading sample model to "{expected_sample_path}". This might take a minute ...') + sample_filepath = download_and_unzip_sample() + else: + logging.info( + f'Using sample model at "{expected_sample_path}"') + sample_filepath = expected_sample_path + compute_render_thread: threading.Thread = threading.Thread( + target=compute_render, args=(1,)) + compute_render_thread.setDaemon(True) + compute_render_thread.start() + + options_gui.start() diff --git a/start_tool_vr.py b/start_tool_vr.py index 26e90e7..d7b8898 100644 --- a/start_tool_vr.py +++ b/start_tool_vr.py @@ -1,10 +1,17 @@ import logging +import ntpath +import os import threading import time +import zipfile +from argparse import ArgumentParser from typing import List, Optional, Tuple +import wget from OpenGL.GL import GL_MAJOR_VERSION, GL_MINOR_VERSION, glGetIntegerv +from data.data_handler import ProcessedNNHandler +from definitions import DATA_PATH from gui.constants import StatisticLink from gui.ui_window import OptionGui from processing.network_processing import NetworkProcessor @@ -13,15 +20,36 @@ from utility.performance import track_time from vr.vr_handler import VRHandler -global options_gui -options_gui = OptionGui() -setup_logger('tool') - RENDER_MODES: List[Tuple[int, int]] = [(3, 2), (4, 1), (1, 1), (2, 2)] +def download_and_unzip_sample() -> str: + output_directory = DATA_PATH + filename = wget.download( + 'https://drive.google.com/uc?export=download&id=1LiVzBfB7LPrR95q_VO44wx4MyGNTj6vD', out=output_directory) + zip_filepath = os.path.join(output_directory, filename) + with zipfile.ZipFile(zip_filepath, 'r') as zip_ref: + zip_ref.extractall(DATA_PATH) + return os.path.join(DATA_PATH, 'sample_model.npz') + + +def open_processed_network(option_gui: OptionGui, filename: str) -> None: + data_loader: ProcessedNNHandler = ProcessedNNHandler(filename) + option_gui.processing_config['prune_percentage'] = 0.9 + option_gui.processing_setting.set() + option_gui.settings['network_name'] = ntpath.basename( + filename) + '_processed' + option_gui.update_layer(data_loader.layer_data, processed_nn=data_loader) + + def compute_render(_: str) -> None: global options_gui + global use_sample + + if use_sample: + global sample_filepath + logging.info('Loading sample model...') + open_processed_network(options_gui, sample_filepath) FileHandler().read_statistics() @@ -101,6 +129,7 @@ def frame() -> None: 'Start building network: ' + str(options_gui.settings['current_layer_data']) ) + options_gui.settings['update_model'] = False network_processor = NetworkProcessor( options_gui.settings['current_layer_data'], options_gui.processing_config, @@ -191,10 +220,39 @@ def frame() -> None: options_gui.destroy() -compute_render_thread: threading.Thread = threading.Thread( - target=compute_render, args=(1,) -) -compute_render_thread.setDaemon(True) -compute_render_thread.start() +def parse_args() -> bool: + parser = ArgumentParser(prog='Start nn_vis tool for VR') + parser.add_argument('--demo', action='store_true', + help='Download sample of a processed model and render it with 90% pruned edges instead of generating a random model.') + args = parser.parse_args() + return args.demo + + +if __name__ == '__main__': + global options_gui + options_gui = OptionGui() + + global sample_filepath + sample_filepath = 'sample_model.npz' + + global use_sample + use_sample = parse_args() + + setup_logger('tool_vr') + + if use_sample: + expected_sample_path = os.path.join(DATA_PATH, sample_filepath) + if not os.path.exists(expected_sample_path): + logging.info( + f'Downloading sample model to "{expected_sample_path}". This might take a minute ...') + sample_filepath = download_and_unzip_sample() + else: + logging.info( + f'Using sample model at "{expected_sample_path}"') + sample_filepath = expected_sample_path + compute_render_thread: threading.Thread = threading.Thread( + target=compute_render, args=(1,)) + compute_render_thread.setDaemon(True) + compute_render_thread.start() -options_gui.start() + options_gui.start()