In [80]:
import torch
from transformers import LlamaTokenizer, AutoTokenizer
from transformers.models.llama import LlamaForCausalLM, LlamaConfig
from transformers import BitsAndBytesConfig

quant_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=False,
   bnb_4bit_compute_dtype=torch.bfloat16
)


In [81]:
quant_config

BitsAndBytesConfig {
  "_load_in_4bit": true,
  "_load_in_8bit": false,
  "bnb_4bit_compute_dtype": "bfloat16",
  "bnb_4bit_quant_storage": "uint8",
  "bnb_4bit_quant_type": "nf4",
  "bnb_4bit_use_double_quant": false,
  "llm_int8_enable_fp32_cpu_offload": false,
  "llm_int8_has_fp16_weight": false,
  "llm_int8_skip_modules": null,
  "llm_int8_threshold": 6.0,
  "load_in_4bit": true,
  "load_in_8bit": false,
  "quant_method": "bitsandbytes"
}

In [82]:
model_config = LlamaConfig.from_pretrained("./llama-10m.json")


In [83]:

model = LlamaForCausalLM(model_config).to("cuda")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") 



In [84]:
model.save_pretrained("llama-10m")

In [85]:
m2 = LlamaForCausalLM.from_pretrained("llama-10m", quantization_config=quant_config)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


In [86]:
m2

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 128, padding_idx=31999)
    (layers): ModuleList(
      (0-3): 4 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=128, out_features=128, bias=False)
          (k_proj): Linear4bit(in_features=128, out_features=128, bias=False)
          (v_proj): Linear4bit(in_features=128, out_features=128, bias=False)
          (o_proj): Linear4bit(in_features=128, out_features=128, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=128, out_features=352, bias=False)
          (up_proj): Linear4bit(in_features=128, out_features=352, bias=False)
          (down_proj): Linear4bit(in_features=352, out_features=128, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): Llam

In [87]:
tokenizer.pad_token = tokenizer.unk_token

In [88]:
#@title Show final memory and time stats
def get_max_memory_reserved():
    mem = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    print(f"Peak reserved memory = {mem} GB.")
    return mem


In [89]:
bs, seqlen, in_features = 1, 16, 4096

input_ids = torch.randint(0, model.config.vocab_size, (bs, seqlen), device="cuda")
labels = input_ids.detach().clone()
attention_mask = torch.ones((bs, seqlen), device="cuda")

get_max_memory_reserved()

Peak reserved memory = 0.213 GB.


0.213

In [90]:
import types
from typing import List, Optional, Tuple, Union

from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss

def forward_fused(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            if self.config.use_fused_cel:
                print("Using fused cross entropy loss")
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


In [91]:
mem_stats = torch.cuda.memory_stats_as_nested_dict()
mem_stats.keys()

dict_keys(['num_alloc_retries', 'num_ooms', 'max_split_size', 'num_sync_all_streams', 'num_device_alloc', 'num_device_free', 'allocation', 'segment', 'active', 'inactive_split', 'allocated_bytes', 'reserved_bytes', 'active_bytes', 'inactive_split_bytes', 'requested_bytes', 'oversize_allocations', 'oversize_segments'])

In [92]:
# model.config.update({
#     "use_fused_cel": True})
# model.forward = types.MethodType(forward_fused, model)

In [9]:
#out = model(input_ids, labels=labels, attention_mask=attention_mask)

Using fused cross entropy loss


In [10]:
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

EOS_TOKEN = tokenizer.eos_token
def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs       = examples["input"]
    outputs      = examples["output"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
        texts.append(text)
    return { "text" : texts, }
pass


In [26]:

from datasets import load_dataset
dataset = load_dataset("yahma/alpaca-cleaned", split = "train")
dataset = dataset.map(formatting_prompts_func, batched = True,)

In [27]:
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  86137 KiB | 124095 KiB | 127737 KiB |  41599 KiB |
|       from large pool |  77248 KiB | 110016 KiB | 110016 KiB |  32768 KiB |
|       from small pool |   8889 KiB |  14079 KiB |  17721 KiB |   8831 KiB |
|---------------------------------------------------------------------------|
| Active memory         |  86137 KiB | 124095 KiB | 127737 KiB |  41599 KiB |
|       from large pool |  77248 KiB | 110016 KiB | 110016 KiB |  32768 KiB |
|       from small pool |   8889 KiB |  14079 KiB |  17721 KiB |   8831 KiB |
|---------------------------------------------------------------

In [93]:
from trl import SFTTrainer
from transformers import TrainingArguments, Trainer
max_seq_length = 256

training_args = TrainingArguments(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 1,
        warmup_steps = 5,
        max_steps = 5,
        learning_rate = 2e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        #Metrics
        skip_memory_metrics=False,
        include_num_input_tokens_seen=True,
        include_tokens_per_second=True,
    )
 
# trainer = Trainer(
#     model = model,
#     tokenizer = tokenizer,
#     train_dataset = dataset,
#     dataset_text_field = "text",
#     max_seq_length = max_seq_length,
#     dataset_num_proc = 2,
#     packing = False, # Can make training 5x faster for short sequences.
#     args = training_args)

In [94]:
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = training_args
)

max_steps is given, it will override any value given in num_train_epochs


In [95]:
#trainer.args

In [96]:
trainer_stats = trainer.train()

Step,Training Loss
1,10.3789
2,10.4117
3,10.3826
4,10.3294
5,10.178


In [97]:
trainer.log_metrics("train", trainer_stats.metrics)

***** train metrics *****
  before_init_mem_cpu        =     1101MB
  before_init_mem_gpu        =      179MB
  epoch                      =     0.0001
  init_mem_cpu_alloc_delta   =        0MB
  init_mem_cpu_peaked_delta  =        0MB
  init_mem_gpu_alloc_delta   =        0MB
  init_mem_gpu_peaked_delta  =        0MB
  num_input_tokens_seen      =        908
  total_flos                 =       24GF
  train_loss                 =    10.3361
  train_mem_cpu_alloc_delta  =        0MB
  train_mem_cpu_peaked_delta =        0MB
  train_mem_gpu_alloc_delta  =       18MB
  train_mem_gpu_peaked_delta =      112MB
  train_runtime              = 0:00:01.49
  train_samples_per_second   =      3.344
  train_steps_per_second     =      3.344
  train_tokens_per_second    =    357.846


In [98]:
trainer.save_metrics("train", trainer_stats.metrics)

In [38]:
trainer.num_examples(trainer.get_train_dataloader())

51760

In [39]:
trainer.num_tokens(trainer.get_train_dataloader())

9433794

In [41]:
trainer.model is trainer.model_wrapped

True

In [42]:
trainer.accelerator.state

Distributed environment: DistributedType.NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: bf16

In [61]:
from accelerate.utils import get_mixed_precision_context_manager, convert_outputs_to_fp32

In [55]:
autocast_context = get_mixed_precision_context_manager("bf16")
a, b = torch.randn(128, 128, device="cuda", dtype=torch.float16), torch.randn(128, 128, device="cuda", dtype=torch.float16)
test_fn = lambda a, b: a @ b


In [62]:
wrapped = autocast_context(convert_outputs_to_fp32(test_fn))
wrapped_bf16 = torch.autocast(device_type="cuda", dtype=torch.bfloat16)(test_fn)

In [63]:
out = wrapped(a, b)
out_bf16 = wrapped_bf16(a, b)

In [64]:
out.dtype

torch.float32

In [65]:
out_bf16.dtype

torch.bfloat16

In [99]:
from torch.utils.flop_counter import FlopCounterMode

#fc= FlopCounterMode(model)
batch = next(iter(trainer.get_train_dataloader()))

In [100]:
with FlopCounterMode() as fc:
    out = model(**batch)

Module                                            FLOP    % Total
-------------------------------------------  ---------  ---------
Global                                       1071.808M    100.00%
 - aten.mm                                   1048.347M     97.81%
 - aten.bmm                                     0.014M      0.00%
 - aten._scaled_dot_product_flash_attention    23.448M      2.19%


In [101]:
fc.get_total_flops() * 5 * 3

16077118080

In [None]:
from transformers.utils import flop