Skip to content

Commit

Permalink
[Fix][Graph] Write compiled graph to tempfile first (#392)
Browse files Browse the repository at this point in the history
When a process is running save_compiled_graph(path) and writing the
CompiledGraph zipfile to the disk, another process may think that the
CompiledGraph file can be read since os.path.isfile(path) evaluates to
true.

This change writes the file to a temporary file first to avoid this race
condition.
  • Loading branch information
destefy committed Dec 12, 2023
1 parent b3840d2 commit f3fa023
Showing 1 changed file with 49 additions and 43 deletions.
92 changes: 49 additions & 43 deletions python/hidet/runtime/compiled_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
from dataclasses import dataclass
import warnings
import tempfile

from tabulate import tabulate
import numpy
Expand Down Expand Up @@ -389,49 +390,54 @@ def save_compiled_graph(model: CompiledGraph, path: str, save_dispatch_table: bo
dirname = os.path.dirname(path)
os.makedirs(dirname, exist_ok=True)

with zipfile.ZipFile(path, 'w') as zf:

def _save_under(dir_path: str, dir_in_zip: str, exclude: Optional[List[str]] = None):
for root, _, files in os.walk(dir_path):
for file in files:
file_path = os.path.join(root, file)
file_in_zip = os.path.join(dir_in_zip, os.path.relpath(file_path, dir_path))
with zf.open(file_in_zip, 'w') as f1:
if exclude and file in exclude:
continue
with open(file_path, 'rb') as f2:
f1.write(f2.read())

# meta info
with zf.open('meta.json', 'w') as f:
meta_bytes = json.dumps(asdict(model.meta), indent=4).encode('utf-8')
f.write(meta_bytes)

# save the modules
_save_under(model.graph_module.module_dir, 'graph_module/')

# save weights
with zf.open('weights.npz', 'w', force_zip64=True) as f: # force_zip64 is required for >4GB weights
numpy.savez(f, *[weight.cpu().numpy() for weight in model.weights])

# save the kernels (i.e., compiled tasks)
for i, compiled_task in enumerate(model.compiled_tasks):
_save_under(compiled_task.task_dir, 'kernels/{}/'.format(i))

# save graph execution
with zf.open('graph_execution.json', 'w') as f:
ge_bytes = json.dumps(asdict(model.graph_execution), indent=4).encode('utf-8')
f.write(ge_bytes)

# save dispatch table file
if save_dispatch_table and os.path.exists(model.dispatch_table_path):
with zf.open('dispatch_table.txt', 'w') as f:
with open(model.dispatch_table_path, 'rb') as f2:
f.write(f2.read())

# save graph string
with zf.open('graph_string.txt', 'w') as f:
f.write(model.graph_string.encode('utf-8'))
with tempfile.NamedTemporaryFile(dir=dirname, delete=False) as temp_file:
temp_path = temp_file.name

with zipfile.ZipFile(temp_path, 'w') as zf:

def _save_under(dir_path: str, dir_in_zip: str, exclude: Optional[List[str]] = None):
for root, _, files in os.walk(dir_path):
for file in files:
file_path = os.path.join(root, file)
file_in_zip = os.path.join(dir_in_zip, os.path.relpath(file_path, dir_path))
with zf.open(file_in_zip, 'w') as f1:
if exclude and file in exclude:
continue
with open(file_path, 'rb') as f2:
f1.write(f2.read())

# meta info
with zf.open('meta.json', 'w') as f:
meta_bytes = json.dumps(asdict(model.meta), indent=4).encode('utf-8')
f.write(meta_bytes)

# save the modules
_save_under(model.graph_module.module_dir, 'graph_module/')

# save weights
with zf.open('weights.npz', 'w', force_zip64=True) as f: # force_zip64 is required for >4GB weights
numpy.savez(f, *[weight.cpu().numpy() for weight in model.weights])

# save the kernels (i.e., compiled tasks)
for i, compiled_task in enumerate(model.compiled_tasks):
_save_under(compiled_task.task_dir, 'kernels/{}/'.format(i))

# save graph execution
with zf.open('graph_execution.json', 'w') as f:
ge_bytes = json.dumps(asdict(model.graph_execution), indent=4).encode('utf-8')
f.write(ge_bytes)

# save dispatch table file
if save_dispatch_table and os.path.exists(model.dispatch_table_path):
with zf.open('dispatch_table.txt', 'w') as f:
with open(model.dispatch_table_path, 'rb') as f2:
f.write(f2.read())

# save graph string
with zf.open('graph_string.txt', 'w') as f:
f.write(model.graph_string.encode('utf-8'))

os.rename(temp_path, path)


def load_compiled_graph(path: str) -> CompiledGraph:
Expand Down

0 comments on commit f3fa023

Please sign in to comment.