Skip to content

Commit

Permalink
Load data from HF
Browse files Browse the repository at this point in the history
  • Loading branch information
iamgroot42 committed Feb 8, 2024
1 parent 7e9c580 commit b136dfc
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
30 changes: 25 additions & 5 deletions mimir/custom_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,38 @@ def load_pubmed(cache_dir):
return data


def load_cached(cache_dir, path: str, filename: str, min_length: int, max_length: int, n_samples: int, max_tokens: int):
def load_cached(cache_dir, data_split: str, filename: str, min_length: int,
max_length: int, n_samples: int, max_tokens: int,
load_from_hf: bool = False):
""""
Read from cache if available. Used for certain pile sources and xsum
to ensure fairness in comparison across attacks.runs.
"""
file_path = os.path.join(cache_dir, f"cache_{min_length}_{max_length}_{n_samples}_{max_tokens}", path, filename + ".jsonl")
if not os.path.exists(file_path):
raise ValueError(f"Requested cache file {file_path} does not exist")
data = load_data(file_path)
if load_from_hf:
print("Loading from HuggingFace!")
data_split = data_split.replace("train", "member")
data_split = data_split.replace("test", "nonmember")
ds = datasets.load_dataset("iamgroot42/mimir", name=filename, split=data_split)
data = collect_hf_data(ds)
if len(data) != n_samples:
raise ValueError(f"Requested {n_samples} samples, but only {len(data)} samples available. Potential mismatch in HuggingFace data and requested data.")
else:
file_path = os.path.join(cache_dir, f"cache_{min_length}_{max_length}_{n_samples}_{max_tokens}", data_split, filename + ".jsonl")
if not os.path.exists(file_path):
raise ValueError(f"Requested cache file {file_path} does not exist")
data = load_data(file_path)
return data


def collect_hf_data(ds):
records = [x["text"] for x in ds]
# Standard DS
if len(records[0]) == 1:
records = [x[0] for x in records]
# Neighbor data
return records


def load_data(file_path):
with open(file_path, 'r') as f:
data = [json.loads(line) for line in f.readlines()]
Expand Down
2 changes: 2 additions & 0 deletions mimir/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def load_neighbors(
max_length=self.config.max_words,
n_samples=self.config.n_samples,
max_tokens=self.config.max_tokens,
load_from_hf=self.config.load_from_hf
)
return data

Expand Down Expand Up @@ -102,6 +103,7 @@ def load(self, train: bool, mask_tokenizer=None, specific_source: str = None):
max_length=self.config.max_words,
n_samples=self.config.n_samples,
max_tokens=self.config.max_tokens,
load_from_hf=self.config.load_from_hf
)
return data
else:
Expand Down
File renamed without changes.

0 comments on commit b136dfc

Please sign in to comment.