Skip to content

Commit

Permalink
enable packing as arg
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey committed Aug 13, 2024
1 parent 1bf1d1e commit db5e418
Show file tree
Hide file tree
Showing 14 changed files with 42 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"lm_type": "mistral-qlora-zero-shot",
"lm_type": "mistral-qlora-zero-shot-packing",
"run_name": "n500_contamination",
"dataset_names": [
"ag_news"
Expand Down
2 changes: 1 addition & 1 deletion analysis/contamination/saved_models_in_hf/experiment.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"lm_type": "mistral-qlora-zero-shot",
"lm_type": "mistral-qlora-zero-shot-packing",
"run_name": "n500_contamination",
"dataset_names": [
"ag_news"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
TQDM_DISABLE=1 python run.py \
--lm_type mistral-qlora-zero-shot \
--lm_type mistral-qlora-zero-shot-packing \
--run_name n100_mistral-qlora-zero-shot-packing_1 \
--dataset_names \
classla/FRENK-hate-en \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
TQDM_DISABLE=1 python run.py \
--lm_type mistral-qlora-zero-shot \
--lm_type mistral-qlora-zero-shot-packing \
--run_name n100_mistral-qlora-zero-shot-packing_2 \
--dataset_names \
ag_news \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
TQDM_DISABLE=1 python run.py \
--lm_type mistral-qlora-zero-shot \
--lm_type mistral-qlora-zero-shot-packing \
--run_name n100_mistral-qlora-zero-shot-packing_3 \
--dataset_names \
app_reviews \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
TQDM_DISABLE=1 python run.py \
--lm_type mistral-qlora-zero-shot \
--lm_type mistral-qlora-zero-shot-packing \
--run_name n100_mistral-qlora-zero-shot-packing_4 \
--dataset_names \
disaster_response_messages \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
TQDM_DISABLE=1 python run.py \
--lm_type mistral-qlora-zero-shot \
--lm_type mistral-qlora-zero-shot-packing \
--run_name n100_mistral-qlora-zero-shot-packing_5 \
--dataset_names \
dair-ai/emotion \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
TQDM_DISABLE=1 python run.py \
--lm_type mistral-qlora-zero-shot \
--lm_type mistral-qlora-zero-shot-packing \
--run_name n100_mistral-qlora-zero-shot-packing_6 \
--dataset_names \
ccdv/patent-classification \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
TQDM_DISABLE=1 python run.py \
--lm_type mistral-qlora-zero-shot \
--lm_type mistral-qlora-zero-shot-packing \
--run_name n100_mistral-qlora-zero-shot-packing_7 \
--dataset_names \
aladar/emo \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
TQDM_DISABLE=1 python run.py \
--lm_type mistral-qlora-zero-shot \
--lm_type mistral-qlora-zero-shot-packing \
--run_name n100_mistral-qlora-zero-shot-packing_8 \
--dataset_names \
movie_rationales \
Expand Down
25 changes: 25 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
"bert", # section 7 of the paper
"gpt2", # section 7 and 8
"mistral-qlora-zero-shot", # section 9
"mistral-qlora-zero-shot-packing", # section 9.1
"mistral-qlora-sft",
# For quick CPU tests
"bert-tiny",
"gpt2-tiny",
"mistral-lora-zero-shot-tiny",
"mistral-lora-zero-shot-packing-tiny",
"mistral-lora-sft-tiny",
"mistral-instruct-lora-sft-tiny",
]
Expand Down Expand Up @@ -62,6 +64,17 @@
**model_independent_kwargs,
),
"mistral-qlora-zero-shot": lambda **model_independent_kwargs: pretrain_on_test.Config(
model_id="mistralai/Mistral-7B-v0.3",
requires_hf_login=True,
model_class_pretrain=AutoModelForCausalLM,
pretrain_method="instructions-with-text",
lora_pretrain=True,
qlora=True,
classification_method="zero-shot",
max_length=512,
**model_independent_kwargs,
),
"mistral-qlora-zero-shot-packing": lambda **model_independent_kwargs: pretrain_on_test.Config(
model_id="mistralai/Mistral-7B-v0.3",
requires_hf_login=True,
model_class_pretrain=AutoModelForCausalLM,
Expand All @@ -70,6 +83,7 @@
qlora=True,
classification_method="zero-shot",
max_length=8192,
pack=True,
**model_independent_kwargs,
),
"mistral-qlora-sft": lambda **model_independent_kwargs: pretrain_on_test.Config(
Expand Down Expand Up @@ -117,6 +131,17 @@
max_length=512,
**model_independent_kwargs,
),
"mistral-lora-zero-shot-packing-tiny": lambda **model_independent_kwargs: pretrain_on_test.Config(
model_id="hf-internal-testing/tiny-random-MistralForCausalLM",
model_class_pretrain=AutoModelForCausalLM,
pretrain_method="instructions-with-text",
lora_pretrain=True,
classification_method="zero-shot",
lora_classification=True,
max_length=8192,
pack=True,
**model_independent_kwargs,
),
"mistral-lora-sft-tiny": lambda **model_independent_kwargs: pretrain_on_test.Config(
model_id="hf-internal-testing/tiny-random-MistralForCausalLM",
model_class_pretrain=AutoModelForCausalLM,
Expand Down
5 changes: 5 additions & 0 deletions src/pretrain_on_test/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,8 @@ def __post_init__(self):
if self.max_length is None: # be explicit about the default
default_max_length = self.tokenizer.model_max_length
object.__setattr__(self, "max_length", default_max_length)
if self.pack and self.pretrain_method != "instructions-with-text":
raise ValueError(
"Currently, packing is only enabled for pretraining on instructions "
"with text"
)
1 change: 1 addition & 0 deletions src/pretrain_on_test/_dum.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _formatter(


def _formatter_nothing(texts: list[str], *args, **kwargs):
breakpoint()
return texts


Expand Down
1 change: 1 addition & 0 deletions src/pretrain_on_test/pretrain_for_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,6 @@ def train(
chat_text_post_processor=partial(
_dum.chat_text_post_processor, config.tokenizer
),
pack=config.pack,
)
return train_output

0 comments on commit db5e418

Please sign in to comment.