In [11]:
%load_ext autoreload
from transformers import TrainingArguments, Trainer, LlamaTokenizerFast, LlamaTokenizer, LlamaModel, LlamaForCausalLM, LlamaConfig
from transformers import DataCollatorForLanguageModeling
import transformers
import numpy as np
import evaluate
import datasets
from replacer import replace_linears_in_hf
from bitnet1 import BitLinear1B
from bitnet158 import BitLinear158B

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
# tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
# tokenizer.pad_token = tokenizer.eos_token

# data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

dataset = datasets.load_dataset("wikitext", "wikitext-2-v1", split="train[:10]")
dataset = dataset.train_test_split(test_size=0.2)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

class TokenizeWrapper:
    def __init__(self) -> None:
        self.tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def tokenize_function(self, examples):
        return self.tokenizer(examples["text"], padding='max_length', truncation=True,)

tokenize_wrapper = TokenizeWrapper()

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenize_wrapper.tokenizer, mlm=False)

train_dataset = train_dataset.map(tokenize_wrapper.tokenize_function, batched=True, num_proc=4)
eval_dataset = eval_dataset.map(tokenize_wrapper.tokenize_function, batched=True, num_proc=4)

Map (num_proc=4):   0%|          | 0/8 [00:00<?, ? examples/s]

num_proc must be <= 2. Reducing num_proc to 2 for dataset of size 2.


Map (num_proc=2):   0%|          | 0/2 [00:00<?, ? examples/s]

In [28]:
# configuration = LlamaConfig(
#     intermediate_size=1024,
#     hidden_size=1024,
#     num_hidden_layers=4,
#     num_attention_heads=4,
#     max_position_embeddings=1024,
# )

configuration = LlamaConfig(
    intermediate_size=4,
    hidden_size=4,
    num_hidden_layers=4,
    num_attention_heads=4,
    max_position_embeddings=1024,
)

model = LlamaForCausalLM(configuration)
# model = LlamaModel(configuration)

# model = replace_linears_in_hf(model, BitLinear=BitLinear158B)

training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")

metric = evaluate.load("perplexity", module_type="metric")

def compute_metrics(eval_pred: transformers.EvalPrediction):
    print("eval_pred", eval_pred)
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions)

In [29]:

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [31]:
trainer.train()

  0%|          | 0/3 [00:00<?, ?it/s]

eval_pred <transformers.trainer_utils.EvalPrediction object at 0x000001DC34926630>
