Skip to content

Commit

Permalink
Merge 0da824b into 2fcc50b
Browse files Browse the repository at this point in the history
  • Loading branch information
satyakesav committed Nov 16, 2018
2 parents 2fcc50b + 0da824b commit e985db1
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions autokeras/utils.py
Expand Up @@ -22,34 +22,36 @@ def __init__(self, message):


def ensure_dir(directory):
"""Create directory if it does not exist"""
"""Create directory if it does not exist."""
if not os.path.exists(directory):
os.makedirs(directory)


def ensure_file_dir(path):
"""Create path if it does not exist"""
"""Create path if it does not exist."""
ensure_dir(os.path.dirname(path))


def has_file(path):
"""Check if the given path exists."""
return os.path.exists(path)


def pickle_from_file(path):
"""Load the pickle file from the provided path and returns the object."""
return pickle.load(open(path, 'rb'))


def pickle_to_file(obj, path):
"""Save the pickle file to the specified path."""
pickle.dump(obj, open(path, 'wb'))


def get_device():
""" If Cuda is available, use Cuda device, else use CPU device
When choosing from Cuda devices, this function will choose the one with max memory available
Returns: string device name
""" If CUDA is available, use CUDA device, else use CPU device.
When choosing from CUDA devices, this function will choose the one with max memory available.
Returns: string device name.
"""
# TODO: could use gputil in the future
device = 'cpu'
Expand Down Expand Up @@ -82,13 +84,15 @@ def get_device():


def temp_folder_generator():
"""Create and return a temporary directory with the path name '/temp_dir_name/autokeras' (E:g:- /tmp/autokeras)."""
sys_temp = tempfile.gettempdir()
path = os.path.join(sys_temp, 'autokeras')
ensure_dir(path)
return path


def download_file(file_link, file_path):
"""Download the file specified in `file_link` and saves it in `file_path`."""
if not os.path.exists(file_path):
with open(file_path, "wb") as f:
print("Downloading %s" % file_path)
Expand All @@ -109,6 +113,7 @@ def download_file(file_link, file_path):


def download_file_with_extract(file_link, file_path, extract_path):
"""Download the file specified in `file_link`, save to `file_path` and extract to the directory `extract_path`."""
if not os.path.exists(extract_path):
download_file(file_link, file_path)
zip_ref = zipfile.ZipFile(file_path, 'r')
Expand All @@ -120,6 +125,7 @@ def download_file_with_extract(file_link, file_path, extract_path):


def verbose_print(new_father_id, new_graph):
"""Print information about the operation performed on father model to obtain current model and father's id."""
cell_size = [24, 49]
header = ['Father Model ID', 'Added Operation']
line = '|'.join(str(x).center(cell_size[i]) for i, x in enumerate(header))
Expand All @@ -137,7 +143,7 @@ def verbose_print(new_father_id, new_graph):


def validate_xy(x_train, y_train):
"""Check `x_train`'s type and the shape of `x_train`, `y_train`."""
"""Validate `x_train`'s type and the shape of `x_train`, `y_train`."""
try:
x_train = x_train.astype('float64')
except ValueError:
Expand All @@ -151,7 +157,7 @@ def validate_xy(x_train, y_train):


def read_csv_file(csv_file_path):
"""Read the csv file and returns two separate list containing files name and their labels.
"""Read the csv file and returns two separate list containing file names and their labels.
Args:
csv_file_path: Path to the CSV file.
Expand All @@ -172,6 +178,7 @@ def read_csv_file(csv_file_path):


def read_image(img_path):
"""Read the image contained in the provided path `image_path`."""
img = imageio.imread(uri=img_path)
return img

Expand Down Expand Up @@ -226,14 +233,12 @@ def resize_image_data(data, h, w):


def get_system():
"""
Get the current system environment. If the current system is not supported,
raise an exception.
"""Get the current system environment. If the current system is not supported, raise an exception.
Returns:
a string to represent the current os name
posix stands for Linux and Mac or Solaris architecture
nt stands for Windows system
A string to represent the current OS name.
"posix" stands for Linux, Mac or Solaris architecture.
"nt" stands for Windows system.
"""
print(os.name)
if 'google.colab' in sys.modules:
Expand Down

0 comments on commit e985db1

Please sign in to comment.