Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed bug in EthosBinary dataset class and model directory copying logic in RayTuneReportCallback #1129

Merged
merged 13 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ludwig/datasets/agnews/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class AGNews(UncompressedFileDownloadMixin, MultifileJoinProcessMixin,
def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION):
super().__init__(dataset_name="agnews", cache_dir=cache_dir)

def read_file(self, filetype, filename):
def read_file(self, filetype, filename, header=0):
file_df = pd.read_csv(
os.path.join(self.raw_dataset_path, filename))
# class_index : number between 1-4 where
Expand Down
16 changes: 11 additions & 5 deletions ludwig/datasets/ethos_binary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,14 @@ class EthosBinary(UncompressedFileDownloadMixin, IdentityProcessMixin,
def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION):
super().__init__(dataset_name="ethos_binary", cache_dir=cache_dir)

def load_processed_dataset(self, split):
dataset_csv = os.path.join(self.processed_dataset_path,
self.csv_filename)
data_df = pd.read_csv(dataset_csv, sep=';')
return data_df
def process_downloaded_dataset(self):
super(EthosBinary, self).process_downloaded_dataset()
# replace ; sperator to ,
processed_df = pd.read_csv(os.path.join(self.processed_dataset_path,
self.csv_filename), sep=";")
# convert float labels (0.0, 1.0) to binary labels
processed_df['isHate'] = processed_df['isHate'].astype(int)
processed_df.to_csv(
os.path.join(self.processed_dataset_path, self.csv_filename),
index=False, sep=","
)
2 changes: 1 addition & 1 deletion ludwig/datasets/mixins/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MultifileJoinProcessMixin:
raw_dataset_path: str
processed_dataset_path: str

def read_file(self, filetype, filename, header):
def read_file(self, filetype, filename, header=0):
if filetype == 'json':
file_df = pd.read_json(
os.path.join(self.raw_dataset_path, filename))
Expand Down
2 changes: 1 addition & 1 deletion ludwig/datasets/yahoo_answers/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ split_filenames:
train_file: train.csv
test_file: test.csv
download_file_type: csv
csv_filename: yelp_answers.csv
csv_filename: yahoo_answers.csv
13 changes: 12 additions & 1 deletion ludwig/hyperopt/execution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import uuid
import copy
import json
import multiprocessing
Expand Down Expand Up @@ -722,7 +723,17 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
if trainer.is_coordinator():
with tune.checkpoint_dir(step=progress_tracker.epoch) as checkpoint_dir:
checkpoint_model = os.path.join(checkpoint_dir, 'model')
shutil.copytree(save_path, checkpoint_model)
# shutil.copytree(save_path, checkpoint_model)
# Note: A previous implementation used shutil.copytree()
# however, this copying method is non atomic
if not os.path.isdir(checkpoint_model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment here about the fact that the previous implementation was copytree (you can leave that line commented) but that that's not atomic and so we have to do this temp + move thing, otherwise in the future we may forget about it and maybe replace it with copytree again :)

Copy link
Collaborator Author

@ANarayan ANarayan Mar 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great feedback -- addressed in this commit here: c924b9f

copy_id = uuid.uuid4()
tmp_dst = "%s.%s.tmp" % (checkpoint_model, copy_id)
shutil.copytree(save_path, tmp_dst)
try:
os.rename(tmp_dst, checkpoint_model)
except:
shutil.rmtree(tmp_dst)

train_stats, eval_stats = progress_tracker.train_metrics, progress_tracker.vali_metrics
stats = eval_stats or train_stats
Expand Down