Skip to content

Commit

Permalink
Fixed bug in EthosBinary dataset class and model directory copying lo…
Browse files Browse the repository at this point in the history
…gic in RayTuneReportCallback (#1129)
  • Loading branch information
ANarayan committed Mar 25, 2021
1 parent 860fb1f commit e281f74
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 9 deletions.
2 changes: 1 addition & 1 deletion ludwig/datasets/agnews/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class AGNews(UncompressedFileDownloadMixin, MultifileJoinProcessMixin,
def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION):
super().__init__(dataset_name="agnews", cache_dir=cache_dir)

def read_file(self, filetype, filename):
def read_file(self, filetype, filename, header=0):
file_df = pd.read_csv(
os.path.join(self.raw_dataset_path, filename))
# class_index : number between 1-4 where
Expand Down
16 changes: 11 additions & 5 deletions ludwig/datasets/ethos_binary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,14 @@ class EthosBinary(UncompressedFileDownloadMixin, IdentityProcessMixin,
def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION):
super().__init__(dataset_name="ethos_binary", cache_dir=cache_dir)

def load_processed_dataset(self, split):
dataset_csv = os.path.join(self.processed_dataset_path,
self.csv_filename)
data_df = pd.read_csv(dataset_csv, sep=';')
return data_df
def process_downloaded_dataset(self):
super(EthosBinary, self).process_downloaded_dataset()
# replace ; sperator to ,
processed_df = pd.read_csv(os.path.join(self.processed_dataset_path,
self.csv_filename), sep=";")
# convert float labels (0.0, 1.0) to binary labels
processed_df['isHate'] = processed_df['isHate'].astype(int)
processed_df.to_csv(
os.path.join(self.processed_dataset_path, self.csv_filename),
index=False, sep=","
)
2 changes: 1 addition & 1 deletion ludwig/datasets/mixins/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MultifileJoinProcessMixin:
raw_dataset_path: str
processed_dataset_path: str

def read_file(self, filetype, filename, header):
def read_file(self, filetype, filename, header=0):
if filetype == 'json':
file_df = pd.read_json(
os.path.join(self.raw_dataset_path, filename))
Expand Down
2 changes: 1 addition & 1 deletion ludwig/datasets/yahoo_answers/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ split_filenames:
train_file: train.csv
test_file: test.csv
download_file_type: csv
csv_filename: yelp_answers.csv
csv_filename: yahoo_answers.csv
13 changes: 12 additions & 1 deletion ludwig/hyperopt/execution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import uuid
import copy
import json
import multiprocessing
Expand Down Expand Up @@ -722,7 +723,17 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
if trainer.is_coordinator():
with tune.checkpoint_dir(step=progress_tracker.epoch) as checkpoint_dir:
checkpoint_model = os.path.join(checkpoint_dir, 'model')
shutil.copytree(save_path, checkpoint_model)
# shutil.copytree(save_path, checkpoint_model)
# Note: A previous implementation used shutil.copytree()
# however, this copying method is non atomic
if not os.path.isdir(checkpoint_model):
copy_id = uuid.uuid4()
tmp_dst = "%s.%s.tmp" % (checkpoint_model, copy_id)
shutil.copytree(save_path, tmp_dst)
try:
os.rename(tmp_dst, checkpoint_model)
except:
shutil.rmtree(tmp_dst)

train_stats, eval_stats = progress_tracker.train_metrics, progress_tracker.vali_metrics
stats = eval_stats or train_stats
Expand Down

0 comments on commit e281f74

Please sign in to comment.