-
Notifications
You must be signed in to change notification settings - Fork 3
/
io.py
202 lines (168 loc) · 7.68 KB
/
io.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import io
import logging
import os
from typing import Sequence, Optional, Tuple, List, Any
import matplotlib.figure
from matplotlib import pyplot as plt
import pandas as pd
log = logging.getLogger(__name__)
class ResultWriter:
log = log.getChild(__qualname__)
def __init__(self, result_dir, filename_prefix="", enabled: bool = True, close_figures: bool = False):
"""
:param result_dir:
:param filename_prefix:
:param enabled: whether the result writer is enabled; if it is not, it will create neither files nor directories
:param close_figures: whether to close figures that are passed by default
"""
self.result_dir = result_dir
self.filename_prefix = filename_prefix
self.enabled = enabled
self.close_figures_default = close_figures
if self.enabled:
os.makedirs(result_dir, exist_ok=True)
def child_with_added_prefix(self, prefix: str) -> "ResultWriter":
"""
Creates a derived result writer with an added prefix, i.e. the given prefix is appended to this
result writer's prefix
:param prefix: the prefix to append
:return: a new writer instance
"""
return ResultWriter(self.result_dir, filename_prefix=self.filename_prefix + prefix, enabled=self.enabled,
close_figures=self.close_figures_default)
def child_for_subdirectory(self, dir_name: str) -> "ResultWriter":
result_dir = os.path.join(self.result_dir, dir_name)
return ResultWriter(result_dir, filename_prefix=self.filename_prefix, enabled=self.enabled,
close_figures=self.close_figures_default)
def path(self, filename_suffix: str, extension_to_add=None, valid_other_extensions: Optional[Sequence[str]] = None) -> str:
"""
:param filename_suffix: the suffix to add (which may or may not already include a file extension)
:param extension_to_add: if not None, the file extension to add (without the leading ".") unless
the extension to add or one of the extenions in valid_extensions is already present
:param valid_other_extensions: a sequence of valid other extensions (without the "."), only
relevant if extensionToAdd is specified
:return: the full path
"""
if extension_to_add is not None:
add_ext = True
valid_extensions = set(valid_other_extensions) if valid_other_extensions is not None else set()
valid_extensions.add(extension_to_add)
if valid_extensions is not None:
for ext in valid_extensions:
if filename_suffix.endswith("." + ext):
add_ext = False
break
if add_ext:
filename_suffix += "." + extension_to_add
path = os.path.join(self.result_dir, f"{self.filename_prefix}{filename_suffix}")
return path
def write_text_file(self, filename_suffix: str, content: str):
p = self.path(filename_suffix, extension_to_add="txt")
if self.enabled:
self.log.info(f"Saving text file {p}")
with open(p, "w") as f:
f.write(content)
return p
def write_text_file_lines(self, filename_suffix: str, lines: List[str]):
p = self.path(filename_suffix, extension_to_add="txt")
if self.enabled:
self.log.info(f"Saving text file {p}")
write_text_file_lines(lines, p)
return p
def write_data_frame_text_file(self, filename_suffix: str, df: pd.DataFrame):
p = self.path(filename_suffix, extension_to_add="df.txt", valid_other_extensions="txt")
if self.enabled:
self.log.info(f"Saving data frame text file {p}")
with open(p, "w") as f:
f.write(df.to_string())
return p
def write_data_frame_csv_file(self, filename_suffix: str, df: pd.DataFrame, index=True, header=True):
p = self.path(filename_suffix, extension_to_add="csv")
if self.enabled:
self.log.info(f"Saving data frame CSV file {p}")
df.to_csv(p, index=index, header=header)
return p
def write_figure(self, filename_suffix: str, fig: plt.Figure, close_figure: Optional[bool] = None):
"""
:param filename_suffix: the filename suffix, which may or may not include a file extension, valid extensions being {"png", "jpg"}
:param fig: the figure to save
:param close_figure: whether to close the figure after having saved it; if None, use default passed at construction
:return: the path to the file that was written (or would have been written if the writer was enabled)
"""
p = self.path(filename_suffix, extension_to_add="png", valid_other_extensions=("jpg",))
if self.enabled:
self.log.info(f"Saving figure {p}")
fig.savefig(p, bbox_inches="tight")
must_close_figure = close_figure if close_figure is not None else self.close_figures_default
if must_close_figure:
plt.close(fig)
return p
def write_figures(self, figures: Sequence[Tuple[str, matplotlib.figure.Figure]], close_figures=False):
for name, fig in figures:
self.write_figure(name, fig, close_figure=close_figures)
def write_pickle(self, filename_suffix: str, obj: Any):
from .pickle import dump_pickle
p = self.path(filename_suffix, extension_to_add="pickle")
if self.enabled:
self.log.info(f"Saving pickle {p}")
dump_pickle(obj, p)
return p
def write_text_file_lines(lines: List[str], path):
"""
:param lines: the lines to write (without a trailing newline, which will be added)
:param path: the path of the text file to write to
"""
with open(path, "w") as f:
for line in lines:
f.write(line)
f.write("\n")
def read_text_file_lines(path, strip=True, skip_empty=True) -> List[str]:
"""
:param path: the path of the text file to read from
:param strip: whether to strip each line, removing whitespace/newline characters
:param skip_empty: whether to skip any lines that are empty (after stripping)
:return: the list of lines
"""
lines = []
with open(path, "r") as f:
for line in f.readlines():
if strip:
line = line.strip()
if not skip_empty or line != "":
lines.append(line)
return lines
def is_s3_path(path: str):
return path.startswith("s3://")
class S3Object:
def __init__(self, path):
assert is_s3_path(path)
self.path = path
self.bucket, self.object = self.path[5:].split("/", 1)
class OutputFile:
def __init__(self, s3_object: "S3Object"):
self.s3Object = s3_object
self.buffer = io.BytesIO()
def write(self, obj: bytes):
self.buffer.write(obj)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.s3Object.put(self.buffer.getvalue())
def get_file_content(self):
return self._get_s3_object().get()['Body'].read()
def open_file(self, mode):
assert mode in ("wb", "rb")
if mode == "rb":
content = self.get_file_content()
return io.BytesIO(content)
elif mode == "wb":
return self.OutputFile(self)
else:
raise ValueError(mode)
def put(self, obj: bytes):
self._get_s3_object().put(Body=obj)
def _get_s3_object(self):
import boto3
session = boto3.session.Session(profile_name=os.getenv("AWS_PROFILE"))
s3 = session.resource("s3")
return s3.Bucket(self.bucket).Object(self.object)