Skip to content

Commit

Permalink
improve python code
Browse files Browse the repository at this point in the history
  • Loading branch information
jessica-mitchell committed Aug 21, 2023
1 parent 97527d1 commit 4f39bae
Showing 1 changed file with 36 additions and 28 deletions.
64 changes: 36 additions & 28 deletions doc/htmldoc/_ext/extract_api_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,16 @@
import ast
import json
import re
import glob


"""
Generate a JSON dictionary that stores the module name as key and corresponding
functions as values, along with the ``NestModule`` and the kernel attributes.
Used in a Jinja template to generate the autosummary for each module in
the API documentation (``ref_material/pynest_api/``)
"""

def find_all_variables(file_path):
"""
This function gets the names of all functions listed in ``__all__``
Expand Down Expand Up @@ -64,36 +72,36 @@ def find_all_variables(file_path):

def process_directory(directory):
"""
Get the PyNEST API files and set the keys to the base name
Get the PyNEST API filenames and set the keys to the base name
"""
api_dict = {}
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
if "helper" not in file:
if "pynest/nest/__init__" in file_path:
api_name = "nest.NestModule"
all_variables = find_all_variables(file_path)
if all_variables:
api_dict[api_name] = all_variables
if "hl_" in file:
parts = file_path.split(os.path.sep)
nest_index = parts.index("nest")
module_path = ".".join(parts[nest_index + 1 : -1])
module_name = os.path.splitext(parts[-1])[0]
api_name = f"nest.{module_path}.{module_name}"
all_variables = find_all_variables(file_path)
if all_variables:
api_dict[api_name] = all_variables
if "raster_plot" in file or "visualization" in file or "voltage_trace" in file:
parts = file_path.split(os.path.sep)
nest_index = parts.index("nest")
module_name = os.path.splitext(parts[-1])[0]
api_name = f"nest.{module_name}"
all_variables = find_all_variables(file_path)
if all_variables:
api_dict[api_name] = all_variables
api_exception_list = ["raster_plot", "visualization", "voltage_trace"]
files = glob.glob(directory+"**/*.py", recursive = True)

for file in files:

# ignoring the connection_helpers and helper modules
if "helper" in file:
continue

# get the NestModule for the kernel attributes
if "pynest/nest/__init__" in file:
api_name = "nest.NestModule"

parts = file.split(os.path.sep)
nest_index = parts.index("nest")
module_name = os.path.splitext(parts[-1])[0]
# only get high level API modules
if "hl_" in file:
module_path = ".".join(parts[nest_index + 1 : -1])
api_name = f"nest.{module_path}.{module_name}"
for item in api_exception_list:
if item in file:
api_name = f"nest.{module_name}"

all_variables = find_all_variables(file)
if all_variables:
api_dict[api_name] = all_variables

return api_dict

Expand Down

0 comments on commit 4f39bae

Please sign in to comment.