Skip to content

Commit

Permalink
temp testing
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey committed Aug 13, 2024
1 parent 2d8910c commit da18962
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 19 deletions.
4 changes: 2 additions & 2 deletions experiment_mini.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ python run.py \
--per_device_eval_batch_size_classification 4

# python run.py \
# --lm_type mistral-lora-sft-tiny \
# --run_name cpu-test-mistral-lora-sft-tiny \
# --lm_type mistral-lora-zero-shot-tiny \
# --run_name cpu-test-mistral-lora-zero-shot-tiny \
# --dataset_names ag_news SetFit/amazon_counterfactual_en \
# --num_subsamples 2 \
# --num_train 10 \
Expand Down
58 changes: 41 additions & 17 deletions src/pretrain_on_test/_dum.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,17 @@ class _Message(TypedDict):
content: str


Chat = list[_Message]


def _get_apply_chat_template(
tokenizer: PreTrainedTokenizerBase,
) -> Callable[[list[_Message]], str]:
) -> Callable[[Chat], str]:
if tokenizer.chat_template is not None:
return partial(tokenizer.apply_chat_template, tokenize=False)
else:

def join_content(messages: list[_Message]) -> str:
def join_content(messages: Chat) -> str:
return "".join(message["content"] for message in messages)

return join_content
Expand All @@ -94,10 +97,10 @@ def _create_chats(
class_names_unique: tuple[str, ...],
task_description: str,
system_role: str | None = "system",
) -> list[list[_Message]]:
) -> list[Chat]:
instruction = _instruction_formatter(class_names_unique, task_description)

def instruction_and_query(text: str) -> list[_Message]:
def instruction_and_query(text: str) -> Chat:
if system_role is None:
return [{"role": "user", "content": instruction + _query_formatter(text)}]
else:
Expand Down Expand Up @@ -135,7 +138,7 @@ def chat_text_post_processor(tokenizer: PreTrainedTokenizerBase, chat_text: str)


def _formatter(
chats: list[list[_Message]],
chats: list[Chat],
tokenizer: PreTrainedTokenizerBase,
chat_text_post_processor: Callable[[str], str] | None = None,
) -> list[str]:
Expand All @@ -146,6 +149,10 @@ def _formatter(
return chat_texts


def _formatter_nothing(texts: list[str], *args, **kwargs):
return texts


def _model_name(tokenizer: PreTrainedTokenizerBase) -> str:
model_name = cast(str, tokenizer.name_or_path)
if model_name.startswith("_"):
Expand Down Expand Up @@ -203,6 +210,10 @@ def load_model(
return AutoModelForCausalLM.from_pretrained(**loading_kwargs)


def _batch(texts: list[str], batch_size: int) -> list[list[str]]:
return [texts[i : (i + batch_size)] for i in range(0, len(texts), batch_size)]


def train(
texts: list[str],
class_names: list[str],
Expand All @@ -222,21 +233,32 @@ def train(
is_pretrained_fresh: bool = False,
device_map: str = "auto",
chat_text_post_processor: Callable[[str], str] | None = None,
pack: bool = True,
) -> tuple[tuple[PreTrainedModel, PreTrainedTokenizerBase], TrainOutput]:
"""
Returns a finetuned model and its tokenizer.
"""
dataset = Dataset.from_dict(
{
"chat": _create_chats(
texts,
class_names,
class_names_unique,
task_description,
system_role=_system_role(tokenizer),
)
}
)
if pack:
dataset = Dataset.from_dict(
{
"chat": [
"\n\n".join([_query_formatter(text) for text in batch])
for batch in _batch(texts, per_device_train_batch_size)
]
}
)
else:
dataset = Dataset.from_dict(
{
"chat": _create_chats(
texts,
class_names,
class_names_unique,
task_description,
system_role=_system_role(tokenizer),
)
}
)

# Set up model
model = load_model(
Expand Down Expand Up @@ -269,6 +291,8 @@ def train(
tokenizer.padding_side = "right"

# Set up trainer. The data_collator defines the objective.
formatter = _formatter_nothing if pack else _formatter
per_device_train_batch_size = 1 if pack else per_device_train_batch_size
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
Expand All @@ -285,7 +309,7 @@ def train(
disable_tqdm=False,
),
train_dataset=dataset,
formatting_func=lambda batch: _formatter(
formatting_func=lambda batch: formatter(
batch["chat"],
tokenizer,
chat_text_post_processor=chat_text_post_processor,
Expand Down

0 comments on commit da18962

Please sign in to comment.