Skip to content

Commit

Permalink
use AfterValidator
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey committed Aug 20, 2024
1 parent 1de6c66 commit c8d231e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 35 deletions.
70 changes: 37 additions & 33 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import os
from datetime import datetime
from functools import partial
from typing import Any, Callable, Collection, Literal
from typing import Annotated, Any, Callable, Collection, Literal

from pydantic import BaseModel, ConfigDict, Field
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
from tap import tapify
import torch
from transformers import AutoModelForCausalLM, BertForMaskedLM, GPT2LMHeadModel
Expand All @@ -29,7 +29,7 @@
"gpt2", # section 7 and 8
"mistral-qlora-zero-shot", # section 9
"mistral-qlora-zero-shot-packing", # section 9.1
"mistral-qlora-sft",
"mistral-qlora-sft", # causes OOMs. Maybe it's b/c merge_and_unload dequantizes?
# For quick CPU tests
"bert-tiny",
"gpt2-tiny",
Expand Down Expand Up @@ -165,6 +165,29 @@
}


def _check_dataset_names(dataset_names: Collection[str] | None) -> list[str]:
if dataset_names is None:
dataset_names = list(
pretrain_on_test.data.hf_dataset_name_to_classification_dataset_info.keys()
)

def remove_owner(dataset_name: str) -> str:
return dataset_name.split("/")[-1]

dataset_names_without_owners = [
remove_owner(dataset_name) for dataset_name in dataset_names
]
if len(set(dataset_names_without_owners)) < len(dataset_names_without_owners):
raise ValueError(
"Some datasets have the same name. They may have different owners. But "
"that's still not allowed."
)
return sorted(dataset_names, key=remove_owner)


DatasetNames = Annotated[list[str] | None, AfterValidator(_check_dataset_names)]


_field_for_config = partial(Field, json_schema_extra={"is_for_config": True})


Expand All @@ -179,7 +202,7 @@ class Experiment(BaseModel):
lm_type: LMType = Field(
description=(
"Type of language model. *-tiny models have random weights and should only "
"be used for testing."
"be used for testing"
)
)
run_name: str = Field(
Expand All @@ -189,7 +212,7 @@ class Experiment(BaseModel):
"this name gets appended to the run ID string: run-{timestamp}-{run_name}"
),
)
dataset_names: list[str] | None = Field(
dataset_names: DatasetNames = Field(
default=None,
description=(
"Space-separated list of HuggingFace datasets, e.g., "
Expand All @@ -201,11 +224,16 @@ class Experiment(BaseModel):
default=50, description="Number of subsamples to draw from the dataset"
)
num_train: int = Field(
default=100, description="Number of observations for classification training"
default=100,
description=(
"Number of observations for classification training, i.e., m in the paper"
),
)
num_test: int = Field(
default=200,
description="Number of observations for pretraining and for evaluation",
description=(
"Number of observations for pretraining and eval, i.e., n in the paper"
),
)
# Model-independent arguments which are passed to the config
per_device_train_batch_size_pretrain: int = _field_for_config(
Expand Down Expand Up @@ -234,26 +262,6 @@ class Experiment(BaseModel):
)


def _check_dataset_names(dataset_names: Collection[str] | None) -> list[str]:
if dataset_names is None:
dataset_names = list(
pretrain_on_test.data.hf_dataset_name_to_classification_dataset_info.keys()
)

def remove_owner(dataset_name: str) -> str:
return dataset_name.split("/")[-1]

dataset_names_without_owners = [
remove_owner(dataset_name) for dataset_name in dataset_names
]
if len(set(dataset_names_without_owners)) < len(dataset_names_without_owners):
raise ValueError(
"Some datasets have the same name. (They may have different owners. "
"But that's still not allowed.)"
)
return sorted(dataset_names, key=remove_owner)


def run(
experiment: Experiment,
create_logger: cloud.CreateLogger = cloud.create_logger_local,
Expand Down Expand Up @@ -312,8 +320,7 @@ def run(
)

# Upload experiment settings
if not os.path.exists(run_id):
os.makedirs(run_id)
os.makedirs(run_id)
with open(os.path.join(run_id, "experiment.json"), "w") as json_file:
experiment_as_dict = experiment.model_dump()
json.dump(experiment_as_dict, json_file, indent=4)
Expand All @@ -334,13 +341,10 @@ def run(
**model_independent_kwargs
)

# Check that the dataset names don't conflict w/ each other
dataset_names = _check_dataset_names(experiment.dataset_names)

# Run experiment on each dataset
_ = torch.manual_seed(123)
torch.cuda.manual_seed_all(123)
for dataset_name in dataset_names:
for dataset_name in experiment.dataset_names:
classification_dataset_info = (
pretrain_on_test.data.hf_dataset_name_to_classification_dataset_info[
dataset_name
Expand Down
3 changes: 1 addition & 2 deletions src/pretrain_on_test/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from typing import Annotated, Callable

from datasets import load_dataset
from pydantic import Field, BaseModel, ConfigDict
from pydantic.functional_validators import AfterValidator
from pydantic import AfterValidator, Field, BaseModel, ConfigDict
import numpy as np
import pandas as pd

Expand Down

0 comments on commit c8d231e

Please sign in to comment.