diff --git a/YOCO/README.md b/YOCO/README.md index e61e137c..ab3f3bec 100644 --- a/YOCO/README.md +++ b/YOCO/README.md @@ -1,6 +1,168 @@ -# YOCO +# You Only Cache Once: Decoder-Decoder Architectures for Large Language Models -- May 2024: Code release -- May 2024: release preprint [YOCO](https://arxiv.org/abs/) +## Approach +
+ +
-## Getting Started +
+ +
+ +## Performance +### Harness Eval +Training with 1T Tokens: +| **Model** | **Arc-c** | **Arc-e** | **BoolQ** | **Hellaswag**$^*$ | **OBQA** | **PIQA** | **Winogrande** | **SciQ** | **Avg** | +|----------------------------|-----------|-----------|-----------|-------------------|----------|----------|----------------|----------|---------| +| OpenLLaMA-3B-v2 | 0.339 | 0.676 | 0.657 | **0.700** | 0.260 | 0.767 | 0.629 | 0.924 | 0.619 | +| StableLM-base-alpha-3B-v2 | 0.324 | 0.673 | 0.646 | 0.686 | 0.264 | 0.760 | 0.621 | 0.921 | 0.612 | +| StableLM-3B-4E1T | --- | 0.666 | --- | --- | --- | **0.768**| 0.632 | 0.914 | --- | +| YOCO-3B | **0.379** | **0.731** | 0.645 | 0.689 | **0.298**| 0.763 | 0.639 | 0.924 | **0.634**| + +Training with 1.6T Tokens: +| **Model** | **Arc-c** | **Arc-e** | **BoolQ** | **Hellaswag**$^*$ | **OBQA** | **PIQA** | **Winogrande** | **SciQ** | **Avg** | +|----------------------------|-----------|-----------|-----------|-------------------|----------|----------|----------------|----------|---------| +| StableLM-3B-4E1T | --- | 0.688 | --- | --- | --- | 0.762 | 0.627 | 0.913 | --- | +| YOCO-3B | 0.396 | 0.733 | **0.644** | 0.698 | 0.300 | 0.764 | 0.631 | 0.921 | 0.636 | +| YOCO-3B-1M | **0.413** | **0.747** | 0.638 | **0.705** | 0.300 | **0.773**| **0.651** | **0.932**| **0.645**| +### Needle In A Haystack +
+ +
+ +### Multi-Needle Eval +| **Model** | **Size** | **N=1** | **N=2** | **N=4** | **N=8** | +|-------------------------|----------|---------|---------|---------|---------| +| GPT-4-128K | -- | 1.00 | 1.00 | 0.98 | 1.00 | +| MiniCPM-128K | 2.4B | 1.00 | 1.00 | 0.54 | 0.56 | +| ChatGLM3-128K | 6B | 0.94 | 0.72 | 0.52 | 0.44 | +| YaRN-Mistral-128K | 7B | 0.02 | 0.12 | 0.08 | 0.20 | +| LWM-1M-text | 7B | 1.00 | 0.90 | 0.76 | 0.62 | +| YOCO-3B-1M | 3B | 0.98 | 0.98 | 0.84 | 0.56 | + +## Setup + +To install the required packages, use the following command: + +```bash +pip install -r requirements.txt +``` + +Besides normal packages, [Apex](https://github.com/NVIDIA/apex) and [Flash-Attention](https://github.com/Dao-AILab/flash-attention) should be installed seperately following their offcial guidences. + +## Harness Eval + +To evaluate models in Harness-Eval, the script is as follows in ```scripts/eval_task.sh```: +```bash +cd fairseq/ +TASK='harness_boolq' + +torchrun --master-port=29505 --nproc_per_node=1 validate.py \ + --data-dir ../harness_data/ \ + --criterion harness_eval \ + --task harness_eval \ + --batch-size 4 \ + --eval-data ${TASK} \ + --log-format simple --log-interval 10 \ + --bf16 \ + --tokenizer-pad-to-multiple 8 \ + --arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /path_to_ckpt/YOCO-3B-1M/checkpoint.pth --yoco-model /path_to_ckpt/YOCO-3B-1M --tokens-per-sample 4096 +``` + +## Needle In A Haystack Evaluation +Our model uses city-number pairs for long sequence evaluation. To get the results at a certain maximal length, the script is as follows in ```scripts/eval_needle.sh```: +```bash +cd fairseq/ +torchrun --master-port=29504 --nproc_per_node=1 validate.py \ + --task pseudo \ + --criterion needle_haystack \ + --batch-size 1 \ + --max-epoch 1 \ + --no-save \ + --tiktoken-model cl100k_base \ + --bf16 \ + --arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /path_to_ckpt/YOCO-3B-1M/checkpoint.pth --yoco-model /path_to_ckpt/YOCO-3B-1M --tokens-per-sample 1048576 --interval 1048576 +``` + +To run Multi-Needle experiments, replace ```--criterion needle_haystack``` with ```--criterion multi_needle --needle-num {num}```. + +## Pretraining From Scratch +To support distributed training, our implementation is based on infinibatch to read data iteratively. The overall data directory should be organized as follows: +``` +Data/ +├── json/ +│ ├── train.json +│ └── CC.json +│ └── StarCoder.json +│ └── ... +├── shard/ +│ ├── CC/ +│ │ ├── 00000.jsonl +│ │ ├── 00001.jsonl +│ │ └── ... +│ └── StarCoder/ +│ ├── 00000.jsonl +│ ├── 00001.jsonl +│ └── ... +``` + +We recommend that each sharded data files contains no more than 10K lines with one json dict per line, and jsonl file, such as ```Data/shard/CC/00000.jsonl```, should be in the format like this: +```json +{"text": "File 1 is here..."} +{"text": "File 2 is here..."} +... +``` + +Then, for each source, a JSON file preserves all the paths of the jsonl files. Take ```Data/json/CC.json``` for example: +```json +[ + "/path_to_data/Data/shard/CC/00000.jsonl", + "/path_to_data/Data/shard/CC/00001.jsonl", + ... +] +``` + +Finally, ```train.json``` records all sources' information and sampling ratio: +```json +[ + { + "name": "CC", + "weight": 0.5 + }, + { + "name": "StarCoder", + "weight": 0.2 + }, + ... +] +``` + + ```scripts/train.sh```: +```bash +cd fairseq/ +torchrun --nproc-per-node=1 train.py /path_to_data \ + --save-interval-updates 5000 \ + --no-epoch-checkpoints \ + --arch yoco_base \ + --criterion cross_entropy \ + --task gpt \ + --tokens-per-sample 2048 \ + --tokenizer-pad-to-multiple 8 \ + --pad-to-max-len \ + --optimizer adam --adam-betas "(0.9, 0.95)" \ + --adam-eps 1e-06 \ + --clip-norm 2.0 \ + --lr 0.00015 \ + --lr-scheduler polynomial_decay \ + --warmup-updates 50 \ + --weight-decay 0.05 \ + --batch-size 1 \ + --model-parallel-size 1 \ + --update-freq 1 \ + --batch-read-ahead 1000 \ + --total-num-update 300000 \ + --log-format simple --log-interval 10 --disable-validation \ + --tiktoken-model cl100k_base \ + --save-interval-updates 5000 \ + --bf16 # bf16 is encouraged in pre-training +``` diff --git a/YOCO/imgs/1m_retrieval.png b/YOCO/imgs/1m_retrieval.png new file mode 100644 index 00000000..9fb8d949 Binary files /dev/null and b/YOCO/imgs/1m_retrieval.png differ diff --git a/YOCO/imgs/arch.png b/YOCO/imgs/arch.png new file mode 100644 index 00000000..15240637 Binary files /dev/null and b/YOCO/imgs/arch.png differ diff --git a/YOCO/imgs/inference.png b/YOCO/imgs/inference.png new file mode 100644 index 00000000..0751e0a6 Binary files /dev/null and b/YOCO/imgs/inference.png differ diff --git a/YOCO/requirements.txt b/YOCO/requirements.txt new file mode 100644 index 00000000..2e133623 --- /dev/null +++ b/YOCO/requirements.txt @@ -0,0 +1,12 @@ +torch>=2.2.0 +triton>=2.2.0 +numpy==1.23.0 +fairscale +tiktoken +sentencepiece +ninja +boto3 +iopath +git+https://github.com/sunyt32/fairseq.git@moe3#egg=fairseq +git+https://github.com/shumingma/infinibatch.git#egg=infinibatch +git+https://github.com/microsoft/torchscale.git#egg=torchscale \ No newline at end of file diff --git a/YOCO/scripts/eval_needle.sh b/YOCO/scripts/eval_needle.sh new file mode 100644 index 00000000..a6277901 --- /dev/null +++ b/YOCO/scripts/eval_needle.sh @@ -0,0 +1,11 @@ +cd yoco/ +torchrun --master-port=29504 --nproc_per_node=1 validate.py \ + --task pseudo \ + --criterion multi_needle --needle-num 4 \ + --batch-size 1 \ + --max-epoch 1 \ + --no-save \ + --tiktoken-model cl100k_base \ + --bf16 \ + --arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /data/yutao/ckpt_opensource/YOCO-3B-1M/checkpoint.pth --yoco-model /data/yutao/ckpt_opensource/YOCO-3B-1M --tokens-per-sample 1048576 --interval 1048576 + diff --git a/YOCO/scripts/eval_task.sh b/YOCO/scripts/eval_task.sh new file mode 100644 index 00000000..07b70593 --- /dev/null +++ b/YOCO/scripts/eval_task.sh @@ -0,0 +1,17 @@ +TASK='harness_boolq' +# TASK='hendrycksTest-abstract_algebra' + +cd yoco/ +torchrun --master-port=29505 --nproc_per_node=1 validate.py \ + --data-dir ../harness_data/ \ + --criterion harness_eval \ + --task harness_eval \ + --batch-size 4 \ + --eval-data ${TASK} \ + --log-format simple --log-interval 10 \ + --bf16 \ + --tokenizer-pad-to-multiple 8 \ + --arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /data/yutao/ckpt_opensource/YOCO-3B-1M/checkpoint.pth --yoco-model /data/yutao/ckpt_opensource/YOCO-3B-1M --tokens-per-sample 4096 + # --arch llama_from_ckpt --llama-model /data/yutao/llama/llama-2-7b --load-ckpt /data/yutao/llama/llama-2-7b/consolidated.00.pth --tokens-per-sample 4096 + + diff --git a/YOCO/scripts/train.sh b/YOCO/scripts/train.sh new file mode 100644 index 00000000..28c13f7b --- /dev/null +++ b/YOCO/scripts/train.sh @@ -0,0 +1,27 @@ +cd yoco/ +torchrun --master-port=29501 --nproc-per-node=1 train.py /mnt/nlcredstone/shaohanh/data/redstone_v4_21_config \ + --save-interval-updates 5000 \ + --no-epoch-checkpoints \ + --arch yoco_base \ + --criterion cross_entropy \ + --task gpt \ + --tokens-per-sample 2048 \ + --tokenizer-pad-to-multiple 8 \ + --pad-to-max-len \ + --optimizer adam --adam-betas "(0.9, 0.95)" \ + --adam-eps 1e-06 \ + --clip-norm 2.0 \ + --lr 0.00015 \ + --lr-scheduler polynomial_decay \ + --warmup-updates 50 \ + --weight-decay 0.05 \ + --batch-size 1 \ + --model-parallel-size 1 \ + --update-freq 1 \ + --batch-read-ahead 1000 \ + --total-num-update 300000 \ + --log-format simple --log-interval 10 --disable-validation \ + --tiktoken-model cl100k_base \ + --no-save \ + --bf16 \ + diff --git a/YOCO/yoco/__init__.py b/YOCO/yoco/__init__.py new file mode 100644 index 00000000..3ae31e25 --- /dev/null +++ b/YOCO/yoco/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] diff --git a/YOCO/yoco/criterions/__init__.py b/YOCO/yoco/criterions/__init__.py new file mode 100644 index 00000000..9901f275 --- /dev/null +++ b/YOCO/yoco/criterions/__init__.py @@ -0,0 +1,8 @@ +import importlib +import os + +# automatically import any Python files in the criterions/ directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("criterions." + file_name) \ No newline at end of file diff --git a/YOCO/yoco/criterions/harness_eval.py b/YOCO/yoco/criterions/harness_eval.py new file mode 100644 index 00000000..8aed18e3 --- /dev/null +++ b/YOCO/yoco/criterions/harness_eval.py @@ -0,0 +1,86 @@ +import torch +import torch.nn.functional as F + +from fairseq import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass + + +@register_criterion("harness_eval", dataclass=FairseqDataclass) +class HarnessEvalCriterion(FairseqCriterion): + def __init__(self, cfg, task): + super().__init__(task) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + model.eval() + net_output, _ = model(sample["net_input"]["src_tokens"]) + net_output = net_output[:, :-1, :] + targets = sample["net_input"]["src_tokens"][:, 1:] + loss_mask = sample["net_input"]["gpt_loss_mask"][:, 1:] + label_length = sample["net_input"]["label_length"] + loss = F.cross_entropy( + net_output.float().reshape(-1, net_output.size(-1)), + targets.reshape(-1), + reduction="none", + ignore_index=self.padding_idx, + ).reshape(targets.size(0), -1) + loss = loss * loss_mask.int() + loss_norm = loss.sum(-1) / label_length.float() + loss = loss.sum(-1) + + option_num = self.task.harness_task.class_num + labels = sample["targets"].view(-1) + + assert sample["targets"].size(0) % option_num == 0 + sample_size = sample["ntokens"] + + pred_label = torch.argmin(loss.view(-1, option_num), dim=1) + pred_norm_label = torch.argmin(loss_norm.view(-1, option_num), dim=1) + target_label = labels.view(-1, option_num)[:, 0] + + logging_output = {} + + logging_output.update( + { + "loss": 0, + "nsentences": pred_label.size(0), + "sample_size": pred_label.size(0), + "ncorrect": (pred_label == target_label).sum().item(), + "ncorrect_norm": (pred_norm_label == target_label).sum().item(), + } + ) + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss = sum(log.get("loss", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) + ncorrect_norm = sum(log.get("ncorrect_norm", 0) for log in logging_outputs) + metrics.log_scalar( + "loss", loss / nsentences, nsentences, round=3 + ) + metrics.log_scalar( + "accuracy", 100.0 * ncorrect / nsentences, nsentences, round=2 + ) + metrics.log_scalar( + "accuracy_norm", 100.0 * ncorrect_norm / nsentences, nsentences, round=2 + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True \ No newline at end of file diff --git a/YOCO/yoco/criterions/multi_needle.py b/YOCO/yoco/criterions/multi_needle.py new file mode 100644 index 00000000..f1b564ec --- /dev/null +++ b/YOCO/yoco/criterions/multi_needle.py @@ -0,0 +1,181 @@ +import os +import random +import math +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F + +from fairseq import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass + +OURS_TEMPLATE = "There is a special magic number inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the magic number there. {context} " +RANDOM_NEEDLE_CITIES = [ + 'Chicago', 'Yangon', 'Antananarivo', 'Colombo', 'Almaty', 'Sydney', 'Chicago', 'Mexico City', + 'Seattle', 'Lagos', 'Amsterdam', 'Belgrade', 'Cairo', 'Baghdad', 'Damascus', 'Kigali', 'Dakar', + 'Dakar', 'Sofia', 'Kigali', 'Victoria', 'Tashkent', 'Mumbai', 'Barcelona', 'Almaty', 'Amman', + 'Toronto', 'Bratislava', 'Johannesburg', 'Thimphu', 'Bangkok', 'Santiago', 'Cairo', 'San Francisco', + 'Lagos', 'Amsterdam', 'Paris', 'Rabat', 'Santiago', 'Copenhagen', 'Madrid', 'Kigali', + 'Ho Chi Minh City', 'Sarajevo', 'Delhi', 'Istanbul', 'Ho Chi Minh City', 'Khartoum', 'Helsinki', + 'Doha', 'Istanbul', 'Kuala Lumpur', 'Budapest', 'Shanghai', 'Moscow', 'Los Angeles', 'Oslo', + 'Johannesburg', 'Berlin', 'Bangalore', 'Tokyo', 'Melbourne', 'Barcelona', 'Chicago', 'Port Louis', + 'Lisbon', 'Nairobi', 'Kampala', 'Lima', 'Maputo', 'Vancouver', 'Dubai', 'Khartoum', 'Jakarta', + 'Madrid', 'Yerevan', 'Beirut', 'Athens', 'Chicago', 'Paris', 'Bucharest', 'Copenhagen', 'Brussels', + 'Damascus', 'Seattle', 'Los Angeles', 'Yerevan', 'Victoria', 'Tunis', 'Astana', 'Seoul', + 'Buenos Aires', 'Bangkok', 'Colombo', 'Brussels', 'Khartoum', 'Doha', 'San Francisco', 'Vienna', 'Jakarta' +] +QUESTION_TEMPLATE = "What is the special magic {city} number? The special magic {city} number is " +NEEDLE_TEMPLATE = "The special magic {city} number is: {rnd_number}" +@dataclass +class NeedleEvalConfig(FairseqDataclass): + needle_num: int = field( + default=4, + metadata={"help":"needle number"} + ) + tokens_per_sample: int = field( + default=16384, + ) + interval: int = field( + default=1024, + ) + needle_file_path: str = field( + default="/mnt/msranlp/yutao/data/PaulGrahamEssays", + ) + +def random_partition(total, n): + cuts = random.sample(range(1, total), n - 1) + cuts.sort() + cuts = [0] + cuts + [total] + parts = [cuts[i+1] - cuts[i] for i in range(n)] + return parts + +@register_criterion("multi_needle", dataclass=NeedleEvalConfig) +class NeedleEvalCriterion(FairseqCriterion): + def __init__(self, cfg: NeedleEvalConfig, task): + super().__init__(task) + self.cfg = cfg + self.essay_list = os.listdir(cfg.needle_file_path) * 5000 + + def generate_garbage(self, length): + current_text = "" + current_length = 0 + while True: + essay = random.choice(self.essay_list) + essay = open(os.path.join(self.cfg.needle_file_path, essay)).read().splitlines() + for line in essay: + tokens = self.task.tokenizer.encode(line + " ") + if current_length + len(tokens) > length: + return current_text + current_text += line + " " + current_length += len(tokens) + + def generate_prompt_landmark(self, first_length_list, second_length_list, final_length): + """Generates a text file and inserts an passkey at a random position.""" + lines = [] + citys = random.sample(RANDOM_NEEDLE_CITIES, self.cfg.needle_num) + for length in first_length_list: + lines.append(self.generate_garbage(length)) + city = citys.pop() + magic_number = random.randint(1, 50000) + information_line = NEEDLE_TEMPLATE.format(city=city, rnd_number=magic_number) + lines.append(information_line) + + final_question, answer = QUESTION_TEMPLATE.format(city=city), magic_number + + for length in second_length_list: + lines.append(self.generate_garbage(length)) + city = citys.pop() + magic_number = random.randint(1, 50000) + information_line = NEEDLE_TEMPLATE.format(city=city, rnd_number=magic_number) + lines.append(information_line) + + + lines.append(self.generate_garbage(final_length)) + lines.append(final_question) + context = "\n".join(lines) + return OURS_TEMPLATE.format(context=context), str(answer) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + model.eval() + all_retrieval_result = {} + random.seed(42) + for context_length in range(self.cfg.interval, self.cfg.tokens_per_sample + 1, self.cfg.interval): + all_length = (context_length - 150) + local_retrieval_result = [] + for depth_ratio in range(1, 11): + prefix_length = int(all_length * depth_ratio / 11) + suffix_length = all_length - prefix_length + n_correct = 0 + for _ in range(5): + if self.cfg.needle_num > 1: + first_needle_num = random.randint(1, self.cfg.needle_num - 1) + second_needle_num = self.cfg.needle_num + 1 - first_needle_num + first_length_list = random_partition(prefix_length, first_needle_num) + second_length_list = random_partition(suffix_length, second_needle_num) + final_length = second_length_list.pop() + else: + first_length_list = [prefix_length] + second_length_list = [] + final_length = suffix_length + prompt, pass_key = self.generate_prompt_landmark(first_length_list, second_length_list, final_length) + prompt_tokens = self.task.tokenizer.encode(prompt, bos=True) + prompt_tokens = torch.tensor([prompt_tokens], device="cuda") + print(prompt_tokens.shape) + output = self.generate(model, prompt_tokens) + pred = self.task.tokenizer.decode(output[0, prompt_tokens.shape[1]:]) + print("Answer: ", pass_key) + print("Pred: ", pred) + if pass_key in pred: + n_correct += 1 + local_retrieval_result.append(n_correct / 5) + all_retrieval_result[context_length] = local_retrieval_result + + print(all_retrieval_result) + return 0, 1, {"loss": 0} + + def generate(self, model, net_input, generate_tokens=20, chunk_length = 32768): + output_tokens = torch.cat((net_input, torch.full((net_input.shape[0], generate_tokens), self.task.tokenizer.pad_id).long().cuda()), dim=1) + begin_pad_index = torch.where(output_tokens == self.task.tokenizer.pad_id)[1].min().item() + incremental_state = {} + eos_reached = torch.tensor([False] * net_input.shape[0], device="cuda") + # prefilling + for begin_index in range(0, begin_pad_index - 1, chunk_length): + end_index = min(begin_index + chunk_length, begin_pad_index - 1) + _, _ = model(output_tokens[:, begin_index : end_index], incremental_state=incremental_state, start_pos=begin_index, skip_cross_decoder=True, is_prefilling=True) + # generation + for index in range(begin_pad_index, output_tokens.shape[1]): + generation_net_output, _ = model(output_tokens[:, index - 1].unsqueeze(-1), incremental_state=incremental_state, start_pos=index - 1, skip_cross_decoder=False, is_prefilling=False) + generation_net_output[:, :, self.task.tokenizer.bos_id] = -math.inf + generation_net_output[:, :, self.task.tokenizer.pad_id] = -math.inf + next_tokens = torch.argmax(generation_net_output[:, -1, :], dim=-1) + pad_tokens = output_tokens[:, index] + next_tokens = torch.where((pad_tokens == self.task.tokenizer.pad_id) & ~eos_reached, next_tokens, pad_tokens) + output_tokens[:, index] = next_tokens + eos_reached |= ( + next_tokens == self.task.tokenizer.eos_id + ) + if all(eos_reached): + break + + return output_tokens + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + pass + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True \ No newline at end of file diff --git a/YOCO/yoco/criterions/needle_haystack.py b/YOCO/yoco/criterions/needle_haystack.py new file mode 100644 index 00000000..5cc9f231 --- /dev/null +++ b/YOCO/yoco/criterions/needle_haystack.py @@ -0,0 +1,169 @@ +import os +import random +import math +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F + +from fairseq import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass + +OURS_TEMPLATE = "There is a special magic number inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the magic number there. {context} " +RANDOM_NEEDLE_CITIES = [ + 'Chicago', 'Yangon', 'Antananarivo', 'Colombo', 'Almaty', 'Sydney', 'Chicago', 'Mexico City', + 'Seattle', 'Lagos', 'Amsterdam', 'Belgrade', 'Cairo', 'Baghdad', 'Damascus', 'Kigali', 'Dakar', + 'Dakar', 'Sofia', 'Kigali', 'Victoria', 'Tashkent', 'Mumbai', 'Barcelona', 'Almaty', 'Amman', + 'Toronto', 'Bratislava', 'Johannesburg', 'Thimphu', 'Bangkok', 'Santiago', 'Cairo', 'San Francisco', + 'Lagos', 'Amsterdam', 'Paris', 'Rabat', 'Santiago', 'Copenhagen', 'Madrid', 'Kigali', + 'Ho Chi Minh City', 'Sarajevo', 'Delhi', 'Istanbul', 'Ho Chi Minh City', 'Khartoum', 'Helsinki', + 'Doha', 'Istanbul', 'Kuala Lumpur', 'Budapest', 'Shanghai', 'Moscow', 'Los Angeles', 'Oslo', + 'Johannesburg', 'Berlin', 'Bangalore', 'Tokyo', 'Melbourne', 'Barcelona', 'Chicago', 'Port Louis', + 'Lisbon', 'Nairobi', 'Kampala', 'Lima', 'Maputo', 'Vancouver', 'Dubai', 'Khartoum', 'Jakarta', + 'Madrid', 'Yerevan', 'Beirut', 'Athens', 'Chicago', 'Paris', 'Bucharest', 'Copenhagen', 'Brussels', + 'Damascus', 'Seattle', 'Los Angeles', 'Yerevan', 'Victoria', 'Tunis', 'Astana', 'Seoul', + 'Buenos Aires', 'Bangkok', 'Colombo', 'Brussels', 'Khartoum', 'Doha', 'San Francisco', 'Vienna', 'Jakarta' +] +QUESTION_TEMPLATE = "What is the special magic {city} number? The special magic {city} number is " +# NEEDLE_TEMPLATE = "The special magic {city} number is {rnd_number} . Remember it. The special magic {city} number is {rnd_number} . " +NEEDLE_TEMPLATE = "The special magic {city} number is {rnd_number} . " +@dataclass +class NeedleHaystackEvalConfig(FairseqDataclass): + max_len_b: int = field( + default=5, + metadata={"help":"max_len_b"} + ) + tokens_per_sample: int = field( + default=16384, + ) + interval: int = field( + default=1024, + ) + needle_file_path: str = field( + default="/mnt/msranlp/yutao/data/PaulGrahamEssays", + ) + +@register_criterion("needle_haystack", dataclass=NeedleHaystackEvalConfig) +class NeedleHaystackEvalCriterion(FairseqCriterion): + def __init__(self, cfg: NeedleHaystackEvalConfig, task): + super().__init__(task) + self.cfg = cfg + self.essay_list = os.listdir(cfg.needle_file_path) * 5000 + + def generate_garbage(self, length): + current_text = "" + current_length = 0 + while True: + essay = random.choice(self.essay_list) + essay = open(os.path.join(self.cfg.needle_file_path, essay)).read().splitlines() + for line in essay: + tokens = self.task.tokenizer.encode(line + " ") + if current_length + len(tokens) > length: + return current_text + current_text += line + " " + current_length += len(tokens) + + def generate_prompt_landmark(self, prefix_length, suffix_length): + """Generates a text file and inserts an passkey at a random position.""" + city = random.choice(RANDOM_NEEDLE_CITIES) + magic_number = random.randint(1, 50000) + garbage_prefix = self.generate_garbage(prefix_length) + garbage_suffix = self.generate_garbage(suffix_length) + information_line = NEEDLE_TEMPLATE.format(city=city, rnd_number=magic_number) + final_question = QUESTION_TEMPLATE.format(city=city) + lines = [ + garbage_prefix, + information_line, + garbage_suffix, + final_question, + ] + context = "\n".join(lines) + return OURS_TEMPLATE.format(context=context), str(magic_number) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + model.eval() + all_retrieval_result = {} + random.seed(0) + for context_length in range(self.cfg.interval, self.cfg.tokens_per_sample + 1, self.cfg.interval): + all_length = (context_length - 150) + local_retrieval_result = [] + depth_number = 10 + for depth_ratio in range(0, depth_number + 1): + prefix_length = int(all_length * depth_ratio / depth_number) + suffix_length = all_length - prefix_length + n_correct = 0 + times = 10 + for _ in range(times): + prompt, pass_key = self.generate_prompt_landmark(prefix_length, suffix_length) + prompt_tokens = self.task.tokenizer.encode(prompt, bos=True) + prompt_tokens = torch.tensor([prompt_tokens], device="cuda") + print(prompt_tokens.shape) + output = self.generate(model, prompt_tokens) + pred = self.task.tokenizer.decode(output[0, prompt_tokens.shape[1]:]) + print("Answer: ", pass_key) + print("Pred: ", pred) + if pass_key in pred: + n_correct += 1 + local_retrieval_result.append(n_correct / times) + all_retrieval_result[context_length] = local_retrieval_result + + print(all_retrieval_result) + return 0, 1, {"loss": 0} + + def generate(self, model, net_input, generate_tokens=20, chunk_length = 32768): + output_tokens = torch.cat((net_input, torch.full((net_input.shape[0], generate_tokens), self.task.tokenizer.pad_id).long().cuda()), dim=1) + begin_pad_index = torch.where(output_tokens == self.task.tokenizer.pad_id)[1].min().item() + incremental_state = {} + eos_reached = torch.tensor([False] * net_input.shape[0], device="cuda") + # prefilling + for begin_index in range(0, begin_pad_index - 1, chunk_length): + end_index = min(begin_index + chunk_length, begin_pad_index - 1) + _, _ = model(output_tokens[:, begin_index : end_index], incremental_state=incremental_state, start_pos=begin_index, skip_cross_decoder=True, is_prefilling=True) + # generation + for index in range(begin_pad_index, output_tokens.shape[1]): + generation_net_output, _ = model(output_tokens[:, index - 1].unsqueeze(-1), incremental_state=incremental_state, start_pos=index - 1, skip_cross_decoder=False, is_prefilling=False) + generation_net_output[:, :, self.task.tokenizer.bos_id] = -math.inf + generation_net_output[:, :, self.task.tokenizer.pad_id] = -math.inf + next_tokens = torch.argmax(generation_net_output[:, -1, :], dim=-1) + pad_tokens = output_tokens[:, index] + next_tokens = torch.where((pad_tokens == self.task.tokenizer.pad_id) & ~eos_reached, next_tokens, pad_tokens) + output_tokens[:, index] = next_tokens + eos_reached |= ( + next_tokens == self.task.tokenizer.eos_id + ) + if all(eos_reached): + break + + return output_tokens + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + metric_sum = sum(log.get("metric", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + metrics.log_scalar( + "loss", loss_sum / ntokens, ntokens, round=3 + ) + metrics.log_scalar( + "metric", metric_sum / nsentences, nsentences, round=3 + ) + + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True \ No newline at end of file diff --git a/YOCO/yoco/models/__init__.py b/YOCO/yoco/models/__init__.py new file mode 100644 index 00000000..1ff184f3 --- /dev/null +++ b/YOCO/yoco/models/__init__.py @@ -0,0 +1,41 @@ +import argparse +import importlib +import os + +try: + from torch._six import inf +except: + import sys + import torch + sys.modules["torch._six"] = torch + torch.string_classes = str + +MODEL_REGISTRY = {} +MODEL_DATACLASS_REGISTRY = {} +ARCH_MODEL_REGISTRY = {} +ARCH_MODEL_NAME_REGISTRY = {} +ARCH_MODEL_INV_REGISTRY = {} +ARCH_CONFIG_REGISTRY = {} + +# automatically import any Python files in the models/ directory +models_dir = os.path.dirname(__file__) +for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("models." + model_name) + + # extra `model_parser` for sphinx + if model_name in MODEL_REGISTRY: + parser = argparse.ArgumentParser(add_help=False) + group_archs = parser.add_argument_group("Named architectures") + group_archs.add_argument( + "--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name] + ) + group_args = parser.add_argument_group("Additional command-line arguments") + MODEL_REGISTRY[model_name].add_args(group_args) + globals()[model_name + "_parser"] = parser \ No newline at end of file diff --git a/YOCO/yoco/models/decoder/__init__.py b/YOCO/yoco/models/decoder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/YOCO/yoco/models/decoder/cross_attention.py b/YOCO/yoco/models/decoder/cross_attention.py new file mode 100644 index 00000000..09c31a89 --- /dev/null +++ b/YOCO/yoco/models/decoder/cross_attention.py @@ -0,0 +1,46 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from fairseq.model_parallel.megatron.mpu import ( + ColumnParallelLinear, + RowParallelLinear, +) + +from .model_parallel_init import init_method +from .kernel.rotary import apply_rotary_emb +from flash_attn import flash_attn_func + +class CrossAttention(nn.Module): + def __init__( + self, + args, + ): + super().__init__() + self.args = args + self.embed_dim = args.dim + self.num_heads = args.n_attn_heads // args.model_parallel_size + self.num_kv_heads = args.n_attn_kv_heads // args.model_parallel_size + + self.head_dim = args.dim // args.n_attn_heads + self.q_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=init_method) + self.out_proj = RowParallelLinear(args.dim, args.dim, bias=False, input_is_parallel=True, init_method=init_method) + + def forward( + self, + x, + key, + value, + rel_pos + ): + bsz, tgt_len, _ = x.size() + + q = self.q_proj(x) + q = q.view(bsz, tgt_len, self.num_heads, self.head_dim) + q = apply_rotary_emb(q, *rel_pos, interleaved=True) + + attn = flash_attn_func(q, key, value, causal=True) + attn = attn.view(bsz, tgt_len, self.head_dim * self.num_heads) + + attn = self.out_proj(attn) + return attn \ No newline at end of file diff --git a/YOCO/yoco/models/decoder/feedforward_network.py b/YOCO/yoco/models/decoder/feedforward_network.py new file mode 100644 index 00000000..3972068f --- /dev/null +++ b/YOCO/yoco/models/decoder/feedforward_network.py @@ -0,0 +1,33 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq.model_parallel.megatron.mpu import ( + ColumnParallelLinear, + RowParallelLinear, +) + +from .kernel.swiglu import swiglu +from .model_parallel_init import init_method + +class FeedForwardNetwork(nn.Module): + def __init__( + self, + embed_dim, + ffn_dim, + load_checkpoint=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.fc1 = ColumnParallelLinear(self.embed_dim, ffn_dim, bias=False, gather_output=False, init_method=init_method) + self.gate = ColumnParallelLinear(self.embed_dim, ffn_dim, bias=False, gather_output=False, init_method=init_method) + self.fc2 = RowParallelLinear(ffn_dim, self.embed_dim, bias=False, input_is_parallel=True, init_method=init_method) + + def forward(self, x): + x_shape = x.shape + x = x.reshape(-1, x.size(-1)) + x = self.fc2(swiglu(self.fc1(x), self.gate(x))) + output = x.view(x_shape) + return output \ No newline at end of file diff --git a/YOCO/yoco/models/decoder/gate_retention.py b/YOCO/yoco/models/decoder/gate_retention.py new file mode 100644 index 00000000..089164c7 --- /dev/null +++ b/YOCO/yoco/models/decoder/gate_retention.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq.model_parallel.megatron.mpu import ( + ColumnParallelLinear, + RowParallelLinear, +) + +from .rms_norm import RMSNorm + +from .kernel.gate_recurrent import chunk_gate_retention, recurrent_gate_retention +from .kernel.rotary import apply_rotary_emb +from .kernel.swiglu import swiglu + +from .model_parallel_init import qkvg_init_method, out_init_method + +class GateRetention(nn.Module): + + def __init__( + self, + args, + gate_logit_normalizer: int = 16, + ): + super().__init__() + self.args = args + self.embed_dim = args.dim + self.num_heads = args.n_self_heads // args.model_parallel_size + self.head_dim = args.dim // args.n_self_heads + + self.q_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=qkvg_init_method) + self.k_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=qkvg_init_method) + self.v_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=qkvg_init_method) + self.g_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=qkvg_init_method) + self.gt_proj = ColumnParallelLinear(args.dim, args.n_self_heads, bias=False, gather_output=False, init_method=qkvg_init_method) + + self.out_proj = RowParallelLinear(args.dim, args.dim, bias=False, input_is_parallel=True, init_method=out_init_method) + + self.subln = RMSNorm(self.head_dim, elementwise_affine=False, eps=args.norm_eps) + + self.gate_logit_normalizer = gate_logit_normalizer + + def forward( + self, + x, + rel_pos, + incremental_state=None, + is_prefilling=False, + ): + bsz, tgt_len, _ = x.size() + + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + g = self.g_proj(x) + gt = self.gt_proj(x) + + qr = q.view(bsz, tgt_len, self.num_heads, self.head_dim) + kr = k.view(bsz, tgt_len, self.num_heads, self.head_dim) + v = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + gt = gt.view(bsz, tgt_len, self.num_heads).transpose(1, 2) + + qr = apply_rotary_emb(qr, *rel_pos, interleaved=True).transpose(1, 2) + kr = apply_rotary_emb(kr, *rel_pos, interleaved=True).transpose(1, 2) + gt = (F.logsigmoid(gt) / self.gate_logit_normalizer) + + if incremental_state is not None and not is_prefilling: + o = recurrent_gate_retention(qr, kr, v, gt, incremental_state) + else: + if incremental_state is not None: + index_mask = incremental_state["index_mask"] + gt_sum = gt.float().masked_fill(index_mask, 0).sum(dim=-1, keepdim=True) + gt_mask = (gt_sum - gt.float().cumsum(dim=-1)).exp().masked_fill(index_mask, 0) + next_hidden_state = (kr.transpose(-1, -2) * (self.head_dim ** -0.5)) @ (v * gt_mask.to(v.dtype).unsqueeze(-1)) + if "last_hidden_state" in incremental_state: + last_hidden_state = incremental_state["last_hidden_state"] + next_hidden_state += last_hidden_state * gt_sum.exp().unsqueeze(-1).to(v.dtype) if last_hidden_state is not None else 0 + else: + last_hidden_state = None + incremental_state["last_hidden_state"] = next_hidden_state + o = chunk_gate_retention(qr, kr, v, gt, chunk_size=256, last_hidden_state=last_hidden_state) + else: + o = chunk_gate_retention(qr, kr, v, gt, chunk_size=256) + + o = self.subln(o).transpose(1, 2).reshape(bsz, tgt_len, self.num_heads * self.head_dim) + o = swiglu(g, o) + o = self.out_proj(o) + return o diff --git a/YOCO/yoco/models/decoder/kernel/gate_recurrent.py b/YOCO/yoco/models/decoder/kernel/gate_recurrent.py new file mode 100644 index 00000000..304131cc --- /dev/null +++ b/YOCO/yoco/models/decoder/kernel/gate_recurrent.py @@ -0,0 +1,302 @@ +import time +from typing import Optional + +import torch +import triton +import triton.language as tl + +torch.backends.cudnn.allow_tf32 = True + +@triton.jit +def _fwd_recurrence( + S, d, + O, + NUM_HEAD, NUM_BLOCK, + D_MODEL_K: tl.constexpr, D_MODEL_V: tl.constexpr, + BLOCK_MODEL_K: tl.constexpr, BLOCK_MODEL_V: tl.constexpr, + last_kv: Optional[tl.tensor] + ): + offset_bh = tl.program_id(0) + offset_d = tl.program_id(1) + offset_s = tl.program_id(2) + + S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :] + + O = O + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :] + + if last_kv is not None: + last_kv = last_kv + offset_bh * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :] + acc = tl.load(last_kv).to(tl.float32) + else: + acc = tl.zeros([BLOCK_MODEL_K, BLOCK_MODEL_V], dtype=tl.float32) + + tl.store(O, acc.to(O.dtype.element_ty)) + O += D_MODEL_K * D_MODEL_V + d = d + offset_bh * NUM_BLOCK + for i in range(NUM_BLOCK-1): + d_i = tl.load(d) + S_i = tl.load(S) + acc = acc * d_i + S_i + tl.store(O, acc.to(O.dtype.element_ty)) + d += 1 + S += D_MODEL_K * D_MODEL_V + O += D_MODEL_K * D_MODEL_V + + +## NUM_SPLIT_K/V. K/V dimension split into NUM_SPLIT_K/V parts with equal size BLOCK_MODEL +@triton.jit +def _bwd_recurrence( + S, d, + DI, DG, DL, DS, + NUM_HEAD, NUM_BLOCK, + D_MODEL_K: tl.constexpr, D_MODEL_V: tl.constexpr, + BLOCK_MODEL_K: tl.constexpr, BLOCK_MODEL_V: tl.constexpr, + + ): + offset_bh = tl.program_id(0) + offset_d = tl.program_id(1) + offset_s = tl.program_id(2) + + # offset_h = offset_bh % NUM_HEAD + NUM_K = D_MODEL_K // BLOCK_MODEL_K + NUM_V = D_MODEL_V // BLOCK_MODEL_V + # skip the last chunk because it is never used + S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V + + DI = DI + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V + + # start from the last chunk + DS = DS + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :] + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V + + DG = DG + offset_bh * NUM_BLOCK * NUM_K * NUM_V + offset_d * NUM_V + offset_s + (NUM_BLOCK - 2) * NUM_K * NUM_V + + d = d + offset_bh * NUM_BLOCK + (NUM_BLOCK - 1) + + Dacc = tl.zeros([BLOCK_MODEL_K, BLOCK_MODEL_V], dtype=tl.float32) + + # ignore the first chunk + for i in range(NUM_BLOCK - 1): + S_i = tl.load(S) + DS_i = tl.load(DS) + d_i = tl.load(d) + Dacc = Dacc * d_i + DS_i + DG_i = tl.sum(Dacc * S_i.to(tl.float32)) + + tl.store(DG, DG_i.to(DG.dtype.element_ty)) + tl.store(DI, Dacc.to(DI.dtype.element_ty)) + + S -= D_MODEL_K * D_MODEL_V + DI -= D_MODEL_K * D_MODEL_V + DS -= D_MODEL_K * D_MODEL_V + DG -= NUM_K * NUM_V + d -= 1 + + DL = DL + offset_bh * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :] + DS_i = tl.load(DS) + d_i = tl.load(d) + Dacc = Dacc * d_i + DS_i + tl.store(DL, Dacc.to(DL.dtype.element_ty)) + +class ChunkGateRecurrent(torch.autograd.Function): + @staticmethod + def forward(ctx, kv, cross_decay, last_kv=None): + cross_decay = cross_decay.contiguous() + kv = kv.contiguous() + + B, H, N, D_k, D_v = kv.shape + output = torch.empty_like(kv) + BLOCK_MODEL_K = 64 + BLOCK_MODEL_V = 16 + + assert D_k % BLOCK_MODEL_K == 0 + assert D_v % BLOCK_MODEL_V == 0 + + grid = (B*H, D_k//BLOCK_MODEL_K, D_v//BLOCK_MODEL_V) + ctx.grid = grid + ctx.have_last_kv = last_kv is not None + ctx.BLOCK_MODEL_K = BLOCK_MODEL_K + ctx.BLOCK_MODEL_V = BLOCK_MODEL_V + + _fwd_recurrence[grid]( + kv, + cross_decay, + output, + D_MODEL_K=D_k, D_MODEL_V=D_v, + NUM_BLOCK=N, NUM_HEAD=H, + BLOCK_MODEL_K=BLOCK_MODEL_K, + BLOCK_MODEL_V=BLOCK_MODEL_V, + last_kv=last_kv + ) + + ctx.save_for_backward(output, cross_decay) + return output + + @staticmethod + def backward(ctx, DO): + DO = DO.contiguous() + + output, cross_decay = ctx.saved_tensors + + B, H, N, D_k, D_v = output.shape + + BLOCK_MODEL_K = 64 + BLOCK_MODEL_V = 16 + + grid = (B*H, D_k//BLOCK_MODEL_K, D_v//BLOCK_MODEL_V) + + DI = torch.empty_like(DO) + DG = torch.empty(B*H, N, D_k//BLOCK_MODEL_K, D_v//BLOCK_MODEL_V, device=cross_decay.device, dtype=cross_decay.dtype) + DL = torch.empty(B, H, D_k, D_v, device=output.device, dtype=output.dtype) + _bwd_recurrence[grid]( + output, cross_decay, + DI, DG, DL, DO, + NUM_HEAD=H, NUM_BLOCK = N, + D_MODEL_K = D_k, + D_MODEL_V = D_v, + BLOCK_MODEL_K=BLOCK_MODEL_K, + BLOCK_MODEL_V=BLOCK_MODEL_V, + ) + + DI[:, :, -1] = 0 + DG[:, -1] = 0 + DG = DG.view(B, H, N, -1).sum(dim=-1) + return DI, DG, DL if ctx.have_last_kv else None + +def cross_chunk(q, k, v, g, last_hidden_state=None): + kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None].to(v.dtype)) + cross_decay = g[:, :, :, -1].exp().to(kv.dtype) + S = chunk_gate_recurrent(kv, cross_decay, last_hidden_state) + cross = (q * g[..., None].exp().to(q.dtype)) @ S + return cross + +@torch.compile +def inner_chunk(q, k, v, g): + attn = q @ k.transpose(-1, -2) + causal_mask = torch.full([q.shape[-2], q.shape[-2]], float("-inf"), device=q.device).triu(1).type_as(q) + attn = attn * (g[..., None] - g[..., None, :] + causal_mask).exp().to(attn.dtype) + inner = attn @ v + return inner + +def chunk_gate_retention(q, k, v, g, chunk_size=64, last_hidden_state=None): + bsz, num_head, tgt_len, key_dim = q.shape + head_dim = v.shape[-1] + num_chunk = tgt_len // chunk_size + q = q.view(bsz, num_head, num_chunk, chunk_size, key_dim) + k = k.view(bsz, num_head, num_chunk, chunk_size, key_dim) * (key_dim ** -0.5) + v = v.view(bsz, num_head, num_chunk, chunk_size, head_dim) + g = g.view(bsz, num_head, num_chunk, chunk_size) + g = g.float().cumsum(-1) + cross = cross_chunk(q, k, v, g, last_hidden_state=last_hidden_state) + inner = inner_chunk(q, k, v, g) + o = cross + inner + return o.view(bsz, num_head, tgt_len, head_dim) + +# for long sequence parallelism +def hier_chunk_gate_retention(q, k, v, g, chunk_size=64, hier_chunk_size=16384): + bsz, num_head, tgt_len, key_dim = q.shape + head_dim = v.shape[-1] + num_hier_chunk = tgt_len // hier_chunk_size + assert tgt_len == num_hier_chunk * hier_chunk_size + + q = q.view(bsz, num_head, num_hier_chunk, hier_chunk_size, key_dim) + k = k.view(bsz, num_head, num_hier_chunk, hier_chunk_size, key_dim) + v = v.view(bsz, num_head, num_hier_chunk, hier_chunk_size, head_dim) + g = g.view(bsz, num_head, num_hier_chunk, hier_chunk_size) + hier_cross = cross_chunk(q, k * (key_dim ** -0.5), v, g.float().cumsum(-1)).view(bsz, num_head, tgt_len, head_dim) + + qi = q.transpose(1, 2).reshape(bsz * num_hier_chunk, num_head, hier_chunk_size, key_dim) + ki = k.transpose(1, 2).reshape(bsz * num_hier_chunk, num_head, hier_chunk_size, key_dim) + vi = v.transpose(1, 2).reshape(bsz * num_hier_chunk, num_head, hier_chunk_size, head_dim) + gi = g.transpose(1, 2).reshape(bsz * num_hier_chunk, num_head, hier_chunk_size) + inner_cross = chunk_gate_retention(qi, ki, vi, gi, chunk_size) + + inner_cross = inner_cross.view(bsz, num_hier_chunk, num_head, hier_chunk_size, head_dim).transpose(1, 2).reshape(bsz, num_head, tgt_len, head_dim) + o = hier_cross + inner_cross + return o + +def recurrent_gate_retention(q, k, v, g, incremental_state): + bsz, num_head, _, key_dim = q.shape + k *= key_dim ** -0.5 + g = g.view(bsz, num_head, 1, 1).float().exp() + kv = k.transpose(-1, -2) * v + if "last_hidden_state" in incremental_state: + prev_kv = incremental_state["last_hidden_state"] + kv += prev_kv * g.to(prev_kv.dtype) + + incremental_state["last_hidden_state"] = kv + o = q @ kv + return o + +def parallel_gate_retention(q, k, v, g): + k = k * (q.shape[-1] ** -0.5) + causal_mask = torch.full([q.shape[-2], q.shape[-2]], float("-inf"), device=q.device).triu(1).type_as(q) + g = g.float().cumsum(-1) + mask = g[..., None] - g[..., None, :] + causal_mask + mask = mask.exp() + + attn = q @ k.transpose(-1, -2) + attn = attn * mask.to(attn.dtype) + o = attn @ v + return o + +def naive_kv_recurrent(kv, cross_decay, last_kv=None): + BSZ, NUM_HEAD, NUM_BLOCK, D_MODEL_K, D_MODEL_V = kv.shape + kv_recurrent = [] + kv_state = torch.zeros(BSZ, NUM_HEAD, D_MODEL_K, D_MODEL_V, dtype=kv.dtype, device="cuda") if last_kv is None else last_kv + # accumulate kv by loop + for i in range(NUM_BLOCK): + kv_recurrent.append(kv_state) + kv_state = kv_state * cross_decay[:, :, i, None, None] + kv[:, :, i] + + kv_recurrent = torch.stack(kv_recurrent, dim=2) + return kv_recurrent + +chunk_gate_recurrent = ChunkGateRecurrent.apply + +def main(): + BSZ = 4 + NUM_HEAD = 4 + NUM_BLOCK = 16 + D_MODEL_K = 256 + D_MODEL_V = 432 + dtype = torch.float16 + kv = torch.randn(BSZ, NUM_HEAD, NUM_BLOCK, D_MODEL_K, D_MODEL_V, dtype=dtype, device="cuda") + last_kv = torch.randn(BSZ, NUM_HEAD, D_MODEL_K, D_MODEL_V, dtype=dtype, device="cuda") + kv_triton = kv.clone().detach() + last_kv_triton = last_kv.clone().detach() + cross_decay = torch.randn(BSZ, NUM_HEAD, NUM_BLOCK, dtype=dtype, device="cuda") + cross_decay = torch.sigmoid(cross_decay) + cross_decay_triton = cross_decay.clone().detach() + grad_weight = torch.randn(BSZ, NUM_HEAD, NUM_BLOCK, D_MODEL_K, D_MODEL_V, dtype=dtype, device="cuda") + kv.requires_grad = True + kv_triton.requires_grad = True + last_kv.requires_grad = True + last_kv_triton.requires_grad = True + cross_decay.requires_grad = True + cross_decay_triton.requires_grad = True + + start = time.time() + kv_recurrent = naive_kv_recurrent(kv, cross_decay, last_kv) + kv_recurrent.mul(grad_weight).sum().backward() + print("naive time:", time.time() - start) + + start = time.time() + kv_recurrent_triton = chunk_gate_recurrent(kv_triton, cross_decay_triton, last_kv_triton) + kv_recurrent_triton.mul(grad_weight).sum().backward() + print("triton time:", time.time() - start) + + print(torch.allclose(kv_recurrent, kv_recurrent_triton, atol=1e-3)) + print((kv_recurrent - kv_recurrent_triton).abs().max(), (kv_recurrent - kv_recurrent_triton).abs().mean()) + + print(torch.allclose(kv.grad, kv_triton.grad, atol=1e-3)) + print((kv.grad - kv_triton.grad).abs().max(), (kv.grad - kv_triton.grad).abs().mean()) + + print(torch.allclose(last_kv.grad, last_kv_triton.grad, atol=1e-3)) + print((last_kv.grad - last_kv_triton.grad).abs().max(), (last_kv.grad - last_kv_triton.grad).abs().mean()) + + print(torch.allclose(cross_decay.grad, cross_decay_triton.grad, atol=1e-3)) + print((cross_decay.grad - cross_decay_triton.grad).abs().max(), (cross_decay.grad - cross_decay_triton.grad).abs().mean()) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/YOCO/yoco/models/decoder/kernel/rotary.py b/YOCO/yoco/models/decoder/kernel/rotary.py new file mode 100644 index 00000000..8ee2cb93 --- /dev/null +++ b/YOCO/yoco/models/decoder/kernel/rotary.py @@ -0,0 +1,332 @@ +# Copyright (c) 2023, Tri Dao. + +from typing import Optional, Union + +import torch + +import triton +import triton.language as tl + + +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_M": 2}), +# triton.Config({"BLOCK_M": 4}), +# triton.Config({"BLOCK_M": 8}), +# triton.Config({"BLOCK_M": 16}), +# ], +# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"], +# ) +@triton.jit +def rotary_kernel( + OUT, # Pointers to matrices + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, # this could be int or a pointer + # Matrix dimensions + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + # strides + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + # Meta-parameters + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * BLOCK_M >= seqlen: + return + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk = tl.arange(0, BLOCK_K) + rk_half = tl.arange(0, BLOCK_K // 2) + + if not INTERLEAVED: + # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT + X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) + COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + cos = tl.load( + COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0 + ).to(tl.float32) + sin = tl.load( + SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0 + ).to(tl.float32) + x0 = tl.load( + X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0 + ).to(tl.float32) + x1 = tl.load( + X + rotary_dim_half * stride_x_headdim, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + # write back result + OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) + tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) + tl.store( + OUT + rotary_dim_half * stride_out_headdim, + o1, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + ) + else: + # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow. + # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...]. + # Loading x0 will be fast but x1 will be slow. + # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...]. + # Then we do the calculation and use tl.where to pick put the right outputs for the even + # and for the odd indices. + rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... + rk_repeat = tl.arange(0, BLOCK_K) // 2 + X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) + X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim) + COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + cos = tl.load( + COS, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=1.0, + ).to(tl.float32) + sin = tl.load( + SIN, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to( + tl.float32 + ) + x1 = tl.load( + X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0 + ).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) + tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + + assert ( + cos.dtype == sin.dtype + ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert ( + x.dtype == cos.dtype + ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = ( + 32 + if rotary_dim <= 32 + else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) + ) + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, # data ptrs + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, # shapes + nheads, + rotary_dim, + seqlen_ro, + seqlen // 128, # key for triton cache (limit number of compilations) + output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + output.stride(-3), # seqlen_stride or total_seqlen_stride + output.stride(-2), # nheads_stride + output.stride(-1), # headdim_stride + x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + x.stride(-3), # seqlen stride or total_seqlen_stride + x.stride(-2), # nheads stride + x.stride(-1), # headdim stride + BLOCK_K, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + BLOCK_M, + ) + return output + +class ApplyRotaryEmb(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + out = apply_rotary( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, + ) + if isinstance(seqlen_offsets, int): + # Can't save int with save_for_backward + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x + + @staticmethod + def backward(ctx, do): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() + dx = apply_rotary( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +def apply_rotary_emb( + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmb.apply( + x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen + ) diff --git a/YOCO/yoco/models/decoder/kernel/swiglu.py b/YOCO/yoco/models/decoder/kernel/swiglu.py new file mode 100644 index 00000000..d57589d2 --- /dev/null +++ b/YOCO/yoco/models/decoder/kernel/swiglu.py @@ -0,0 +1,32 @@ +import torch + + +swiglu_fwd_codestring = """ +template T swiglu_fwd(T x, T y) { + return float(x) * float(y) / (1.0f + ::exp(-float(x))); +} +""" +swiglu_bwd_codestring = """ +template T swiglu_bwd(T x, T y, T g, T& dx, T& dy) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); + dy = float(x) * x_sigmoid * float(g); +} +""" +swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring) +swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2) + + +class SwiGLUFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + return swiglu_fwd(x, y) + + @staticmethod + def backward(ctx, dout): + x, y = ctx.saved_tensors + return swiglu_bwd(x, y, dout) + +swiglu = SwiGLUFunction.apply diff --git a/YOCO/yoco/models/decoder/model_parallel_init.py b/YOCO/yoco/models/decoder/model_parallel_init.py new file mode 100644 index 00000000..3eb50a85 --- /dev/null +++ b/YOCO/yoco/models/decoder/model_parallel_init.py @@ -0,0 +1,16 @@ +import math + +import torch +import torch.nn as nn + +def init_method(tensor, **kwargs): + nn.init.kaiming_uniform_(tensor, a=math.sqrt(5)) + +def qkvg_init_method(tensor, **kwargs): + nn.init.xavier_uniform_(tensor, gain = 2 ** -2.5) + +def out_init_method(tensor, **kwargs): + nn.init.xavier_uniform_(tensor, gain = 2 ** -1) + +def vocab_init_method(tensor, **kwargs): + torch.nn.init.normal_(tensor, mean=0, std=tensor.shape[1] ** -0.5) diff --git a/YOCO/yoco/models/decoder/rms_norm.py b/YOCO/yoco/models/decoder/rms_norm.py new file mode 100644 index 00000000..fccb027e --- /dev/null +++ b/YOCO/yoco/models/decoder/rms_norm.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True): + super().__init__() + self.dim = dim + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_parameter('weight', None) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if self.weight is not None: + output = output * self.weight + return output + + def extra_repr(self) -> str: + return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}' + \ No newline at end of file diff --git a/YOCO/yoco/models/decoder/sliding_window_attention.py b/YOCO/yoco/models/decoder/sliding_window_attention.py new file mode 100644 index 00000000..3d744956 --- /dev/null +++ b/YOCO/yoco/models/decoder/sliding_window_attention.py @@ -0,0 +1,68 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from fairseq.model_parallel.megatron.mpu import ( + ColumnParallelLinear, + RowParallelLinear, +) + +from .model_parallel_init import init_method +from .kernel.rotary import apply_rotary_emb + +from flash_attn import flash_attn_func + +class SlidingWindowAttention(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + self.embed_dim = args.dim + self.num_heads = args.n_self_heads // args.model_parallel_size + self.window_size = args.sliding_window - 1 # compatible with flash attention + + self.head_dim = args.dim // args.n_self_heads + + self.q_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=init_method) + self.k_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=init_method) + self.v_proj = ColumnParallelLinear(args.dim, args.dim, bias=False, gather_output=False, init_method=init_method) + self.out_proj = RowParallelLinear(args.dim, args.dim, bias=False, input_is_parallel=True, init_method=init_method) + + def forward( + self, + x, + rel_pos, + start_pos=0, + incremental_state=None, + ): + bsz, tgt_len, embed_dim = x.size() + src_len = tgt_len + + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + q = q.view(bsz, tgt_len, self.num_heads, self.head_dim) + k = k.view(bsz, src_len, self.num_heads, self.head_dim) + v = v.view(bsz, src_len, self.num_heads, self.head_dim) + + q = apply_rotary_emb(q, *rel_pos, interleaved=True) + k = apply_rotary_emb(k, *rel_pos, interleaved=True) + if incremental_state is not None: + if "prev_key" not in incremental_state: + incremental_state["prev_key"] = torch.empty(self.args.max_batch_size, self.window_size, self.num_heads, self.head_dim, device=x.device, dtype=x.dtype) + incremental_state["prev_value"] = torch.empty(self.args.max_batch_size, self.window_size, self.num_heads, self.head_dim, device=x.device, dtype=x.dtype) + + key = torch.cat([incremental_state["prev_key"][:bsz, :start_pos], k], dim=1) + value = torch.cat([incremental_state["prev_value"][:bsz, :start_pos], v], dim=1) + if key.shape[1] > self.window_size: + incremental_state["prev_key"][:bsz] = key[:, -self.window_size:] + incremental_state["prev_value"][:bsz] = value[:, -self.window_size:] + else: + incremental_state["prev_key"][:bsz, start_pos : start_pos + tgt_len] = k + incremental_state["prev_value"][:bsz, start_pos : start_pos + tgt_len] = v + + attn = flash_attn_func(q, k, v, causal=True, window_size=(self.window_size - 1, 0)) + attn = attn.reshape(bsz, tgt_len, self.head_dim * self.num_heads) + + attn = self.out_proj(attn) + return attn \ No newline at end of file diff --git a/YOCO/yoco/models/decoder/transformer.py b/YOCO/yoco/models/decoder/transformer.py new file mode 100644 index 00000000..f41edf58 --- /dev/null +++ b/YOCO/yoco/models/decoder/transformer.py @@ -0,0 +1,251 @@ +import json +import math +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple + +import torch +from torch import nn + +from flash_attn import flash_attn_func + +from fairseq.model_parallel.megatron.mpu import ( + ColumnParallelLinear, + RowParallelLinear, + copy_to_model_parallel_region, + VocabParallelEmbedding +) + +from fairscale.nn import checkpoint_wrapper + +from .rms_norm import RMSNorm +from .kernel.rotary import apply_rotary_emb +from .model_parallel_init import init_method, vocab_init_method + + +def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + return freqs + + +@dataclass +class ModelArgs: + dim: int + n_layers: int + head_dim: int + hidden_dim: int + n_heads: int + n_kv_heads: int + norm_eps: float + vocab_size: int + + max_batch_size: int = 0 + max_seq_len: int = -1 + model_parallel_size: int = 1 + load_checkpoint: bool = False + rope_theta: float = 10000.0 + sliding_window: Optional[int] = None + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.dim = args.dim + self.head_dim = args.head_dim + self.hidden_dim = args.n_heads * args.head_dim + self.key_value_dim = args.n_kv_heads * args.head_dim + self.n_heads = args.n_heads // args.model_parallel_size + self.n_kv_heads = args.n_kv_heads // args.model_parallel_size + self.activate_sliding_window = args.sliding_window is not None + self.cache_len = args.sliding_window - 1 if self.activate_sliding_window else args.max_seq_len + + self.repeats = self.n_heads // self.n_kv_heads + + self.scale = self.args.head_dim**-0.5 + + self.wq = ColumnParallelLinear(self.dim, self.hidden_dim, bias=False, gather_output=False, init_method=init_method) + self.wk = ColumnParallelLinear(self.dim, self.key_value_dim, bias=False, gather_output=False, init_method=init_method) + self.wv = ColumnParallelLinear(self.dim, self.key_value_dim, bias=False, gather_output=False, init_method=init_method) + self.wo = RowParallelLinear(self.hidden_dim, self.dim, bias=False, input_is_parallel=True, init_method=init_method) + + def forward( + self, + x: torch.Tensor, + rel_pos: Tuple[torch.Tensor, torch.Tensor], + start_pos: int, + incremental_state = None, + ) -> torch.Tensor: + bsz, seqlen, _ = x.shape + + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + xq = apply_rotary_emb(xq, *rel_pos) + xk = apply_rotary_emb(xk, *rel_pos) + if incremental_state is not None: + if "cache_k" not in incremental_state: + incremental_state["cache_k"] = torch.zeros( + ( + self.args.max_batch_size, + self.cache_len, + self.n_kv_heads, + self.head_dim, + ) + ).to(xk) + incremental_state["cache_v"] = torch.zeros( + ( + self.args.max_batch_size, + self.cache_len, + self.n_kv_heads, + self.head_dim, + ) + ).to(xv) + key = torch.cat([incremental_state["cache_k"][:, :start_pos], xk], dim=1) + value = torch.cat([incremental_state["cache_v"][:, :start_pos], xv], dim=1) + if key.shape[1] > self.cache_len: + incremental_state["cache_k"][:bsz] = key[:, -self.cache_len:] + incremental_state["cache_v"][:bsz] = value[:, -self.cache_len:] + else: + incremental_state["cache_k"][:bsz, start_pos : start_pos + seqlen] = xk + incremental_state["cache_v"][:bsz, start_pos : start_pos + seqlen] = xv + else: + key, value = xk, xv + + output = flash_attn_func(xq, key, value, causal=True, window_size=(self.args.sliding_window - 1, 0) if self.activate_sliding_window else (-1, -1)) + + return self.wo(output.view(bsz, seqlen, self.n_heads * self.head_dim)) + + +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.w1 = ColumnParallelLinear(args.dim, args.hidden_dim, bias=False, gather_output=False, init_method=init_method) + self.w2 = RowParallelLinear(args.hidden_dim, args.dim, bias=False, input_is_parallel=True, init_method=init_method) + self.w3 = ColumnParallelLinear(args.dim, args.hidden_dim, bias=False, gather_output=False, init_method=init_method) + + def forward(self, x) -> torch.Tensor: + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.attention = Attention(args) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.args = args + + self.feed_forward: nn.Module + self.feed_forward = FeedForward(args=args) + + def forward( + self, x: torch.Tensor, rel_pos: Tuple[torch.Tensor, torch.Tensor], start_pos: int, incremental_state = None + ) -> torch.Tensor: + r = self.attention.forward(self.attention_norm(x), rel_pos, start_pos, incremental_state) + h = x + r + r = self.feed_forward.forward(self.ffn_norm(h)) + out = h + r + return out + + +class Transformer(nn.Module): + def __init__( + self, + args: ModelArgs, + mp_rank: int = 0, + checkpoint_activations: bool = False + ): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.n_layers = args.n_layers + self._precomputed_freqs_cis: Optional[torch.Tensor] = None + self._window_precomputed_freqs_cis: Optional[torch.Tensor] = None + self._global_precomputed_freqs_cis: Optional[torch.Tensor] = None + assert self.vocab_size > 0 + self.mp_rank = mp_rank + self.checkpoint_activations = checkpoint_activations + self.tok_embeddings = VocabParallelEmbedding( + args.vocab_size, args.dim, -1, init_method=vocab_init_method + ) + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.output = nn.Linear(args.dim, args.vocab_size // args.model_parallel_size, bias=False) + # Initialize all layers but slice off those not of this rank. + layers = [TransformerBlock(args=args) for idx in range(args.n_layers)] + if checkpoint_activations: + layers = [checkpoint_wrapper(layer) for layer in layers] + self.layers = nn.ModuleList(layers) + self.n_local_layers = len(self.layers) + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def build_rel_pos(self, x, start_pos): + if self._precomputed_freqs_cis is None: + theta = self.args.rope_theta + self._precomputed_freqs_cis = precompute_freqs_cis( + self.args.head_dim, self.args.max_seq_len, theta + ) + if self._precomputed_freqs_cis.device != self.device: + self._precomputed_freqs_cis = self._precomputed_freqs_cis.to( + device=self.device + ) + cos = torch.cos(self._precomputed_freqs_cis[start_pos:start_pos+x.size(1)]) + sin = torch.sin(self._precomputed_freqs_cis[start_pos:start_pos+x.size(1)]) + rel_pos = (cos.to(x.dtype), sin.to(x.dtype)) + return rel_pos + + def forward_partial( + self, + input_ids: torch.Tensor, + start_pos: Optional[int] = 0, + incremental_state = None, + ) -> torch.Tensor: + h = self.tok_embeddings(input_ids) + rel_pos = self.build_rel_pos(h, start_pos) + for local_layer_id, layer in enumerate(self.layers): + if incremental_state is not None: + if local_layer_id not in incremental_state: + incremental_state[local_layer_id] = {} + h = layer(h, rel_pos, start_pos, incremental_state=incremental_state[local_layer_id] if incremental_state is not None else None) + + return self.norm(h) + + def forward( + self, + input_ids: torch.Tensor, + start_pos: Optional[int] = 0, + incremental_state = None, + ) -> torch.Tensor: + h = self.forward_partial(input_ids, start_pos, incremental_state) + if self.args.model_parallel_size > 1: + h = copy_to_model_parallel_region(h) + outs = self.output(h) + return outs.float(), None + + def load_state_dict(self, state_dict, strict=False, assign=False): + state_to_load = {} + for k, v in state_dict.items(): + if k.startswith("tok_embeddings") or k.startswith("output"): + state_to_load[k] = v.view(self.args.model_parallel_size, self.vocab_size // self.args.model_parallel_size, self.args.dim)[self.mp_rank] + elif "wq" in k or "wk" in k or "wv" in k or "w1" in k or "w3" in k: + state_to_load[k] = v.view(self.args.model_parallel_size, -1, v.shape[1])[self.mp_rank] + elif "wo" in k or "w2" in k: + state_to_load[k] = v.view(v.shape[0], self.args.model_parallel_size, -1)[:, self.mp_rank] + else: + state_to_load[k] = v + super().load_state_dict(state_to_load, strict=False, assign=assign) + print("Loaded state dict from checkpoint.") diff --git a/YOCO/yoco/models/decoder/yoco.py b/YOCO/yoco/models/decoder/yoco.py new file mode 100644 index 00000000..6fb0d01e --- /dev/null +++ b/YOCO/yoco/models/decoder/yoco.py @@ -0,0 +1,294 @@ +import math +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairscale.nn import checkpoint_wrapper + +from fairseq.model_parallel.megatron.mpu import ( + ColumnParallelLinear, + copy_to_model_parallel_region, + VocabParallelEmbedding +) + +from .gate_retention import GateRetention +from .sliding_window_attention import SlidingWindowAttention +from .cross_attention import CrossAttention +from .feedforward_network import FeedForwardNetwork, init_method +from .rms_norm import RMSNorm + +from .kernel.rotary import apply_rotary_emb +from .model_parallel_init import vocab_init_method, init_method + + +@dataclass +class YOCOArgs: + dim: int + n_layers: int + hidden_dim: int + n_self_heads: int + n_attn_heads: int + n_attn_kv_heads: int + vocab_size: int + + max_batch_size: int = 0 + max_seq_len: int = -1 + model_parallel_size: int = 1 + load_checkpoint: bool = False + rope_theta: float = 10000.0 + norm_eps: float = 1e-5 + sliding_window: Optional[int] = None + +class DecoderLayer(nn.Module): + def __init__( + self, + args: YOCOArgs, + is_cross_layer=False + ): + super().__init__() + self.args = args + self.is_cross_layer = is_cross_layer + + if is_cross_layer: + self.mixer = CrossAttention(args) + elif args.sliding_window is not None: + self.mixer = SlidingWindowAttention(args) + else: + self.mixer = GateRetention(args) + + self.mixer_layer_norm = RMSNorm(args.dim, eps=args.norm_eps) + + self.ffn = FeedForwardNetwork( + args.dim, + args.hidden_dim, + args.load_checkpoint + ) + + self.final_layer_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x, + start_pos=0, + key=None, + value=None, + rel_pos=None, + incremental_state=None, + is_prefilling=False, + ): + residual = x + x = self.mixer_layer_norm(x) + + if self.is_cross_layer: + x = self.mixer( + x, + key, + value, + rel_pos=rel_pos, + ) + elif self.args.sliding_window is not None: + x = self.mixer( + x, + rel_pos=rel_pos, + start_pos=start_pos, + incremental_state=incremental_state, + ) + else: + x = self.mixer( + x, + rel_pos=rel_pos, + incremental_state=incremental_state, + is_prefilling=is_prefilling,) + + x = x + residual + residual = x + x = self.final_layer_norm(x) + + x = self.ffn(x) + + x = x + residual + return x + +class SelfDecoder(nn.Module): + def __init__( + self, + args: YOCOArgs, + checkpoint_activations: bool = False + ): + super().__init__() + self.args = args + layers = [DecoderLayer(args, is_cross_layer=False,) for idx in range(args.n_layers // 2)] + if checkpoint_activations: + layers = [checkpoint_wrapper(layer) for layer in layers] + self.layers = nn.ModuleList(layers) + self.head_dim = args.dim // args.n_self_heads + self.block_size = 256 + self._precomputed_freqs_cis = None + + def build_rel_pos(self, x, start_pos): + if self._precomputed_freqs_cis is None: + angle = 1.0 / (self.args.rope_theta ** torch.linspace(0, 1, self.head_dim // 2, dtype=torch.float, device=x.device)) + index = torch.arange(self.args.max_seq_len).to(angle) + self._precomputed_freqs_cis = index[:, None] * angle + + cos = torch.cos(self._precomputed_freqs_cis[start_pos:start_pos+x.size(1)]) + sin = torch.sin(self._precomputed_freqs_cis[start_pos:start_pos+x.size(1)]) + rel_pos = (cos.to(x.dtype), sin.to(x.dtype)) + return rel_pos + + def get_index_mask(self, x, length, pad_length): + return torch.arange(pad_length, device=x.device) >= length + + def forward( + self, + x, + incremental_state=None, + is_prefilling=False, + start_pos=0 + ): + if is_prefilling and x.size(1) % self.block_size != 0 and self.args.sliding_window is None: + padding_len = self.block_size - x.size(1) % self.block_size + x = F.pad(x, (0, 0, 0, padding_len), value=0) + else: + padding_len = 0 + + if incremental_state is not None and is_prefilling: + index_mask = self.get_index_mask(x, x.size(1) - padding_len, x.size(1)) + + rel_pos = self.build_rel_pos(x, start_pos) + for idx, layer in enumerate(self.layers): + if incremental_state is not None: + if idx not in incremental_state: + incremental_state[idx] = {} + if is_prefilling: + incremental_state[idx]["index_mask"] = index_mask + x = layer( + x, + start_pos=start_pos, + rel_pos=rel_pos, + incremental_state=incremental_state[idx] if incremental_state is not None else None, + is_prefilling=is_prefilling,) + + x = x[:, :x.size(1) - padding_len, :] + return x + +class CrossDecoder(nn.Module): + def __init__( + self, + args: YOCOArgs, + checkpoint_activations: bool = False + ): + super().__init__() + self.args = args + self.num_heads = args.n_attn_kv_heads + self.head_dim = args.dim // args.n_attn_heads + self.k_proj = ColumnParallelLinear(args.dim, self.head_dim * args.n_attn_kv_heads, bias=False, gather_output=False, init_method=init_method) + self.v_proj = ColumnParallelLinear(args.dim, self.head_dim * args.n_attn_kv_heads, bias=False, gather_output=False, init_method=init_method) + self.kv_layer_norm = RMSNorm(args.dim, eps=args.norm_eps) + layers = [DecoderLayer(args, is_cross_layer=True) for idx in range(args.n_layers // 2)] + if checkpoint_activations: + layers = [checkpoint_wrapper(layer) for layer in layers] + self.layers = nn.ModuleList(layers) + self._precomputed_freqs_cis = None + + def build_rel_pos(self, x, start_pos): + if self._precomputed_freqs_cis is None: + angle = 1.0 / (self.args.rope_theta ** torch.linspace(0, 1, self.head_dim // 2, dtype=torch.float, device=x.device)) + index = torch.arange(self.args.max_seq_len).to(angle) + self._precomputed_freqs_cis = index[:, None] * angle + + cos = torch.cos(self._precomputed_freqs_cis[start_pos:start_pos+x.size(1)]) + sin = torch.sin(self._precomputed_freqs_cis[start_pos:start_pos+x.size(1)]) + rel_pos = (cos.to(x.dtype), sin.to(x.dtype)) + return rel_pos + + def forward( + self, + x, + incremental_state=None, + start_pos=0, + skip_cross_decoder=False, + ): + bsz, seqlen, embed_dim = x.size() + x_norm = self.kv_layer_norm(x) + key, value = self.k_proj(x_norm), self.v_proj(x_norm) + key = key.view(bsz, seqlen, self.num_heads, self.head_dim) + value = value.view(bsz, seqlen, self.num_heads, self.head_dim) + rel_pos = self.build_rel_pos(x, start_pos) + key = apply_rotary_emb(key, *rel_pos, interleaved=True) + if incremental_state is not None: + if "prev_key" not in incremental_state: + incremental_state["prev_key"] = torch.empty(bsz, self.args.max_seq_len, self.num_heads, self.head_dim, device=x.device, dtype=x.dtype) + incremental_state["prev_value"] = torch.empty(bsz, self.args.max_seq_len, self.num_heads, self.head_dim, device=x.device, dtype=x.dtype) + incremental_state["prev_key"][:, start_pos : start_pos + seqlen] = key + incremental_state["prev_value"][:, start_pos : start_pos + seqlen] = value + key = incremental_state["prev_key"][:, : start_pos + seqlen] + value = incremental_state["prev_value"][:, : start_pos + seqlen] + + if skip_cross_decoder: + return torch.zeros(bsz, 1, embed_dim, device=x.device, dtype=x.dtype) + for layer in self.layers: + x = layer( + x, + key=key, + value=value, + rel_pos=rel_pos) + + return x + +class YOCO(nn.Module): + def __init__( + self, + args: YOCOArgs, + checkpoint_activations: bool = False, + share_input_output_embed: bool = False, + ): + super().__init__() + self.args = args + self.embed_scale = math.sqrt(args.dim) + self.embed_tokens = VocabParallelEmbedding( + args.vocab_size, args.dim, -1, init_method=vocab_init_method + ) + self.output_projection = nn.Linear(args.dim, args.vocab_size, bias=False) + if share_input_output_embed: + self.output_projection.weight = self.embed_tokens.weight + + self.self_decoder = SelfDecoder(args, checkpoint_activations) + self.cross_decoder = CrossDecoder(args, checkpoint_activations) + self.layer_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x, + start_pos=0, + incremental_state=None, + is_prefilling=True, + skip_cross_decoder=False + ): + x = self.embed_scale * self.embed_tokens(x) + + x = self.self_decoder( + x, + incremental_state=incremental_state, + is_prefilling=is_prefilling, + start_pos=start_pos, + ) + + x = self.cross_decoder( + x, + start_pos=start_pos, + incremental_state=incremental_state, + skip_cross_decoder=skip_cross_decoder, + ) + + x = self.layer_norm(x) + x = self.output_layer(x) + + return x, None + + def output_layer(self, features): + if self.args.model_parallel_size > 1: + features = copy_to_model_parallel_region(features) + return self.output_projection(features) \ No newline at end of file diff --git a/YOCO/yoco/models/transformer.py b/YOCO/yoco/models/transformer.py new file mode 100644 index 00000000..3fcaa78d --- /dev/null +++ b/YOCO/yoco/models/transformer.py @@ -0,0 +1,141 @@ +import json +import os +from dataclasses import dataclass, field +from typing import Optional + +import torch + +from fairseq.model_parallel.megatron.mpu import ( + initialize_model_parallel, + model_parallel_is_initialized, + get_model_parallel_rank +) + +from fairseq.dataclass import FairseqDataclass +from fairseq.models import ( + FairseqIncrementalDecoder, + FairseqLanguageModel, + register_model, + register_model_architecture, +) + +from omegaconf import II + +from .decoder.transformer import ModelArgs, Transformer + +DEFAULT_MAX_TARGET_POSITIONS = 4096 + +@dataclass +class LanguageConfig(FairseqDataclass): + llama_model: Optional[str] = field( + default=None, + metadata={"help": "path to load tokenizer and config"}, + ) + load_ckpt: Optional[str] = field( + default=None, + metadata={"help": "path to load checkpoint from"}, + ) + init_from_config: bool = field( + default=False, + ) + dim: int = field( + default=1024, + ) + n_layers: int = field( + default=8, + ) + n_heads: int = field( + default=8, + ) + n_kv_heads: int = field( + default=2, + ) + batch_size: int = field( + default=1, + ) + rope_theta: Optional[float] = field( + default=10000.0, + ) + checkpoint_activations: bool = field( + default=False, metadata={"help": "checkpoint activations at each layer"} + ) + tokens_per_sample: int = II("task.tokens_per_sample") + model_parallel_size: int = II("common.model_parallel_size") + +@register_model("llama", dataclass=LanguageConfig) +class LanguageModel(FairseqLanguageModel): + def __init__(self, args, decoder, tokenizer): + self.args = args + self.tokenizer = tokenizer + super().__init__(decoder) + + @classmethod + def build_model(cls, args, task): + if not model_parallel_is_initialized(): + initialize_model_parallel(args.model_parallel_size) + + if not args.init_from_config: + params = { + "dim": args.dim, + "n_layers": args.n_layers, + "n_heads": args.n_heads, + "head_dim": args.dim // args.n_heads, + "n_kv_heads": args.n_kv_heads, + "hidden_dim": int(args.dim * 8 / 3), + "vocab_size": task.tokenizer.n_words, + "max_batch_size": args.batch_size, + "max_seq_len": args.tokens_per_sample, + "model_parallel_size": args.model_parallel_size, + "load_checkpoint": args.load_ckpt is not None, + "rope_theta": args.rope_theta, + } + model_args: ModelArgs = ModelArgs( + **params, + ) + else: + with open(os.path.join(args.llama_model, "params.json"), "r") as f: + params = json.load(f) + model_args = ModelArgs(**params) + model_args.max_batch_size = args.batch_size + model_args.max_seq_len = args.tokens_per_sample + model_args.model_parallel_size = args.model_parallel_size + model_args.load_checkpoint = args.load_ckpt is not None + model = Transformer( + model_args, + mp_rank=get_model_parallel_rank(), + checkpoint_activations=args.checkpoint_activations, + ) + if args.load_ckpt is not None: + loaded = torch.load(args.load_ckpt, mmap=True) + model.load_state_dict(loaded, assign=True) + + model = LLaMA(model) + return cls(args, model, task.tokenizer) + +class LLaMA(FairseqIncrementalDecoder): + def __init__(self, model): + super().__init__(None) + self.model = model + + def forward(self, src_tokens, start_pos = 0, **kwargs): + padding = src_tokens < 0 + src_tokens = torch.where(padding, torch.zeros_like(src_tokens), src_tokens) + return self.model.forward(src_tokens, start_pos, **kwargs) + + def max_positions(self): + return self.model.args.max_seq_len + +@register_model_architecture("llama", "llama_from_scratch") +def llama_from_scratch(args): + args.init_from_config = getattr(args, "init_from_config", False) + args.dim = getattr(args, "dim", 1024) + args.n_layers = getattr(args, "n_layers", 8) + args.n_heads = getattr(args, "n_heads", 8) + args.n_kv_heads = getattr(args, "n_kv_heads", 2) + +@register_model_architecture("llama", "llama_from_ckpt") +def llama_from_ckpt(args): + args.init_from_config = getattr(args, "init_from_config", True) + + + \ No newline at end of file diff --git a/YOCO/yoco/models/yoco.py b/YOCO/yoco/models/yoco.py new file mode 100644 index 00000000..580d15bd --- /dev/null +++ b/YOCO/yoco/models/yoco.py @@ -0,0 +1,158 @@ +import os +import json +import logging +from dataclasses import dataclass, field +from typing import Optional + +import torch +from fairseq import distributed_utils, utils +from fairseq.dataclass import FairseqDataclass +from fairseq.models import ( + FairseqIncrementalDecoder, + FairseqLanguageModel, + register_model, + register_model_architecture, +) + +from omegaconf import II + +from fairseq.model_parallel.megatron.mpu import ( + initialize_model_parallel, + model_parallel_is_initialized +) +from .decoder.yoco import YOCO, YOCOArgs + +DEFAULT_MAX_TARGET_POSITIONS = 4096 +logger = logging.getLogger(__name__) + + +@dataclass +class LanguageConfig(FairseqDataclass): + yoco_model: Optional[str] = field( + default=None, + metadata={"help": "path to load params from"}, + ) + load_ckpt: Optional[str] = field( + default=None, + metadata={"help": "path to load checkpoint from"}, + ) + dim: int = field( + default=1024, + ) + hidden_dim: int = field( + default=3072, + ) + n_layers: int = field( + default=24, + ) + n_self_heads: int = field( + default=4, + ) + n_attn_heads: int = field( + default=8, + ) + n_attn_kv_heads: Optional[int] = field( + default=None, + ) + batch_size: int = field( + default=1, + ) + share_input_output_embed: bool = field( + default=False, metadata={"help": "share decoder input and output embeddings"} + ) + sliding_window: Optional[bool] = field( + default=None, + ) + rope_theta: Optional[float] = field( + default=10000.0, + ) + checkpoint_activations: bool = field( + default=False, metadata={"help": "checkpoint activations at each layer"} + ) + tokens_per_sample: int = II("task.tokens_per_sample") + model_parallel_size: int = II("common.model_parallel_size") + + +@register_model("yoco", dataclass=LanguageConfig) +class LanguageModel(FairseqLanguageModel): + def __init__(self, args, decoder, tokenizer): + self.args = args + self.tokenizer = tokenizer + super().__init__(decoder) + + @classmethod + def build_model(cls, args, task): + if not model_parallel_is_initialized(): + initialize_model_parallel(args.model_parallel_size) + + if args.yoco_model is None: + params = { + "dim": args.dim, + "n_layers": args.n_layers, + "n_self_heads": args.n_self_heads, + "n_attn_heads": args.n_attn_heads, + "n_attn_kv_heads": args.n_attn_kv_heads, + "hidden_dim": args.hidden_dim, + "vocab_size": task.tokenizer.n_words, + "max_batch_size": args.batch_size, + "max_seq_len": args.tokens_per_sample, + "model_parallel_size": args.model_parallel_size, + "load_checkpoint": args.load_ckpt is not None, + "rope_theta": args.rope_theta, + } + model_args: YOCOArgs = YOCOArgs( + **params, + ) + else: + with open(os.path.join(args.yoco_model, "params.json"), "r") as f: + params = json.load(f) + model_args = YOCOArgs(**params) + model_args.max_batch_size = args.batch_size + model_args.max_seq_len = args.tokens_per_sample + model_args.model_parallel_size = args.model_parallel_size + model_args.load_checkpoint = args.load_ckpt is not None + + model = YOCO( + model_args, + checkpoint_activations=args.checkpoint_activations, + ) + if args.load_ckpt is not None: + loaded = torch.load(args.load_ckpt, mmap=True) + model.load_state_dict(loaded, assign=True) + model = YOCOModel(model) + return cls(args, model, task.tokenizer) + +class YOCOModel(FairseqIncrementalDecoder): + def __init__(self, model): + super().__init__(None) + self.model = model + + def forward(self, src_tokens, **kwargs): + return self.model.forward(src_tokens, **kwargs) + + def max_positions(self): + return self.model.args.max_seq_len + +def default(args): + args.n_attn_kv_heads = getattr(args, "n_attn_kv_heads", args.n_attn_heads) + args.sliding_window = getattr(args, "sliding_window", False) + args.rope_theta = getattr(args, "rope_theta", 10000.0) + args.share_input_output_embed = getattr( + args, "share_input_output_embed", False + ) + args.checkpoint_activations = getattr(args, "checkpoint_activations", False) + + +@register_model_architecture("yoco", "yoco_3b") +def yoco_3b(args): + args.dim = getattr(args, "dim", 3072) + args.hidden_dim = getattr(args, "hidden_dim", 8192) + args.n_layers = getattr(args, "n_layers", 26) + args.n_self_heads = getattr(args, "n_self_heads", 24) + args.n_attn_heads = getattr(args, "n_attn_heads", 24) + args.n_attn_kv_heads = getattr(args, "n_attn_kv_heads", 8) + default(args) + + + + diff --git a/YOCO/yoco/tasks/__init__.py b/YOCO/yoco/tasks/__init__.py new file mode 100644 index 00000000..1da9d123 --- /dev/null +++ b/YOCO/yoco/tasks/__init__.py @@ -0,0 +1,32 @@ +import argparse +import importlib +import os + +# register dataclass +TASK_DATACLASS_REGISTRY = {} +TASK_REGISTRY = {} +TASK_CLASS_NAMES = set() + +# automatically import any Python files in the tasks/ directory +tasks_dir = os.path.dirname(__file__) +for file in os.listdir(tasks_dir): + path = os.path.join(tasks_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + task_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("tasks." + task_name) + + # expose `task_parser` for sphinx + if task_name in TASK_REGISTRY: + parser = argparse.ArgumentParser(add_help=False) + group_task = parser.add_argument_group("Task name") + # fmt: off + group_task.add_argument('--task', metavar=task_name, + help='Enable this task with: ``--task=' + task_name + '``') + # fmt: on + group_args = parser.add_argument_group("Additional command-line arguments") + TASK_REGISTRY[task_name].add_args(group_args) + globals()[task_name + "_parser"] = parser diff --git a/YOCO/yoco/tasks/data/__init__.py b/YOCO/yoco/tasks/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/YOCO/yoco/tasks/data/basic_loader.py b/YOCO/yoco/tasks/data/basic_loader.py new file mode 100644 index 00000000..d6f06f2a --- /dev/null +++ b/YOCO/yoco/tasks/data/basic_loader.py @@ -0,0 +1,75 @@ +import torch +from infinibatch.iterators import CheckpointableIterator + +from . import utils + + +class BaseBatchGen(CheckpointableIterator): + """ + This is a base class for batch generators that use infinibatch + """ + + def __init__(self): + self._iter = None + self.epoch = 1 + self.next_epoch_idx = 1 + self.sharded_checkpoint = True + self.should_close_after_finished = True + + def _build_iter(self): + """ + Build infinibatch iterator and assign to self._iter + """ + raise NotImplementedError() + + def _move_to_tensor(self, batch): + def to_tensor(x): + return torch.tensor(x) + + return utils.apply_to_sample(to_tensor, batch) + + @property + def iterator(self): + if self._iter is None: + raise NotImplementedError("_build_iter() must called first") + return self._iter + + def __iter__(self): + if self._iter is None: + raise NotImplementedError("_build_iter() must called first") + return self._iter + + def __next__(self): + return next(self._iter) + + def setstate(self, value): + self._iter.setstate(value) + + def getstate(self): + return self._iter.getstate() + + def close(self): + self._iter.close() + + def __len__(self) -> int: + return 819200000 + + def next_epoch_itr( + self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True + ): + return self + + def end_of_epoch(self) -> bool: + return False + + def state_dict(self): + """Returns a dictionary containing a whole state of the iterator.""" + return self.getstate() + + def load_state_dict(self, state_dict): + """Copies the state of the iterator from the given *state_dict*.""" + self.setstate(state_dict) + + @property + def first_batch(self): + return "DUMMY" diff --git a/YOCO/yoco/tasks/data/llama_tokenizer.py b/YOCO/yoco/tasks/data/llama_tokenizer.py new file mode 100644 index 00000000..fad3d206 --- /dev/null +++ b/YOCO/yoco/tasks/data/llama_tokenizer.py @@ -0,0 +1,38 @@ +from pathlib import Path +from sentencepiece import SentencePieceProcessor +from typing import List + + +class LLaMATokenizer: + def __init__(self, model_path: str): + assert Path(model_path).exists(), model_path + self._model = SentencePieceProcessor(model_file=model_path) + assert self._model.vocab_size() == self._model.get_piece_size() + + @property + def n_words(self) -> int: + return self._model.vocab_size() + + @property + def bos_id(self) -> int: + return self._model.bos_id() + + @property + def eos_id(self) -> int: + return self._model.eos_id() + + @property + def pad_id(self) -> int: + return self._model.pad_id() + + def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]: + assert isinstance(s, str) + t = self._model.encode(s) + if bos: + t = [self.bos_id, *t] + if eos: + t = [*t, self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + return self._model.decode(t) \ No newline at end of file diff --git a/YOCO/yoco/tasks/data/lm_loader.py b/YOCO/yoco/tasks/data/lm_loader.py new file mode 100644 index 00000000..825a8223 --- /dev/null +++ b/YOCO/yoco/tasks/data/lm_loader.py @@ -0,0 +1,303 @@ +import os +import random +import math +import numpy as np +import json + +from infinibatch import iterators +from .utils import FixedBlockwiseShuffleIterator, NativeCheckpointableIterator, WeightNoRandomStateIterator +from .basic_loader import BaseBatchGen + + +class LMLoader(BaseBatchGen): + def __init__( + self, + args, + dataset, + tokenizer, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + epoch=1, + num_shards=1, + shard_id=0, + reject_sampling=1, + ): + super().__init__() + self.args = args + self.data = dataset.data + self.data_dir = dataset.data_dir + self.shuffle = dataset.shuffle + self.tokenizer = tokenizer + + self.max_tokens = max_tokens + self.max_sentences = max_sentences + self.max_positions = max_positions + self.tokens_per_sample = args.tokens_per_sample + self.mlm_cut_length = getattr(args, "mlm_cut_length", 0) + self.mlm_tokens_proportion = getattr(args, "mlm_tokens_proportion", 0) + self.pad_to_max_len = getattr(args, "pad_to_max_len", False) + self.ignore_invalid_inputs = ignore_invalid_inputs + self.required_batch_size_multiple = required_batch_size_multiple + self.seed = str(seed) + self.epoch = epoch + self.num_shards = num_shards + self.shard_id = shard_id + + self.batch_read_ahead = args.batch_read_ahead + self.sharded_checkpoint = True + + self._build_iter() + + def _build_iter(self): + tokenized_lines = self._tokenize() + self.padded_batches = self._batchify(tokenized_lines) + + prefetch_batches = iterators.PrefetchIterator( + self.padded_batches, + buffer_size=10, + buffer_in_main_process=True, + log_empty_buffer_warning=True and self.shard_id == 0, + ) + + prefetch_batches = iterators.MapIterator( + prefetch_batches, self._move_to_tensor + ) + + self._iter = prefetch_batches + + def _tokenize(self): + ''' + data: + { + 'source': list[Path], + } + ''' + dataset = list(zip(self.data['source'])) + + if self.shuffle: + chunk_files = \ + iterators.InfinitePermutationSourceIterator( + dataset, + seed=self.seed, + shuffle=self.shuffle, + num_instances=self.num_shards, + instance_rank=self.shard_id, + ) + else: + chunk_files = \ + iterators.ChunkedSourceIterator( + dataset, + num_instances=self.num_shards, + instance_rank=self.shard_id, + ) + + tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files)) + tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed) + + return tokenized_lines + + def getstate(self): + state = super().getstate() + state["epoch"] = self.epoch + state["iterations_in_epoch"] = None + return state + + def _batchify(self, lines): + + if self.max_sentences is not None: + if self.batch_read_ahead > 0: + lines = FixedBlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed) + batches = iterators.FixedBatchIterator(lines, self.max_sentences) + else: + # - + def dynamic_batch_size(sample): + lengths = [len(x) for x in sample] + batch_size = self.max_tokens // max(lengths) // self.required_batch_size_multiple * self.required_batch_size_multiple + return max(1, batch_size) + + batches = iterators.BucketedReadaheadBatchIterator( + lines, + read_ahead=self.batch_read_ahead, + key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None, + batch_size=dynamic_batch_size, + shuffle=self.shuffle, + seed=self.seed, + ) + + def collate(batch): + batch_size = len(batch) + gpt_max_length = max([len(x[0]) for x in batch]) + if self.pad_to_max_len: + gpt_max_length = self.tokens_per_sample + 1 + + gpt_source_ids = np.full(shape=(batch_size, gpt_max_length-1), dtype=np.int32, + fill_value=self.tokenizer.pad_id) + gpt_target_ids = np.full(shape=(batch_size, gpt_max_length-1), dtype=np.int32, + fill_value=self.tokenizer.pad_id) + gpt_input_mask_all = np.full(shape=(batch_size, gpt_max_length-1), dtype=np.int32, fill_value=0) + gpt_loss_mask_all = np.full(shape=(batch_size, gpt_max_length-1), dtype=np.int32, fill_value=1) + + for i, (gpt_ids, gpt_input_mask, gpt_loss_mask) in enumerate(batch): + gpt_source_ids[i, :len(gpt_ids)-1] = gpt_ids[:-1] + gpt_target_ids[i, :len(gpt_ids)-1] = gpt_ids[1:] + gpt_input_mask_all[i, :len(gpt_ids)-1] = gpt_input_mask[:-1] + gpt_loss_mask_all[i, :len(gpt_ids)-1] = gpt_loss_mask[1:] + + ret_batch = { + 'net_input': { + 'src_tokens': gpt_source_ids.astype(np.int64), + }, + 'target': gpt_target_ids.astype(np.int64), + 'nsentences': batch_size, + 'ntokens': sum([len(x[0]) for x in batch]), + } + + return ret_batch + + padded_batches = iterators.MapIterator( + batches, collate + ) + + return padded_batches + + def _prepare(self, doc): + gpt_input_mask = [0] * len(doc) + gpt_loss_mask = [1] * len(doc) + full_tokens = doc + return full_tokens, gpt_input_mask, gpt_loss_mask + + def _tokenize(self): + multilingual_iters = [] + weights = [] + + for data in self.data: + multilingual_iters.append( + self._tokenize_foreach_lang(data) + ) + if 'weight' in data: + weights.append(float(data['weight'])) + else: + weights.append(int(data['count'])) + + if len(multilingual_iters) == 1: + return multilingual_iters[0] + + sampling_iterator = WeightNoRandomStateIterator(weights, self.seed) + control_iterator = NativeCheckpointableIterator(sampling_iterator) + tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters) + + return tokenized_lines + + def _tokenize_foreach_lang(self, data): + # if 'epoch' in data: + _random = random.Random(self.seed) + if 'source' not in data or len(data['source']) == 0: + # load source from single file, format: self.data_dir/json/{name}.json + file_path = os.path.join(self.data_dir, 'json', f"{data['name']}.json") + if not os.path.exists(file_path): + raise FileNotFoundError(f"file {file_path} not exists") + with open(file_path, 'r', encoding='utf8') as f: + data_source = json.load(f) + data['source'] = data_source + data_source = data['source'] + epoch_num = 50 + temp_list = math.ceil(epoch_num) * data_source + _random.shuffle(temp_list) + dataset = list(zip(temp_list)) + # print('data name: ', data['name'], 'len(dataset): ', len(dataset)) + chunk_files = iterators.ChunkedSourceIterator( + dataset, + num_instances=self.num_shards, + instance_rank=self.shard_id,) + + tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files)) + tokenized_lines = iterators.MapIterator(tokenized_lines, self._prepare) + + return tokenized_lines + + @staticmethod + def _doc_to_ids(text, tokenizer=None): + tokenized_ids = [] # list of list of ids + lines = text.split('\n\n') + for line_idx, line in enumerate(lines): + suffix = '\n\n' if line_idx != len(lines) - 1 else '' + if len(line) == 0: + continue + + sublines = line.split('\n') + for idx, subline in enumerate(sublines): + if len(subline) > 200000: + continue + if len(subline) == 0: + continue + if idx == len(sublines) - 1: + tokenized_ids.append(tokenizer.encode(subline + suffix)) + else: + tokenized_ids.append(tokenizer.encode(subline + '\n')) + + tokenized_ids[-1].append(tokenizer.eos_id) + return tokenized_ids + + def _read_lines(self, file_path): + try: + with open(file_path, 'r', encoding='utf8') as f: + lines = f.read().strip().split('\n') + except: + return iter([]) # skip bad file + return lines + + def _read_from_files(self, source_file): + data = [] + if self.args.absolute_path: + file_path = source_file + else: + file_path = os.path.join(self.data_dir, source_file) + + if not os.path.exists(file_path): + print('| file {} not exists'.format(file_path), flush=True) + return iter([]) # skip bad file + + lines = self._read_lines(file_path) + + tokenized_ids = [] + for doc_jsonstr in lines: + try: + json_obj = json.loads(doc_jsonstr) + + if 'text' in json_obj: + text = json_obj['text'] + elif 'content' in json_obj: + text = json_obj['content'] + elif 'raw_content_lines' in json_obj: + text = "\n".join(json_obj['raw_content_lines']) + else: + print('no text in json_obj') + + if len(text) == 0: + continue + ret = LMLoader._doc_to_ids(text, self.tokenizer) + tokenized_ids.extend(ret) + except Exception as e: + print(source_file, flush=True) + print(e, flush=True) + + # ################################################### + + doc = [self.tokenizer.bos_id] + for ids in tokenized_ids: + if len(doc) + len(ids) > self.tokens_per_sample + 1: + doc.extend(ids) + doc = doc[:self.tokens_per_sample + 1] + data.append(doc) + doc = [self.tokenizer.bos_id] + else: + doc.extend(ids) + + # if len(doc) > 1 and len(doc) <= self.tokens_per_sample + 1: + # data.append(doc) + return data + diff --git a/YOCO/yoco/tasks/data/tiktoken_tokenizer.py b/YOCO/yoco/tasks/data/tiktoken_tokenizer.py new file mode 100644 index 00000000..3a041cc0 --- /dev/null +++ b/YOCO/yoco/tasks/data/tiktoken_tokenizer.py @@ -0,0 +1,81 @@ +import tiktoken +from typing import List + + +class TiktokenTokenizer: + def __init__(self, + tiktoken_model: str, + tokenizer_pad_to_multiple: int = 8, + bos="", + pad="", + eos="", + unk="", + ): + self.symbols = [bos, pad, eos, unk] + self.indices = {s: i for i, s in enumerate(self.symbols)} + self.tokenizer_pad_to_multiple = tokenizer_pad_to_multiple + cl100k_base = tiktoken.get_encoding(tiktoken_model) + self._model = tiktoken.Encoding( + # If you're changing the set of special tokens, make sure to use a different name + # It should be clear from the name what behaviour to expect. + name="cl100k_im", + pat_str=cl100k_base._pat_str, + mergeable_ranks=cl100k_base._mergeable_ranks, + special_tokens={ + **cl100k_base._special_tokens, + "": 100264, + "": 100265, + "": 100266, + "": 100267, + "": 100268, + "": 100269, + "": 100270, + "": 100271, + "": 100272, + "": 100273, + "": 100274, + "": 100275, + "": 100276, + "": 100277, + "": 100278, + "": 100279, + "": 100280, + "": 100281, + } + ) + + @property + def n_words(self) -> int: + n_words = self._model.n_vocab + len(self.symbols) + n_words = (n_words + self.tokenizer_pad_to_multiple - 1) // self.tokenizer_pad_to_multiple * self.tokenizer_pad_to_multiple + return n_words + + @property + def bos_id(self) -> int: + return self.indices[""] + + @property + def eos_id(self) -> int: + return self.indices[""] + + @property + def pad_id(self) -> int: + return self.indices[""] + + @property + def unk_id(self) -> int: + return self.indices[""] + + def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]: + assert isinstance(s, str) + t = self._model.encode(s, allowed_special="all") + t = [i + len(self.symbols) for i in t] + if bos: + t = [self.bos_id, *t] + if eos: + t = [*t, self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + t = [i - len(self.symbols) for i in t if i >= len(self.symbols)] + return self._model.decode(t) \ No newline at end of file diff --git a/YOCO/yoco/tasks/data/utils.py b/YOCO/yoco/tasks/data/utils.py new file mode 100644 index 00000000..fd850d73 --- /dev/null +++ b/YOCO/yoco/tasks/data/utils.py @@ -0,0 +1,267 @@ +import collections +from random import Random +from typing import Dict, Iterable, Optional + +import torch +import numpy as np +from infinibatch import iterators +from infinibatch.iterators import CheckpointableIterator, FixedBatchIterator, SelectManyIterator, MapIterator + +from fairseq.data import BaseWrapperDataset, FairseqDataset, data_utils + +def apply_to_sample(f, sample): + if hasattr(sample, "__len__") and len(sample) == 0: + return {} + + def _apply(x): + if isinstance(x, np.ndarray): + return f(x) + elif isinstance(x, collections.OrderedDict): + # OrderedDict has attributes that needs to be preserved + od = collections.OrderedDict( + (key, _apply(value)) for key, value in x.items() + ) + od.__dict__ = x.__dict__ + return od + elif isinstance(x, dict): + return {key: _apply(value) for key, value in x.items()} + elif isinstance(x, list): + return [_apply(x) for x in x] + elif isinstance(x, tuple): + return tuple(_apply(x) for x in x) + elif isinstance(x, set): + return {_apply(x) for x in x} + else: + return x + + return _apply(sample) + + +class NativeCheckpointableIterator(iterators.CheckpointableIterator): + def __init__(self, iterable: Iterable): + self._input_iterable = iterable + self.setstate(None) + + def getstate(self) -> Dict: + return {"num_items_yielded": self._num_items_yielded} + + def setstate(self, checkpoint: Optional[Dict]): + self._iterator = iter(self._input_iterable) + self._num_items_yielded = ( + iterators._advance_iterator(self._iterator, checkpoint["num_items_yielded"]) + if checkpoint is not None + else 0 + ) + + def __next__(self): + item = next(self._iterator) + self._num_items_yielded += 1 + return item + + def close(self): + pass + + +class WeightIterator(object): + def __init__(self, weights, seed): + self.weights = weights + self.seed = seed + self.control_index = list(range(len(weights))) + self.setstate(None) + + def __iter__(self): + return self + + def getstate(self): + return {"random_state": self._random_state} + + def setstate(self, checkpoint): + self._random_state = checkpoint["random_state"] if checkpoint else None + self._random = ( + None # this will trigger the lazy initialization in self.__next__ + ) + + def __next__(self): + if self._random is None: + self._random = Random(self.seed) + if self._random_state is not None: + self._random.setstate(self._random_state) + idx = self._random.choices(self.control_index, self.weights)[0] + self._random_state = self._random.getstate() + return idx + + def close(self): + pass + + +def FixedBlockwiseShuffleIterator(source_iterator: CheckpointableIterator, block_size: int, seed: int=0): + """ + Shuffles a sequence of items by grouping consecutive items in blocks of fixed size, shuffling + each block, and yielding the shuffled items of all blocks as a flat sequence. + + E.g. [1, 2, 3, 4, 5, 6, 7, 8] with block_size = 3 may yield [3, 1, 2, 4, 6, 5, 8, 7]. + + Args: + source_iterator: checkpointable iterator or restartable iterable over input items to shuffle + block_size: size of the buffer in number of items used for shuffling + seed: random seed used for shuffling (or None) + """ + # This is implemented as a pipeline: + # - group N consecutive items together + # - shuffle them + # - flatten the result + blocks = FixedBatchIterator(source_iterator, batch_size=block_size) + def shuffle_block_fn(block): + _random = Random(seed) + _random.shuffle(block) + return block + shuffled_blocks = MapIterator(blocks, transform=shuffle_block_fn) + # samples = SelectManyNoSkipIterator(shuffled_blocks, collection_selector=lambda shuffled_block: iter(shuffled_block)) + samples = SelectManyIterator(shuffled_blocks, collection_selector=lambda shuffled_block: iter(shuffled_block)) + return samples + + +class IndexIterator(object): + def __init__(self, num): + self.num = num + self.setstate(None) + + def __iter__(self): + return self + + def getstate(self): + return {'num_items_yielded': self._num_items_yielded} + + def setstate(self, checkpoint): + self._num_items_yielded =checkpoint['num_items_yielded'] if checkpoint is not None else 0 + + def __next__(self): + item = self._num_items_yielded % self.num + self._num_items_yielded += 1 + return item + + def close(self): + pass + + +class WeightNoRandomStateIterator(object): + def __init__(self, weights, seed): + self.weights = weights + self.seed = seed + self.control_index = list(range(len(weights))) + self.setstate(None) + + def __iter__(self): + return self + + def getstate(self): + return {'num_items_yielded': self._num_items_yielded} + + def setstate(self, checkpoint): + self._num_items_yielded =checkpoint['num_items_yielded'] if checkpoint is not None else 0 + + def __next__(self): + self._random = Random(int(self.seed) + self._num_items_yielded) + idx = self._random.choices(self.control_index, self.weights)[0] + self._num_items_yielded += 1 + return idx + + def close(self): + pass + + +class SelectManyNoSkipIterator(CheckpointableIterator): + """ + Projects each element of a source sequence to a sequence and flattens the resulting sequences into one sequence. + """ + def __init__(self, source_iterator: CheckpointableIterator, collection_selector=None): + """ + Args: + source_iterator: iterator over the items to pass to collection_selector() + collection_selector: user callback that maps an item into an Iterable, whose items will be yielded. + The returned Iterator is used only once. Hence, it is also allowed to + return self-iterables, such as iterators and generator expressions. + If None is given, no callback is applied. + """ + if not isinstance(source_iterator, CheckpointableIterator): + raise ValueError('source_iterator has to be a CheckpointableIterator') + self._source_iterator = source_iterator # type: CheckpointableIterator + self._collection_selector = collection_selector + self.setstate(None) + + def getstate(self) -> Dict: + return {'source_state': self._source_state, + 'flattened_items_yielded': self._flattened_items_yielded} + + def setstate(self, checkpoint: Optional[Dict]): + self._source_state = checkpoint['source_state'] if checkpoint else None + self._flattened_items_yielded = 0 + self._source_iterator.setstate(self._source_state) + def _generate(): + skip_to_checkpoint = self._flattened_items_yielded + # main loop over source source_items + for source_item in self._source_iterator: + if self._collection_selector is not None: + data = iter(self._collection_selector(source_item)) + else: + data = iter(source_item) + self._flattened_items_yielded = 0 + # if skip_to_checkpoint: + # #print("Skipping to index", skip_to_checkpoint, file=sys.stderr) + # self._flattened_items_yielded += _advance_iterator(data, skip_to_checkpoint) + # skip_to_checkpoint = 0 + # main loop over lines + for item in data: + self._flattened_items_yielded += 1 + yield item + self._source_state = self._source_iterator.getstate() + self._iterator = _generate() + + def __next__(self): + return next(self._iterator) + + def close(self): + self._source_iterator.close() + + +class RawArrayDataset(FairseqDataset): + + def __init__(self, dataset, datatype="token"): + super().__init__() + self.dataset = dataset + self.datatype = datatype + if hasattr(dataset, 'sizes'): + self._sizes = dataset.sizes + else: + try: + self._sizes = np.array([len(x) for x in self.dataset]) + except: + self._sizes = np.array([1 for x in self.dataset]) + + def __getitem__(self, index): + if type(self.dataset[index][0]) != list: + if self.datatype == "token": + return torch.Tensor(self.dataset[index]).long() + else: + return torch.Tensor(self.dataset[index]).bool() + else: + return self.dataset[index] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + if hasattr(self.dataset, 'collater'): + return self.dataset.collater(samples) + else: + raise NotImplementedError() + + @property + def sizes(self): + return self._sizes + + def num_tokens(self, index): + return self.dataset.num_tokens(index) + + def size(self, index): + return self.dataset.size(index) diff --git a/YOCO/yoco/tasks/gpt.py b/YOCO/yoco/tasks/gpt.py new file mode 100644 index 00000000..70dd1028 --- /dev/null +++ b/YOCO/yoco/tasks/gpt.py @@ -0,0 +1,176 @@ +import os +from typing import Optional +import json +from argparse import Namespace +import torch + +from fairseq.tasks import register_task, FairseqDataclass, FairseqTask +from dataclasses import dataclass, field +from omegaconf import II + +from .data.lm_loader import LMLoader +from .data.tiktoken_tokenizer import TiktokenTokenizer +from .data.llama_tokenizer import LLaMATokenizer + + +@dataclass +class GPTLanguageModelingConfig(FairseqDataclass): + data: Optional[str] = field( + default=None, metadata={"help": "path to data directory"} + ) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + max_target_positions: Optional[int] = field( + default=None, metadata={"help": "max number of tokens in the target sequence"} + ) + llama_model: Optional[str] = field( + default=None, + metadata={"help": "path to load tokenizer and config"}, + ) + tiktoken_model: Optional[str] = field( + default=None, + metadata={ + "help": "tiktoken model to tokenize the data" + }, + ) + batch_read_ahead: int = field( + default=10000, + metadata={"help": "batch read ahead size for infinibatch"}, + ) + pad_to_max_len: bool = field( + default=False, + metadata={"help": "pad each sentence to max length"}, + ) + absolute_path: bool = field( + default=False, + metadata={"help": "use absolute path in data config"}, + ) + tokenizer_pad_to_multiple: int = field( + default=8, + metadata={"help": "pad to multiple of this value"}, + ) + seed: int = II("common.seed") + batch_size: Optional[int] = II("dataset.batch_size") + + +@register_task('gpt', dataclass=GPTLanguageModelingConfig) +class GPTPretrainingTask(FairseqTask): + def __init__(self, args, tokenizer): + super().__init__(args) + self.cfg = args + self.tokenizer = tokenizer + + @classmethod + def setup_task(cls, cfg, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + if cfg.llama_model is not None: + tokenizer = LLaMATokenizer(os.path.join(cfg.llama_model, "tokenizer.model")) + elif cfg.tiktoken_model is not None: + tokenizer = TiktokenTokenizer(cfg.tiktoken_model, cfg.tokenizer_pad_to_multiple) + else: + raise ValueError("No tokenizer model provided") + + return cls(cfg, tokenizer) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + self.datasets[split] = { + 'data': json.load(open(f'{self.cfg.data}/json/{split}.json')), + 'data_dir': self.cfg.data, + 'shuffle': True if split == 'train' else False, + } + self.datasets[split] = Namespace(**self.datasets[split]) + + def dataset(self, split): + if split not in self.datasets: + raise KeyError("Dataset not loaded: " + split) + + return self.datasets[split] + + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, + skip_remainder_batch=False, + grouped_shuffling=False, + update_epoch_batch_itr=False + ): + return LMLoader( + self.cfg, + dataset, + self.tokenizer, + max_tokens=max_tokens, + max_sentences=max_sentences, + max_positions=max_positions, + ignore_invalid_inputs=ignore_invalid_inputs, + required_batch_size_multiple=required_batch_size_multiple, + seed=seed, + epoch=epoch, + num_shards=num_shards, + shard_id=shard_id, + ) + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + """ + Do forward and backward, and return the loss as computed by *criterion* + for the given *model* and *sample*. + + Args: + sample (dict): the mini-batch. The format is defined by the + :class:`~fairseq.data.FairseqDataset`. + model (~fairseq.models.BaseFairseqModel): the model + criterion (~fairseq.criterions.FairseqCriterion): the criterion + optimizer (~fairseq.optim.FairseqOptimizer): the optimizer + update_num (int): the current update + ignore_grad (bool): multiply loss by 0 if this is set to True + + Returns: + tuple: + - the loss + - the sample size, which is used as the denominator for the + gradient + - logging outputs to display while training + """ + model.train() + model.set_num_updates(update_num) + with torch.autograd.profiler.record_function("forward"): + loss, sample_size, logging_output = criterion(model, sample) + if ignore_grad: + loss *= 0 + with torch.autograd.profiler.record_function("backward"): + optimizer.backward(loss) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + model.eval() + with torch.no_grad(): + loss, sample_size, logging_output = criterion(model, sample) + return loss, sample_size, logging_output + + @property + def target_dictionary(self): + padding_idx = self.tokenizer.pad_id + class Dict: + def pad(self): + return padding_idx + dictionary = Dict() + return dictionary + diff --git a/YOCO/yoco/tasks/harness_eval.py b/YOCO/yoco/tasks/harness_eval.py new file mode 100644 index 00000000..0b0621ae --- /dev/null +++ b/YOCO/yoco/tasks/harness_eval.py @@ -0,0 +1,151 @@ +import os +from typing import Optional +import logging + +from fairseq.data import ( + IdDataset, + NumSamplesDataset, + NumelDataset, + NestedDictionaryDataset, + NumelDataset, + RightPadDataset, + RawLabelDataset, +) + +from fairseq.tasks import register_task, FairseqDataclass, LegacyFairseqTask +from dataclasses import dataclass, field + +from .data.tiktoken_tokenizer import TiktokenTokenizer +from .data.llama_tokenizer import LLaMATokenizer +from .data.utils import RawArrayDataset + +from .harness_task import HarnessAnlir1, HarnessAnlir2, HarnessAnlir3, HarnessArc_challenge, HarnessArc_easy, HarnessBoolq, HarnessCopa, HarnessOpenbookqa, HarnessPiqa, HarnessRte, HarnessWic, HarnessWinogrande, HarnessHellaswag, HarnessRecord, HarnessTruthfullqaMC1, HarnessTruthfullqaMC2, HarnessSCIQ +from .harness_task import HarnessArc_challenge25s, HarnessHellaswag10s + + +logger = logging.getLogger(__name__) + +task_map = { + "harness_anli_r1": HarnessAnlir1, + "harness_anli_r2": HarnessAnlir2, + "harness_anli_r3": HarnessAnlir3, + "harness_boolq": HarnessBoolq, + "harness_copa": HarnessCopa, + "harness_openbookqa": HarnessOpenbookqa, + "harness_piqa": HarnessPiqa, + "harness_rte": HarnessRte, + "harness_wic": HarnessWic, + "harness_winogrande": HarnessWinogrande, + "harness_hellaswag": HarnessHellaswag, + "harness_arc_challenge": HarnessArc_challenge, + "harness_arc_easy": HarnessArc_easy, + "harness_record": HarnessRecord, + "harness_truthfullqa_mc1": HarnessTruthfullqaMC1, + "harness_truthfullqa_mc2": HarnessTruthfullqaMC2, + "harness_arc_challenge_25s": HarnessArc_challenge25s, + "harness_hellaswag_10s": HarnessHellaswag10s, + "harness_sciq": HarnessSCIQ, +} + +from .mmlu_task import create_mmlu_tasks +mmlu_tasks = create_mmlu_tasks() +task_map.update(mmlu_tasks) + +@dataclass +class HarnessEvalConfig(FairseqDataclass): + data_dir: str = field( + default="/mnt/msranlp/shaohanh/data/fs_eval/harness/", + metadata={"help": "path to data directory"}, + ) + eval_data: str = field(default="", metadata={"help": "dataset name"}) + tokens_per_sample: int = field( + default=2048, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + max_target_positions: Optional[int] = field( + default=None, metadata={"help": "max number of tokens in the target sequence"} + ) + llama_model: Optional[str] = field( + default=None, + metadata={"help": "path to load tokenizer and config"}, + ) + tiktoken_model: Optional[str] = field( + default=None, + metadata={ + "help": "tiktoken model to tokenize the data" + }, + ) + tokenizer_pad_to_multiple: int = field( + default=8, + metadata={"help": "pad to multiple of this value"}, + ) + + +@register_task('harness_eval', dataclass=HarnessEvalConfig) +class HarnessEval(LegacyFairseqTask): + + def __init__(self, cfg, tokenizer): + super().__init__(cfg) + self.cfg = cfg + self.tokenizer = tokenizer + self.harness_task = task_map[self.cfg.eval_data](tokenizer=self.tokenizer, data_dir=cfg.data_dir, tokens_per_sample=cfg.tokens_per_sample) + + @classmethod + def setup_task(cls, cfg, **kwargs): + if cfg.llama_model is not None: + tokenizer = LLaMATokenizer(os.path.join(cfg.llama_model, "tokenizer.model")) + elif cfg.tiktoken_model is not None: + tokenizer = TiktokenTokenizer(cfg.tiktoken_model, cfg.tokenizer_pad_to_multiple) + else: + raise ValueError("No tokenizer model provided") + + return cls(cfg, tokenizer) + + def load_dataset(self, split, combine=False, **kwargs): + src_tokens, gpt_loss_mask, label_length, labels = self.harness_task.get_data_for_evaluation() + + src_tokens = RawArrayDataset(src_tokens) + gpt_loss_mask = RawArrayDataset(gpt_loss_mask, datatype="mask") + label_length = RawLabelDataset(label_length) + label_ids = RawLabelDataset(labels) + ''' + Input format: src_tokens + option_tokens + ''' + data_dict = { + 'id': IdDataset(), + 'net_input': { + 'src_tokens': RightPadDataset( + src_tokens, + pad_idx=self.tokenizer.pad_id, + ), + 'gpt_loss_mask': RightPadDataset( + gpt_loss_mask, + pad_idx=False, + ), + 'label_length': label_length, + 'src_lengths': NumelDataset(src_tokens, reduce=False), + }, + 'targets': label_ids, + 'nsentences': NumSamplesDataset(), + 'ntokens': NumelDataset(src_tokens, reduce=True), + } + dataset = NestedDictionaryDataset( + data_dict, + sizes=[src_tokens.sizes], + ) + + print('| Loaded {} with {} samples'.format(split, len(dataset))) + + self.datasets[split] = dataset + return self.datasets[split] + + @property + def target_dictionary(self): + padding_idx = self.tokenizer.pad_id + class Dict: + def pad(self): + return padding_idx + dictionary = Dict() + return dictionary + + \ No newline at end of file diff --git a/YOCO/yoco/tasks/harness_task.py b/YOCO/yoco/tasks/harness_task.py new file mode 100644 index 00000000..3e87f96a --- /dev/null +++ b/YOCO/yoco/tasks/harness_task.py @@ -0,0 +1,289 @@ +import json +import numpy as np + +class HarnessBaseTask: + def __init__(self, tokenizer, data_dir, tokens_per_sample=1024): + self.tokenizer = tokenizer + self.class_num = 1 + self.tokens_per_sample = tokens_per_sample + self.base_dir = data_dir + self.set_dataname() + self.set_class_num() + self.dataset = self.load_data() + + def load_data(self): + import os + datasets = [] + with open(os.path.join(self.base_dir, self.dataname), "r", encoding='utf-8') as fin: + for line in fin: + obj = json.loads(line) + datasets.append( + { + "text": obj["ctx"] if "ctx" in obj else None, + "label": obj["label"] if "label" in obj else None, + "choices": obj["choices"] if "choices" in obj else [], + "gold": obj["gold"] if "gold" in obj else None, + "raw": obj, + } + ) + return datasets + + def set_class_num(self): + raise NotImplementedError + + def set_dataname(self): + raise NotImplementedError + + def preprocess_example(self, example): + raise NotImplementedError + + def get_data_for_evaluation(self): + src_tokens = [] + gpt_loss_mask = [] + label_length = [] + labels = [] + cut_num = 0 + for i, example in enumerate(self.dataset): + input_str, label_str, label = self.preprocess_example(example) + if i < 2: + print(f"input str is {input_str}") + print(f"label str is {label_str}") + + for j in range(len(input_str)): + sub_input_str, sub_label_str = input_str[j], label_str[j] + input_token = self.tokenizer.encode(sub_input_str) + label_token = self.tokenizer.encode(sub_input_str + sub_label_str)[len(input_token):] + if len(input_token) + len(label_token) + 1 >= self.tokens_per_sample: + cut_num += 1 + input_token = input_token[-(self.tokens_per_sample - len(label_token) - 1):] + + src_tokens.append([self.tokenizer.bos_id] + input_token + label_token) + gpt_loss_mask.append([False] * (len(input_token) + 1) + [True] * len(label_token)) + label_length.append(len(sub_label_str.strip())) + labels.append(label) + + if cut_num > 0: + print(f"cut {cut_num} examples") + + return np.array(src_tokens), np.array(gpt_loss_mask), np.array(label_length), np.array(labels) + + +class HarnessAnlir1(HarnessBaseTask): + def set_class_num(self): + self.class_num = 3 + + def set_dataname(self): + self.dataname = "anli_r1" + + def preprocess_example(self, example): + input_str = [example["text"]] * self.class_num + answer_str = [" True", " Neither", " False"] + label = example["label"] + return input_str, answer_str, label + +class HarnessAnlir2(HarnessAnlir1): + def set_dataname(self): + self.dataname = "anli_r2" + +class HarnessAnlir3(HarnessAnlir1): + def set_dataname(self): + self.dataname = "anli_r3" + +class HarnessArc_challenge(HarnessBaseTask): + ''' + using harness to evaluate arc challenge + ''' + def set_class_num(self): + self.class_num = 5 + + def set_dataname(self): + self.dataname = "arc_challenge" + + def preprocess_example(self, example): + input_str = [example["text"]] * len(example["choices"]) + answer_str = [' ' + item for item in example["choices"]] + label = example["gold"] + return input_str, answer_str, label + +class HarnessArc_challenge25s(HarnessBaseTask): + ''' + using harness to evaluate arc challenge + ''' + def set_class_num(self): + self.class_num = 5 + + def set_dataname(self): + self.dataname = "arc_challenge_25s" + + def preprocess_example(self, example): + input_str = [example["text"]] * len(example["choices"]) + answer_str = [' ' + item for item in example["choices"]] + label = example["gold"] + return input_str, answer_str, label + +class HarnessArc_easy(HarnessArc_challenge): + def set_class_num(self): + self.class_num = 5 + + def set_dataname(self): + self.dataname = "arc_easy" + +class HarnessBoolq(HarnessBaseTask): + def set_class_num(self): + self.class_num = 2 + + def set_dataname(self): + self.dataname = "boolq" + + def preprocess_example(self, example): + input_str = [example["text"]] * self.class_num + answer_str = [" no", " yes"] + label = example["label"] + return input_str, answer_str, label + +class HarnessCopa(HarnessBaseTask): + def set_class_num(self): + self.class_num = 2 + + def set_dataname(self): + self.dataname = "copa" + + def preprocess_example(self, example): + input_str = [example["text"]] * self.class_num + answer_str = [' ' + example['raw']['choice1'], ' ' + example['raw']['choice2']] + label = example["label"] + return input_str, answer_str, label + +class HarnessOpenbookqa(HarnessArc_challenge): + def set_class_num(self): + self.class_num = 4 + + def set_dataname(self): + self.dataname = "openbookqa" + +class HarnessPiqa(HarnessArc_challenge): + def set_class_num(self): + self.class_num = 2 + + def set_dataname(self): + self.dataname = "piqa" + +class HarnessRte(HarnessBaseTask): + def set_class_num(self): + self.class_num = 2 + + def set_dataname(self): + self.dataname = "rte" + + def preprocess_example(self, example): + input_str = [example["text"]] * self.class_num + answer_str = [' True', ' False'] + label = example["label"] + return input_str, answer_str, label + +class HarnessWic(HarnessRte): + def set_dataname(self): + self.dataname = "wic" + +class HarnessWinogrande(HarnessBaseTask): + def set_class_num(self): + self.class_num = 2 + + def set_dataname(self): + self.dataname = "winogrande" + + def preprocess_example(self, example): + pronoun_loc = example['raw']['sentence'].index("_") + input_str = [] + input_str.append(example['raw']['sentence'][:pronoun_loc].strip() + ' ' + example['raw']['option1']) + input_str.append(example['raw']['sentence'][:pronoun_loc].strip() + ' ' + example['raw']['option2']) + answer_str = [" " + example['raw']["sentence"][pronoun_loc + 1:].strip()] * self.class_num + label = int(example['raw']['answer']) - 1 + return input_str, answer_str, label + +class HarnessHellaswag(HarnessBaseTask): + def set_class_num(self): + self.class_num = 4 + + def set_dataname(self): + self.dataname = "hellaswag" + + def preprocess_example(self, example): + input_str = [example["text"]] * self.class_num + answer_str = [' ' + item for item in example["choices"]] + label = example["gold"] + return input_str, answer_str, label + + +class HarnessHellaswag10s(HarnessBaseTask): + def set_class_num(self): + self.class_num = 4 + + def set_dataname(self): + self.dataname = "hellaswag_10s" + + def preprocess_example(self, example): + input_str = [example["text"]] * self.class_num + answer_str = [' ' + item for item in example["choices"]] + label = example["gold"] + return input_str, answer_str, label + + +class HarnessTruthfullqaMC1(HarnessBaseTask): + def set_class_num(self): + self.class_num = 1 + + def set_dataname(self): + self.dataname = "truthfulqa_mc" + + def preprocess_example(self, example): + input_str = [example["text"]] * len(example["raw"]["mc1_targets"]["choices"]) + answer_str = [' ' + item for item in example["raw"]["mc1_targets"]["choices"]] + label = 0 # dummy label + return input_str, answer_str, label + + + +class HarnessTruthfullqaMC2(HarnessBaseTask): + def set_class_num(self): + self.class_num = 1 + + def set_dataname(self): + self.dataname = "truthfulqa_mc" + + def preprocess_example(self, example): + input_str = [example["text"]] * len(example["raw"]["mc2_targets"]["choices"]) + answer_str = [' ' + item for item in example["raw"]["mc2_targets"]["choices"]] + label = 0 # dummy label + return input_str, answer_str, label + + +class HarnessRecord(HarnessBaseTask): + def set_class_num(self): + self.class_num = 1 + + def set_dataname(self): + self.dataname = "record" + + def preprocess_example(self, example): + input_str = [example["text"]] * len(example["raw"]["entities"]) + answer_str = [f' - {example["raw"]["query"]}'.replace("@placeholder", item) for item in example["raw"]["entities"]] + label = 0 # dummy label + return input_str, answer_str, label + +class HarnessSCIQ(HarnessBaseTask): + def set_class_num(self): + self.class_num = 4 + + def set_dataname(self): + self.dataname = "sciq" + + def preprocess_example(self, example): + input_str = [example["text"]] * self.class_num + answer_str = [' ' + example["raw"]["distractor1"], + ' ' + example["raw"]["distractor2"], + ' ' + example["raw"]["distractor3"], + ' ' + example["raw"]["correct_answer"] + ] + label = 3 + return input_str, answer_str, label \ No newline at end of file diff --git a/YOCO/yoco/tasks/mmlu_task.py b/YOCO/yoco/tasks/mmlu_task.py new file mode 100644 index 00000000..a93476c7 --- /dev/null +++ b/YOCO/yoco/tasks/mmlu_task.py @@ -0,0 +1,92 @@ +from .harness_task import HarnessBaseTask + + +SUBJECTS = [ + "abstract_algebra", + "anatomy", + "astronomy", + "business_ethics", + "clinical_knowledge", + "college_biology", + "college_chemistry", + "college_computer_science", + "college_mathematics", + "college_medicine", + "college_physics", + "computer_security", + "conceptual_physics", + "econometrics", + "electrical_engineering", + "elementary_mathematics", + "formal_logic", + "global_facts", + "high_school_biology", + "high_school_chemistry", + "high_school_computer_science", + "high_school_european_history", + "high_school_geography", + "high_school_government_and_politics", + "high_school_macroeconomics", + "high_school_mathematics", + "high_school_microeconomics", + "high_school_physics", + "high_school_psychology", + "high_school_statistics", + "high_school_us_history", + "high_school_world_history", + "human_aging", + "human_sexuality", + "international_law", + "jurisprudence", + "logical_fallacies", + "machine_learning", + "management", + "marketing", + "medical_genetics", + "miscellaneous", + "moral_disputes", + "moral_scenarios", + "nutrition", + "philosophy", + "prehistory", + "professional_accounting", + "professional_law", + "professional_medicine", + "professional_psychology", + "public_relations", + "security_studies", + "sociology", + "us_foreign_policy", + "virology", + "world_religions", +] + + +def create_mmlu_tasks(): + """Creates a dictionary of tasks from a list of subjects + :return: {task_name: task} + e.g. {hendrycksTest-abstract_algebra: Task, hendrycksTest-anatomy: Task} + """ + return {f"hendrycksTest-{sub}": create_task(f"hendrycksTest-{sub}") for sub in SUBJECTS} + + +def create_task(subject): + class HendrycksTest(GeneralHendrycksTest): + def set_dataname(self): + self.dataname = f"{subject}" + + return HendrycksTest + +class GeneralHendrycksTest(HarnessBaseTask): + def set_class_num(self): + self.class_num = 4 + + def preprocess_example(self, example): + # find the last occurence of "Queston:" in example["text"], and remove everything before it + # this is to remove the context + # last_question = example["text"].rfind("Question:") + # example["text"] = example["text"][last_question:] + input_str = [example["text"]] * self.class_num + answer_str = [' ' + item for item in example["choices"]] + label = example["gold"] + return input_str, answer_str, label diff --git a/YOCO/yoco/tasks/pseudo.py b/YOCO/yoco/tasks/pseudo.py new file mode 100644 index 00000000..87b51a12 --- /dev/null +++ b/YOCO/yoco/tasks/pseudo.py @@ -0,0 +1,202 @@ +import os +from typing import Optional +import torch + +from fairseq.data import FairseqDataset +from fairseq.tasks import register_task, FairseqDataclass, LegacyFairseqTask +from dataclasses import dataclass, field +from omegaconf import II + +from .data.tiktoken_tokenizer import TiktokenTokenizer +from .data.llama_tokenizer import LLaMATokenizer + + +class PseudoIterator(FairseqDataset): + def __init__(self, batch_size, length, vocab_size): + super().__init__() + self.batch_size = batch_size + self.length = length + self.vocab_size = vocab_size + + self.epoch = 1 + self.next_epoch_idx = 1 + self.sharded_checkpoint = True + self.should_close_after_finished = True + + def __iter__(self): + while True: + yield self.__next__() + + def __next__(self): + net_input = torch.randint(size=(self.batch_size, self.length), dtype=torch.long, low=0, high=self.vocab_size - 1) + return { + "net_input": {"src_tokens": net_input}, + "target": net_input, + "ntokens": self.batch_size * self.length, + } + + def __len__(self) -> int: + return 819200000 + + def next_epoch_itr(self, **kwargs): + return self + + @property + def first_batch(self): + return "DUMMY" + + def end_of_epoch(self) -> bool: + return False + + def state_dict(self): + return None + + def load_state_dict(self, state_dict): + pass + + def setstate(self, value): + pass + + def getstate(self): + pass + + def close(self): + pass + +@dataclass +class PseudoConfig(FairseqDataclass): + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + max_target_positions: Optional[int] = field( + default=None, metadata={"help": "max number of tokens in the target sequence"} + ) + llama_model: Optional[str] = field( + default=None, + metadata={"help": "path to load tokenizer and config"}, + ) + tiktoken_model: Optional[str] = field( + default=None, + metadata={ + "help": "tiktoken model to tokenize the data" + }, + ) + batch_read_ahead: int = field( + default=10000, + metadata={"help": "batch read ahead size for infinibatch"}, + ) + pad_to_max_len: bool = field( + default=False, + metadata={"help": "pad each sentence to max length"}, + ) + absolute_path: bool = field( + default=False, + metadata={"help": "use absolute path in data config"}, + ) + tokenizer_pad_to_multiple: int = field( + default=8, + metadata={"help": "pad to multiple of this value"}, + ) + seed: int = II("common.seed") + batch_size: Optional[int] = II("dataset.batch_size") + + +@register_task('pseudo', dataclass=PseudoConfig) +class PseudoTask(LegacyFairseqTask): + def __init__(self, args, tokenizer): + super().__init__(args) + self.cfg = args + self.tokenizer = tokenizer + + @classmethod + def setup_task(cls, cfg, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + if cfg.llama_model is not None: + tokenizer = LLaMATokenizer(os.path.join(cfg.llama_model, "tokenizer.model")) + elif cfg.tiktoken_model is not None: + tokenizer = TiktokenTokenizer(cfg.tiktoken_model, cfg.tokenizer_pad_to_multiple) + else: + raise ValueError("No tokenizer model provided") + + return cls(cfg, tokenizer) + + def load_dataset(self, split, **kwargs): + pass + # self.datasets[split] = None + + def dataset(self, split): + return None + + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, + skip_remainder_batch=False, + grouped_shuffling=False, + update_epoch_batch_itr=False + ): + return PseudoIterator(max_sentences, self.cfg.tokens_per_sample, 10000) + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + """ + Do forward and backward, and return the loss as computed by *criterion* + for the given *model* and *sample*. + + Args: + sample (dict): the mini-batch. The format is defined by the + :class:`~fairseq.data.FairseqDataset`. + model (~fairseq.models.BaseFairseqModel): the model + criterion (~fairseq.criterions.FairseqCriterion): the criterion + optimizer (~fairseq.optim.FairseqOptimizer): the optimizer + update_num (int): the current update + ignore_grad (bool): multiply loss by 0 if this is set to True + + Returns: + tuple: + - the loss + - the sample size, which is used as the denominator for the + gradient + - logging outputs to display while training + """ + model.train() + model.set_num_updates(update_num) + with torch.autograd.profiler.record_function("forward"): + loss, sample_size, logging_output = criterion(model, sample) + if ignore_grad: + loss *= 0 + with torch.autograd.profiler.record_function("backward"): + optimizer.backward(loss) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + model.eval() + with torch.no_grad(): + loss, sample_size, logging_output = criterion(model, sample) + return loss, sample_size, logging_output + + @property + def target_dictionary(self): + padding_idx = self.tokenizer.pad_id + class Dict: + def pad(self): + return padding_idx + dictionary = Dict() + return dictionary \ No newline at end of file diff --git a/YOCO/yoco/train.py b/YOCO/yoco/train.py new file mode 100644 index 00000000..ee6615d8 --- /dev/null +++ b/YOCO/yoco/train.py @@ -0,0 +1,7 @@ +import models +import tasks +import criterions +from fairseq_cli.train import cli_main + +if __name__ == "__main__": + cli_main() diff --git a/YOCO/yoco/validate.py b/YOCO/yoco/validate.py new file mode 100644 index 00000000..e3815ca6 --- /dev/null +++ b/YOCO/yoco/validate.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Train a new model on one or across multiple GPUs. +""" +import models +import tasks +import criterions + +import argparse +import logging +import math +import os +import sys +from typing import Any, Callable, Dict, List, Optional, Tuple + +# We need to setup root logger before importing any fairseq libraries. +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.train") + +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf + +from fairseq import checkpoint_utils, options, quantization_utils, tasks, utils +from fairseq.data import data_utils, iterators +from fairseq.data.plasma_utils import PlasmaStore +from fairseq.dataclass.configs import FairseqConfig +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap +from fairseq.distributed import utils as distributed_utils +from fairseq.file_io import PathManager +from fairseq.logging import meters, metrics, progress_bar + + +def main(cfg: FairseqConfig) -> None: + if isinstance(cfg, argparse.Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + utils.import_user_module(cfg.common) + + if ( + distributed_utils.is_master(cfg.distributed_training) + and "job_logging_cfg" in cfg + ): + # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) + logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) + + assert ( + cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None + ), "Must specify batch size either with --max-tokens or --batch-size" + metrics.reset() + + if cfg.common.log_file is not None: + handler = logging.FileHandler(filename=cfg.common.log_file) + logger.addHandler(handler) + + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) + + # if distributed_utils.is_master(cfg.distributed_training): + # checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) + + # Print args + logger.info(cfg) + + if cfg.checkpoint.write_checkpoints_asynchronously: + try: + import iopath # noqa: F401 + except ImportError: + logging.exception( + "Asynchronous checkpoint writing is specified but iopath is " + "not installed: `pip install iopath`" + ) + return + + # Setup task, e.g., translation, language modeling, etc. + task = tasks.setup_task(cfg.task) + + assert cfg.criterion, "Please specify criterion to train a model" + + # Build model and criterion + if cfg.distributed_training.ddp_backend == "fully_sharded": + with fsdp_enable_wrap(cfg.distributed_training): + model = fsdp_wrap(task.build_model(cfg.model)) + else: + model = task.build_model(cfg.model) + criterion = task.build_criterion(cfg.criterion) + + tpu = cfg.common.tpu + cuda = torch.cuda.is_available() and not cfg.common.cpu and not tpu + if cuda: + device = torch.device("cuda") + elif tpu: + device = utils.get_tpu_device() + else: + device = torch.device("cpu") + if cfg.common.fp16: + criterion = criterion.half() + model = model.half() + elif cfg.common.bf16: + criterion = criterion.to(dtype=torch.bfloat16) + model = model.to(dtype=torch.bfloat16) + criterion = criterion.to(device) + model = model.to(device) + + logger.info(model) + logger.info("task: {}".format(task.__class__.__name__)) + logger.info("model: {}".format(model.__class__.__name__)) + logger.info("criterion: {}".format(criterion.__class__.__name__)) + logger.info( + "num. shared model params: {:,} (num. trained: {:,})".format( + sum( + p.numel() for p in model.parameters() if not getattr(p, "expert", False) + ), + sum( + p.numel() + for p in model.parameters() + if not getattr(p, "expert", False) and p.requires_grad + ), + ) + ) + + logger.info( + "num. expert model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)), + sum( + p.numel() + for p in model.parameters() + if getattr(p, "expert", False) and p.requires_grad + ), + ) + ) + + # Load valid dataset (we load training data below, based on the latest checkpoint) + # We load the valid dataset AFTER building the model + data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg) + if cfg.dataset.combine_valid_subsets: + task.load_dataset("valid", combine=True, epoch=1) + else: + for valid_sub_split in cfg.dataset.valid_subset.split(","): + task.load_dataset(valid_sub_split, combine=False, epoch=1) + + # Load the latest checkpoint if one is available and restore the + # corresponding train iterator + # try: + # state_dict = torch.load(cfg.checkpoint.restore_file) + # model.load_state_dict(state_dict['model']) + # print(f"Loaded model from {cfg.checkpoint.restore_file}") + # except Exception as e: + # print(e) + # print(f"No checkpoint found from {cfg.checkpoint.restore_file}") + + valid_subsets = cfg.dataset.valid_subset.split(",") + logger.info("Start validating") + + validate( + cfg, task, model, criterion, valid_subsets, + ) + +@torch.no_grad() +def validate( + cfg: DictConfig, + task: tasks.FairseqTask, + model, + criterion, + subsets: List[str], +) -> List[Optional[float]]: + """Evaluate the model on the validation set(s) and return the losses.""" + if cfg.dataset.fixed_validation_seed is not None: + # set fixed seed for every validation + utils.set_torch_seed(cfg.dataset.fixed_validation_seed) + + valid_losses = [] + for subset in subsets: + logger.info('begin validation on "{}" subset'.format(subset)) + + # Initialize data iterator + itr = task.get_batch_iterator( + dataset=task.dataset(subset), + max_tokens=cfg.dataset.max_tokens_valid, + max_sentences=cfg.dataset.batch_size_valid, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_workers=cfg.dataset.num_workers_valid, + # always pass a fixed "epoch" to keep validation data consistent + # across training epochs + epoch=1, + data_buffer_size=cfg.dataset.data_buffer_size, + ).next_epoch_itr( + shuffle=False, set_dataset_epoch=False # use a fixed valid set + ) + if cfg.common.tpu: + itr = utils.tpu_data_loader(itr) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + prefix=f"valid on '{subset}' subset", + tensorboard_logdir=( + cfg.common.tensorboard_logdir + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + wandb_project=( + cfg.common.wandb_project + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + wandb_run_name=os.environ.get( + "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) + ), + ) + + # create a new root metrics aggregator so validation metrics + # don't pollute other aggregators (e.g., train meters) + with metrics.aggregate(new_root=True) as agg: + logging_outputs = [] + for i, sample in enumerate(progress): + if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps: + break + sample = utils.move_to_cuda(sample) + _, _, inner_logging_outputs = task.valid_step( + sample, model, criterion + ) + logging_outputs.append(inner_logging_outputs) + task.reduce_metrics(logging_outputs, criterion) + + # with metrics.aggregate(new_root=True) as agg: + # for i, sample in enumerate(progress): + # if ( + # cfg.dataset.max_valid_steps is not None + # and i > cfg.dataset.max_valid_steps + # ): + # break + # trainer.valid_step(sample) + + stats = get_valid_stats(cfg, agg.get_smoothed_values()) + + progress.print(stats, tag=subset) + + valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) + return valid_losses + + +def get_valid_stats( + cfg: DictConfig, stats: Dict[str, Any] +) -> Dict[str, Any]: + if hasattr(checkpoint_utils.save_checkpoint, "best"): + key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric) + best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min + stats[key] = best_function( + checkpoint_utils.save_checkpoint.best, + stats[cfg.checkpoint.best_checkpoint_metric], + ) + return stats + + +def cli_main( + modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None +) -> None: + parser = options.get_training_parser() + args = options.parse_args_and_arch(parser, modify_parser=modify_parser) + + cfg = convert_namespace_to_omegaconf(args) + + if cfg.common.use_plasma_view: + server = PlasmaStore(path=cfg.common.plasma_path) + logger.info( + f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}" + ) + + if args.profile: + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + distributed_utils.call_main(cfg, main) + else: + distributed_utils.call_main(cfg, main) + + # if cfg.common.use_plasma_view: + # server.server.kill() + + +if __name__ == "__main__": + cli_main() \ No newline at end of file