In [39]:
import torch
from huggingface_hub import login
from transformers import AutoConfig, LlamaForCausalLM, AutoTokenizer, TrainingArguments
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import DataCollatorForLanguageModeling, TrainingArguments, Trainer

In [19]:
login(token="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")
model_1b = 'meta-llama/Llama-3.2-1B'
model_8b = 'meta-llama/Llama-3-8B'
# Load the configuration for the model (without weights)
config = AutoConfig.from_pretrained(model_1b)

# Initialize the model with this configuration
model = LlamaForCausalLM(config)
tokenizer = AutoTokenizer.from_pretrained(model_1b)
tokenizer.pad_token = tokenizer.eos_token


In [3]:
dataset = load_dataset('Salesforce/wikitext', 'wikitext-103-v1')

Generating test split: 100%|██████████| 4358/4358 [00:00<00:00, 68826.92 examples/s]
Generating train split: 100%|██████████| 1801350/1801350 [00:04<00:00, 363745.12 examples/s]
Generating validation split: 100%|██████████| 3760/3760 [00:00<00:00, 276153.66 examples/s]


In [54]:
def tokenize_function(examples, tokenizer, max_length):
    return tokenizer(examples["text"], padding='max_length', max_length=max_length, truncation=True)


In [55]:
class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # Tokenize the text on-the-fly
        encoding = tokenize_function(self.texts[idx], self.tokenizer, self.max_length)
        
        return {
            "input_ids": encoding['input_ids'],
            "labels": encoding['input_ids'],
            "attention_mask": encoding['attention_mask'],
        }
train_data = TextDataset(dataset['train'], tokenizer, 512)
eval_data = TextDataset(dataset['validation'], tokenizer, 512)
test_data = TextDataset(dataset['test'], tokenizer, 512)


In [32]:

import wandb
#os.environ["WANDB_API_KEY"] = "ae05f44c8d5afe19940ef81e6f5cf69063392241"
wandb.login()  # Log in directly without setting env variable
wandb.init(project='llama-training', entity='fjiang7-ucsd')




In [56]:
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="steps",
    eval_steps=30,
    learning_rate=2e-5,
    per_device_train_batch_size=8, #maybe try larger batch size
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.1,
    gradient_accumulation_steps=6,
    report_to="wandb",
    logging_dir='./logs',  # directory for storing logs
    logging_steps=50,
    save_steps=500,
    fp16=True,
    gradient_checkpointing=True,
)



In [57]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=train_data,
    eval_dataset=eval_data,
    data_collator=data_collator
)

  trainer = Trainer(


In [58]:
train_stats = trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss,Validation Loss
30,No log,8.344831
60,8.970200,7.386399
90,8.970200,6.913524
120,7.315200,6.694263
150,6.859100,6.557484
180,6.859100,6.465035
210,6.690500,6.385707
240,6.690500,6.316482
270,6.558500,6.249198
300,6.436500,6.198709


KeyboardInterrupt: 

In [16]:
config

LlamaConfig {
  "_name_or_path": "meta-llama/Llama-3.2-1B",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 16,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.1",
  "use_cache": true,
  "vocab_size": 128256
}

In [59]:
wandb.finish()

0,1
eval/loss,█▅▄▃▃▃▂▂▂▂▂▁▁▁▁▁
eval/runtime,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
eval/samples_per_second,███████████████▁
eval/steps_per_second,███████████████▁
train/epoch,▁▁▁▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
train/global_step,▁▁▁▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
train/grad_norm,▁█▆█▆▆▄▃▂▅
train/learning_rate,█▇▆▆▅▄▃▃▂▁
train/loss,█▄▃▂▂▂▂▁▁▁

0,1
eval/loss,5.94853
eval/runtime,127.6027
eval/samples_per_second,29.466
eval/steps_per_second,3.683
train/epoch,0.01332
train/global_step,500.0
train/grad_norm,1.83157
train/learning_rate,2e-05
train/loss,6.1436
