Skip to content

Commit

Permalink
Merge pull request #8 from iamgroot42/michael/default_load_hf
Browse files Browse the repository at this point in the history
bug fix for loading neighbors
  • Loading branch information
iamgroot42 committed Feb 15, 2024
2 parents 4789787 + 19c1535 commit 18f9085
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 18 deletions.
2 changes: 1 addition & 1 deletion mimir/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class ExperimentConfig(Serializable):
"Dump data to cache? Exits program after dumping"
load_from_cache: Optional[bool] = False
"""Load data from cache?"""
load_from_hf: Optional[bool] = False
load_from_hf: Optional[bool] = True
"""Load data from HuggingFace?"""
blackbox_attacks: Optional[List[str]] = field(
default_factory=lambda: None
Expand Down
28 changes: 14 additions & 14 deletions mimir/custom_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@

DATASETS = ['writing', 'english', 'german', 'pubmed']

SOURCES_UPLOADED = [
"arxiv",
"dm_mathematics",
"github",
"hackernews",
"pile_cc",
"pubmed_central",
"wikipedia_(en)",
"full_pile",
"c4",
"temporal_arxiv",
"temporal_wiki"
]


def load_pubmed(cache_dir):
data = datasets.load_dataset('pubmed_qa', 'pqa_labeled', split='train', cache_dir=cache_dir)
Expand Down Expand Up @@ -41,20 +55,6 @@ def load_cached(cache_dir,
if not filename.startswith("the_pile"):
raise ValueError(f"HuggingFace data only available for The Pile.")

SOURCES_UPLOADED = [
"arxiv",
"dm_mathematics",
"github",
"hackernews",
"pile_cc",
"pubmed_central",
"wikipedia_(en)",
"full_pile",
"c4",
"temporal_arxiv",
"temporal_wiki"
]

for source in SOURCES_UPLOADED:
# Got a match
if source in filename and filename.startswith(f"the_pile_{source}"):
Expand Down
4 changes: 2 additions & 2 deletions python_scripts/mimir.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,6 @@ def _generate_examples(self, file_path_dict):
yield id, {
"member": json.loads(member),
"nonmember": json.loads(nonmember),
"member_neighbors": json.loads(member_neighbors)[0],
"nonmember_neighbors": json.loads(nonmember_neighbors)[0],
"member_neighbors": json.loads(member_neighbors),
"nonmember_neighbors": json.loads(nonmember_neighbors),
}
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def get_mia_scores(
else None
),
loss=loss,
batch_siz=4,
batch_size=4,
substr_neighbors=substr_neighbors,
)

Expand Down

0 comments on commit 18f9085

Please sign in to comment.