-
Notifications
You must be signed in to change notification settings - Fork 384
/
combine.py
192 lines (157 loc) · 6.44 KB
/
combine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""Combine source and outputs from two notebooks
"""
import re
from copy import copy
from nbformat import NotebookNode
from .cell_metadata import _IGNORE_CELL_METADATA
from .formats import long_form_one_format
from .header import _DEFAULT_NOTEBOOK_METADATA
from .metadata_filter import restore_filtered_metadata
_BLANK_LINE = re.compile(r"^\s*$")
def black_invariant(text, chars=None):
"""Remove characters that may be changed when reformatting the text with black"""
if chars is None:
chars = [" ", "\t", "\n", ",", "'", '"', "(", ")", "\\"]
for char in chars:
text = text.replace(char, "")
return text
def same_content(ref, test, endswith=False):
"""Is the content of two cells the same, up to reformatting by black"""
ref = black_invariant(ref)
test = black_invariant(test)
if endswith and test:
return ref.endswith(test)
return ref == test
def combine_inputs_with_outputs(nb_source, nb_outputs, fmt=None):
"""Return a notebook that combines the text and metadata from the first notebook,
with the outputs and metadata of the second notebook."""
# nbformat version number taken from the notebook with outputs
assert (
nb_outputs.nbformat == nb_source.nbformat
), "The notebook with outputs is in format {}.{}, please upgrade it to {}.x".format(
nb_outputs.nbformat, nb_outputs.nbformat_minor, nb_source.nbformat
)
nb_source.nbformat_minor = nb_outputs.nbformat_minor
fmt = long_form_one_format(fmt)
text_repr = nb_source.metadata.get("jupytext", {}).get("text_representation", {})
ext = fmt.get("extension") or text_repr.get("extension")
format_name = fmt.get("format_name") or text_repr.get("format_name")
notebook_metadata_filter = nb_source.metadata.get("jupytext", {}).get(
"notebook_metadata_filter"
)
if notebook_metadata_filter == "-all":
nb_metadata = nb_outputs.metadata
else:
nb_metadata = restore_filtered_metadata(
nb_source.metadata,
nb_outputs.metadata,
notebook_metadata_filter,
_DEFAULT_NOTEBOOK_METADATA,
)
source_is_md_version_one = (
ext in [".md", ".markdown", ".Rmd"] and text_repr.get("format_version") == "1.0"
)
if nb_metadata.get("jupytext", {}).get("formats") or ext in [
".md",
".markdown",
".Rmd",
]:
nb_metadata.get("jupytext", {}).pop("text_representation", None)
if not nb_metadata.get("jupytext", {}):
nb_metadata.pop("jupytext", {})
if format_name in ["nomarker", "sphinx"] or source_is_md_version_one:
cell_metadata_filter = "-all"
else:
cell_metadata_filter = nb_metadata.get("jupytext", {}).get(
"cell_metadata_filter"
)
outputs_map = map_outputs_to_inputs(nb_source.cells, nb_outputs.cells)
cells = []
for source_cell, j in zip(nb_source.cells, outputs_map):
if j is None:
cells.append(source_cell)
continue
output_cell = nb_outputs.cells[j]
# Outputs and optional attributes are taken from the notebook with outputs
cell = copy(output_cell)
# Cell text is taken from the source notebook
cell.source = source_cell.source
# We also restore the cell metadata that has been filtered
cell.metadata = restore_filtered_metadata(
source_cell.metadata,
output_cell.metadata,
# The 'spin' format does not allow metadata on non-code cells
"-all"
if format_name == "spin" and source_cell.cell_type != "code"
else cell_metadata_filter,
_IGNORE_CELL_METADATA,
)
cells.append(cell)
# We call NotebookNode rather than new_notebook as we don't want to validate
# the notebook (some of the notebook in the collection of test notebooks
# do have some invalid properties - probably inherited from an older version
# of the notebook format).
return NotebookNode(
cells=cells,
metadata=nb_metadata,
nbformat=nb_outputs.nbformat,
nbformat_minor=nb_outputs.nbformat_minor,
)
def map_outputs_to_inputs(cells_inputs, cells_outputs):
"""Returns a map i->(j or None) that maps the cells with outputs to the input cells"""
n_in = len(cells_inputs)
n_out = len(cells_outputs)
outputs_map = [None] * n_in
# First rule: match based on cell type, content, in increasing order, for each cell type
first_unmatched_output_per_cell_type = {}
for i in range(n_in):
cell_input = cells_inputs[i]
for j in range(
first_unmatched_output_per_cell_type.get(cell_input.cell_type, 0), n_out
):
cell_output = cells_outputs[j]
if cell_input.cell_type == cell_output.cell_type and same_content(
cell_input.source, cell_output.source
):
outputs_map[i] = j
first_unmatched_output_per_cell_type[cell_input.cell_type] = j + 1
break
# Second rule: match unused outputs based on cell type and content
# Third rule: is the new cell the final part of a previous cell with outputs?
unused_ouputs = set(range(n_out)).difference(outputs_map)
for endswith in [False, True]:
if not unused_ouputs:
return outputs_map
for i in range(n_in):
if outputs_map[i] is not None:
continue
cell_input = cells_inputs[i]
for j in unused_ouputs:
cell_output = cells_outputs[j]
if cell_input.cell_type == cell_output.cell_type and same_content(
cell_output.source, cell_input.source, endswith
):
outputs_map[i] = j
unused_ouputs.remove(j)
break
# Fourth rule: match based on increasing index (and cell type) for non-empty cells
if not unused_ouputs:
return outputs_map
prev_j = -1
for i in range(n_in):
if outputs_map[i] is not None:
prev_j = outputs_map[i]
continue
j = prev_j + 1
if j not in unused_ouputs:
continue
cell_input = cells_inputs[i]
cell_output = cells_outputs[j]
if (
cell_input.cell_type == cell_output.cell_type
and cell_input.source.strip() != ""
):
outputs_map[i] = j
unused_ouputs.remove(j)
prev_j = j
return outputs_map