Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Sotirios Anagnostidis committed Feb 11, 2023
1 parent 6d569a5 commit ec984c6
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 83 deletions.
78 changes: 4 additions & 74 deletions model/supervised_finetuning/custom_datasets/dialogue_collator.py
Expand Up @@ -16,83 +16,13 @@ class DialogueDataCollator:
Expects a list of texts corresponding to a sequence of [question, answer, question, answer, ...] pairs.
"""

tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None

def __call__(self, features):
flatten_messages = []
label_masks = []

for messages in features:
messages = list(messages)

# Add a way for the model to terminate generation
# When we predict the start of a new expected question, we want to be able to stop generation
messages.append(QA_SPECIAL_TOKENS["Question"])

flatten_message = self.tokenizer(
"".join(messages),
truncation=True,
max_length=self.max_length,
return_offsets_mapping=True,
)

message_change_indices = np.cumsum([len(x) for x in messages[:-1]])
# for each token an integer indicating the index of the message it belongs to. Just to create the label mask.
# Label mask is true when predicting a token that is part of the answer, false otherwise.
# TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John. Question:
# MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3 -2
# LABEL_MASK: 0 0 0 0 0 1 1 1 1 0 0 0 0 0 1 1 1 1 1 0

# If no result in next, we are predicting the last termination token(s)
message_indices = list(
map(
lambda x: next((i for i, val in enumerate(message_change_indices) if val >= x), -2),
list(map(lambda x: x[1], flatten_message["offset_mapping"])),
)
)
label_mask = np.roll(list(map(lambda x: x % 2 == 1, message_indices)), -1, -1)
try:
label_mask[[i for i in range(len(message_indices)) if message_indices[i] == -2][0] - 1] = True
except IndexError:
# due to truncation, we might not have the last termination token
label_mask[-1] = False

label_masks.append(label_mask)

flatten_messages.append({k: v for k, v in flatten_message.items() if k != "offset_mapping"})

batch = self.tokenizer.pad(
flatten_messages,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
dim = batch["input_ids"].shape[-1]

batch["label_masks"] = torch.stack(
[F.pad(torch.tensor(x), (0, dim - len(x)), value=False) for x in label_masks]
)
batch["targets"] = torch.roll(batch["input_ids"], -1, -1)

return batch


@dataclass
class TrainDialogueDataCollator:
"""
Expects a list of texts corresponding to a sequence of [question, answer, question, answer, ...] pairs.
"""

tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
mix_length_threshold: Optional[int] = 256
mix_probability: Optional[int] = 0.6
pad_to_multiple_of: Optional[int] = None
samples_mixing: Optional[bool] = False

def __call__(self, features):
flatten_messages = []
Expand Down Expand Up @@ -134,15 +64,15 @@ def __call__(self, features):
label_mask[-1] = False

label_masks.append(label_mask)
if len(flatten_message["input_ids"]) < self.mix_length_threshold:
if len(flatten_message["input_ids"]) < self.mix_length_threshold and self.samples_mixing:
total_short_context += len(flatten_message["input_ids"])
flatten_messages.append({k: v for k, v in flatten_message.items() if k != "offset_mapping"})
# packing
if total_short_context > 2:
if total_short_context > 2 and self.samples_mixing:
_flatten_messages, _label_masks = [], []
prev_short_msg, prev_short_mask = None, None
for flatten_msg, label_mask in zip(flatten_messages, label_masks):
if len(flatten_msg["input_ids"]) < self.mix_length_threshold and random.random() > 0.6:
if len(flatten_msg["input_ids"]) < self.mix_length_threshold and random.random() > self.mix_probability:
if prev_short_msg is not None:
for key in flatten_msg.keys():
flatten_msg[key] += prev_short_msg[key]
Expand Down
4 changes: 2 additions & 2 deletions model/supervised_finetuning/trainer.py
Expand Up @@ -195,7 +195,7 @@ def argument_parsing(notebook=False, notebook_args=None):

tokenizer = get_tokenizer(training_conf)
model = get_model(training_conf, tokenizer)
train, evals, collate_fn, train_collate_fn = get_dataset(training_conf, tokenizer)
train, evals, train_collate_fn, eval_collate_fn = get_dataset(training_conf, tokenizer)
sampler = PerDatasetSampler.build_sampler_from_config(training_conf, train.datasets)
metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF
Expand Down Expand Up @@ -252,7 +252,7 @@ def argument_parsing(notebook=False, notebook_args=None):
poly_eps=training_conf.poly_eps,
train_dataset=train,
eval_dataset=evals,
data_collator=collate_fn,
data_collator=eval_collate_fn,
tokenizer=tokenizer,
compute_metrics=partial(compute_metrics, metrics=metrics, preprocess_fns=preprocess_fns),
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
Expand Down
11 changes: 4 additions & 7 deletions model/supervised_finetuning/utils.py
Expand Up @@ -6,7 +6,7 @@
import transformers
import yaml
from custom_datasets import get_one_dataset
from custom_datasets.dialogue_collator import DialogueDataCollator, TrainDialogueDataCollator
from custom_datasets.dialogue_collator import DialogueDataCollator
from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS
from losses import CrossEntropyLoss, PolyLoss
from models import freeze_top_n_layers, get_specific_model
Expand Down Expand Up @@ -252,13 +252,10 @@ def get_dataset(conf, tokenizer):

train = ConcatDataset(train_datasets)

collate_fn = DialogueDataCollator(tokenizer, max_length=conf.max_length)
train_collate_fn = DialogueDataCollator(tokenizer, max_length=conf.max_length, samples_mixing=conf.samples_mixing)
eval_collate_fn = DialogueDataCollator(tokenizer, max_length=conf.max_length, samples_mixing=False)

train_collate_fn = (
TrainDialogueDataCollator(tokenizer, max_length=conf.max_length) if conf.samples_mixing else collate_fn
)

return train, evals, collate_fn, train_collate_fn
return train, evals, train_collate_fn, eval_collate_fn


def get_loss(loss, poly_eps):
Expand Down

0 comments on commit ec984c6

Please sign in to comment.