diff --git a/keras_core/utils/summary_utils.py b/keras_core/utils/summary_utils.py index 52e71c33b..165194ae2 100644 --- a/keras_core/utils/summary_utils.py +++ b/keras_core/utils/summary_utils.py @@ -170,13 +170,13 @@ def print_summary( break if sequential_like: - line_length = line_length or 84 + default_line_length = 84 positions = positions or [0.45, 0.84, 1.0] # header names for the different log elements header = ["Layer (type)", "Output Shape", "Param #"] alignment = ["left", "left", "right"] else: - line_length = line_length or 108 + default_line_length = 108 positions = positions or [0.3, 0.56, 0.70, 1.0] # header names for the different log elements header = ["Layer (type)", "Output Shape", "Param #", "Connected to"] @@ -186,13 +186,16 @@ def print_summary( relevant_nodes += v if show_trainable: - line_length += 8 + default_line_length += 8 positions = [p * 0.88 for p in positions] + [1.0] header.append("Trainable") alignment.append("center") # Compute columns widths - line_length = min(line_length, shutil.get_terminal_size().columns - 4) + default_line_length = min( + default_line_length, shutil.get_terminal_size().columns - 4 + ) + line_length = line_length or default_line_length column_widths = [] current = 0 for pos in positions: