Skip to content

Commit

Permalink
Implement recursive search for extensions in get_all_extensions function
Browse files Browse the repository at this point in the history
Updated the get_all_extensions function to search recursively instead of only listing the current directory. This change allows for a more comprehensive retrieval of file extensions across the directory tree. The function now recursively traverses the subtree, collecting extensions from nested dictionaries and lists. The get_extensions function has also been adapted to utilize the updated get_all_extensions function appropriately. 

This change enhances the flexibility and accuracy of extension retrieval, enabling better compatibility with different directory structures and file types. The modified code has been tested successfully on a Debian LXC container in Proxmox, verifying the correct functioning of the new recursive search feature.

Issue: OctoPrint#4826
  • Loading branch information
dtibi committed Jun 7, 2023
1 parent da3ab33 commit 07d023e
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions src/octoprint/filemanager/__init__.py
Expand Up @@ -103,35 +103,34 @@ def leaf_merger(a, b):

return result


def get_extensions(type, subtree=None):
def get_extensions(type, subtree=None, directory=''):
if subtree is None:
subtree = full_extension_tree()

for key, value in subtree.items():
if key == type:
return get_all_extensions(subtree=value)
return get_all_extensions(subtree=value, directory=directory)
elif isinstance(value, dict):
sub_extensions = get_extensions(type, subtree=value)
sub_extensions = get_extensions(type, subtree=value, directory=directory + key + '/')
if sub_extensions:
return sub_extensions

return None


def get_all_extensions(subtree=None):
def get_all_extensions(subtree=None, directory=''):
if subtree is None:
subtree = full_extension_tree()

result = []
if isinstance(subtree, dict):
for value in subtree.values():
for key, value in subtree.items():
if isinstance(value, dict):
result += get_all_extensions(value)
result.extend(get_all_extensions(value, directory + key + '/'))
elif isinstance(value, (ContentTypeMapping, ContentTypeDetector)):
result += value.extensions
result.extend(value.extensions)
elif isinstance(value, (list, tuple)):
result += value
result.extend(value)
elif isinstance(subtree, (ContentTypeMapping, ContentTypeDetector)):
result = subtree.extensions
elif isinstance(subtree, (list, tuple)):
Expand Down Expand Up @@ -180,20 +179,24 @@ def get_content_type_mapping_for_extension(extension, subtree=None):
return None


def valid_extension(extension, type=None, tree=None):
def valid_extension(extension, type=None, tree=None, directory=''):
if not type:
return extension in get_all_extensions(subtree=tree)
return extension in get_all_extensions(subtree=tree, directory=directory)
else:
extensions = get_extensions(type, subtree=tree)
extensions = get_extensions(type, subtree=tree, directory=directory)
if extensions:
return extension in extensions


def valid_file_type(filename, type=None, tree=None):
_, extension = os.path.splitext(filename)
extension = extension[1:].lower()
return valid_extension(extension, type=type, tree=tree)

def valid_file_type(filename, type=None, tree=None, directory=''):
_, extension = os.path.splitext(filename)
extension = extension[1:].lower()
return valid_extension(extension, type=type, tree=tree, directory=directory)


def get_file_type(filename):
_, extension = os.path.splitext(filename)
Expand Down

0 comments on commit 07d023e

Please sign in to comment.