Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions projects/cocomix/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# CoCoMix

Official PyTorch implementation of "LLM Pretraining with Continuous Concepts".

<p align="center">
<img src=./cocomix.png width="900">
</p>

## Environment
```
conda create -n cocomix python=3.10 -y
conda activate cocomix

# we have developed/tested CoCoMix on torch 2.3.0+cuda12.1
pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txt
```

## Code structure

```
Home
|--conf
|--setup
|--gpt2_69m_ntp.yaml # config for gpt2 69m pretraining 20B tokens for next token prediction
|--gpt2_69m_cocomix.yaml # config for gpt2 50m pretraining 20B tokens for cocomix
|--...
|--config.yaml # general config for training
|--config_eval.yaml # general config for evaluation
|--ddp.yaml # config for huggingface accelerate ddp
|--fsdp_bf16.yaml # config for huggingface accelerate fsdp with bf16
|--data
|--data.py # dataset definition / loader
|--model
|--sparse_autoencoder
... # code for top-k sparse autoencoder
|--__init__.py # Define model loading, concept extractor loading
|--concept_extractor.py # GPT2-124M model with SAE
|--modeling_gpt2_cocomix.py # CoCoMix for GPT2
|--train
|--train_func
|--ntp.py # next token prediction
|--cocomix.py # CoCoMix
|--trainer.py # trainer function defined: optimizer, scheduler, evaluation
|--main.py # main file, define model, define dataset, define trainer
|--test.py # evaluation functions, we use EleutherAI lm-evaluation-harness
|--utils.py # utility functions: loggers
```

## Preparation and configurations

**dataset**:
- OpenWebText: run `./data/openwebtext_preprocess/prepare.py`. Readme file `./data/openwebtext_preprocess/readme.md`
- Set `data_dir` in `./conf/config.yaml` (e.g., `./data/openwebtext_preprocess`)

**WANDB**: To use weight and bias (wandb) logging
- Create a wandb account and get your wandb key
- Set `wandb_key` in `./conf/config.yaml` as your wandb key
- `wandb_project` in `./conf/config.yaml` is the name of your wandb project
- `wandb_entity` in `./conf/config.yaml` is your wandb entity name
- Set `wandb_log` as false if you don't want to use wandb logging

**Concept related**:
- `insert_layer_index`: Which layer to predict concept labels, insert continous concepts
- `sae_layer_index`: Which layer to extract concepts (from the pretrained model)
- `lam_concept`: concept prediction loss hyperparameter (default: 0.1)
- `concept_dim`: number of concepts on the sparse autoencoder (SAE) latent: pretrained SAE uses 32768 (fixed)
- `concept_num`: number of active concepts (i.e., TopK value of sparse activatation) in TopK SAE: pretrained SAE uses 32 (fixed)

All configuration for next token prediction and cocomix are presented in `./conf/setup/`

## Train code
For all experiments, we have used multi-node training. We have provided a slurm job submit example file in `./slurm_bash`.
- Note that the user needs to fill the details in `./slurm_bash/slurm_multi.sh` to use the slurm file (e.g., account, env_name)
- Currently assuming FSDP (to use DDP, change `--config_file` to `./conf/ddp.yaml`)

We also provide a single-node training example code (without slurm).\
If OOM occurs, please increase the gradient accumulation step `grad_acc_steps` and reduce the micro batch size `update_batch_size`.
```
# train gpt2 69m on openwebtext with next token prediction
sbatch ./slurm_bash/slurm_multi.sh setup=gpt2_69m_ntp

# train gpt2 69m on openwebtext with cocomix
sbatch ./slurm_bash/slurm_multi.sh setup=gpt2_69m_cocomix

# train gpt2 69m on single node with FSDP
accelerate launch --config_file ./conf/fsdp_bf16.yaml --num_processes=8 main.py setup=gpt2_69m_ntp

# train gpt2 69m on single node with DDP
accelerate launch --config_file ./conf/ddp.yaml --num_processes=8 main.py setup=gpt2_69m_ntp
```

## Evaluation code
Set `data_dir` in `./conf/config_eval.yaml` with the preprocessed openwebtext dataset path (e.g., `./data/openwebtext_preprocess`).\
We use [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) for the evaluation (except for openwebtext validation perplexity). To evaluate on different dataset, please modify `eval_tasks` in `./conf/config_eval.yaml`.\
Note that `eval_single_ckpt` defines whether to evaluate a single checkpoint or evaluate the entire saved checkpoints with a given freqencey (e.g., if the user have saved the ckpt every 2000 training steps, by setting true, it will evaluate all ckpts at once).
```
# two options
# eval_single_ckpt=True or False

# if True, pass the path including the step (e.g., ./logs/.../step_xxx/), this will only evaluate single ckpt
# the eval_results.json will be saved in ./logs/.../step_xxx/
CUDA_VISIBLE_DEVICES=0 python test.py eval_single_ckpt=True load_path=<LOAD_PATH>

# else, pass the path excluding the step (e.g., ./logs/.../), this will evaluate all ckpts with a frequency of eval_freq (e.g., step_2000, step_4000, ...)
# the eval_results.json will be saved in ./logs/.../
CUDA_VISIBLE_DEVICES=0 python test.py load_path=<LOAD_PATH> eval_freq=2000
```
Binary file added projects/cocomix/cocomix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
79 changes: 79 additions & 0 deletions projects/cocomix/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
wandb_log: true
wandb_entity: null
wandb_project: null
wandb_key: null

defaults:
- _self_
- setup: 'gpt2_69m'

hydra:
run:
dir: .

mode: 'ntp'
seed: 22
rank: 0
suffix: null

# model
base_model: 'openai-community/gpt2'
pretrained_model: 'openai-community/gpt2'
dataset: openwebtext
data_dir: './data/openwebtext_preprocess' # set your data path
n_embd: null
n_layer: null
n_head: null
vocab_size: null

load_path: null
port: 9819
distributed: False
world_size: 1
use_torch_compile: True
compile_dynamo_cache_size_limit: 256

# optimization
lr: 6e-4
lr_schedule: 'cosine_with_min_lr' # 'cosine' 'constant_with_warmup' 'constant',
beta1: 0.9
beta2: 0.95
grad_clip_thresh: 1.
warmup_steps: 2000
min_lr: 6e-5
eps: 1e-8
mixed_precision: null
weight_decay: 0.1
train_steps: 600000 # 600k steps
n_epochs: 0
num_workers: 2

# total batch size = 1024 (context length) * 64 (update_batch_size) * 8 (grad_acc_steps) = 524,288 (~0.5M)
# total number of tokens = train_steps * total batch size = 600k * 0.5M = 300B tokens
update_batch_size: 256 # micro batch size is update_batch_size // num_gpus
grad_acc_steps: 2
block_size: 1024 # context length
dropout: 0.0
bias: False

log_path: null
use_accelerator: True

# saving/evaluation/logging frequency
save_step_freq: 10000
eval_step_freq: 1000
log_step_freq: 50
global_step: 0
val_datasets: ['openwebtext'] # measuring ppl
batch_size_eval: 256
eval_limit: 1000

topK_attri: 4 # TopK for concept label
concept_num: 32 # TopK for SAE activation
concept_dim: 32768 # SAE concept dimention

# sae
sae_location: 'resid_post_mlp'
insert_layer_index: null # CoCoMix model's layer that predict and insert the concept
sae_layer_index: null # SAE layer that is used for concept extraction
lam_concept: 0.1
17 changes: 17 additions & 0 deletions projects/cocomix/conf/config_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
defaults:
- _self_

hydra:
run:
dir: .

rank: 0
seed: 42
base_model: 'openai-community/gpt2'
data_dir: './data/openwebtext_preprocess' # set your data path
load_path: null
batch_size: 64
eval_freq: 2000
eval_single_ckpt: False
eval_tasks: ['lambada_openai','wikitext','hellaswag','piqa','social_iqa','arc_easy','winogrande']
save_result: True
19 changes: 19 additions & 0 deletions projects/cocomix/conf/ddp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: true
gpu_ids: all
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
26 changes: 26 additions & 0 deletions projects/cocomix/conf/fsdp_bf16.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'yes'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true # true for torch.compile
machine_rank: 0
main_process_port: 12345
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
38 changes: 38 additions & 0 deletions projects/cocomix/conf/setup/gpt2_1b_cocomix.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# @package _global_

mode: 'cocomix'
n_embd: 2048
n_layer: 24
n_head: 16
compile_dynamo_cache_size_limit: 512

# optimization
lr: 2e-4
lr_schedule: 'cosine_with_min_lr' # 'cosine' 'constant_with_warmup' 'constant',
beta1: 0.9
beta2: 0.95
grad_clip_thresh: 1.
warmup_steps: 65
min_lr: 2e-5
eps: 1e-8
mixed_precision: null
weight_decay: 0.1
train_steps: 20000 # 20k steps ~ 20B

# total batch size = 1024 (context length) * 1024 (update_batch_size) * 1 (grad_acc_steps) = (~1.0M)
# total number of tokens = train_steps * total batch size = 20k * 1.0M = 20B tokens
update_batch_size: 1024 # micro batch size is update_batch_size // num_gpus
grad_acc_steps: 1
block_size: 1024

# saving/evaluation/logging frequency
save_step_freq: 1000
eval_step_freq: 500
log_step_freq: 50
val_datasets: ['openwebtext'] # measuring ppl
batch_size_eval: 256
eval_limit: 1000

# sae
insert_layer_index: 5 # CoCoMix model's layer that predict and insert the concept
sae_layer_index: 5 # SAE layer that is used for concept extraction
33 changes: 33 additions & 0 deletions projects/cocomix/conf/setup/gpt2_1b_ntp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# @package _global_

mode: 'ntp'
n_embd: 2096
n_layer: 24
n_head: 16

# optimization
lr: 2e-4
lr_schedule: 'cosine_with_min_lr' # 'cosine' 'constant_with_warmup' 'constant',
beta1: 0.9
beta2: 0.95
grad_clip_thresh: 1.
warmup_steps: 65
min_lr: 2e-5
eps: 1e-8
mixed_precision: null
weight_decay: 0.1
train_steps: 20000 # 20k steps ~ 20B

# total batch size = 1024 (context length) * 1024 (update_batch_size) * 1 (grad_acc_steps) = (~1.0M)
# total number of tokens = train_steps * total batch size = 20k * 1.0M = 20B tokens
update_batch_size: 1024 # micro batch size is update_batch_size // num_gpus
grad_acc_steps: 1
block_size: 1024

# saving/evaluation/logging frequency
save_step_freq: 1000
eval_step_freq: 500
log_step_freq: 50
val_datasets: ['openwebtext'] # measuring ppl
batch_size_eval: 256
eval_limit: 1000
38 changes: 38 additions & 0 deletions projects/cocomix/conf/setup/gpt2_386m_cocomix.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# @package _global_

mode: 'cocomix'
n_embd: 1024
n_layer: 24
n_head: 16
compile_dynamo_cache_size_limit: 512

# optimization
lr: 3e-4
lr_schedule: 'cosine_with_min_lr' # 'cosine' 'constant_with_warmup' 'constant',
beta1: 0.9
beta2: 0.95
grad_clip_thresh: 1.
warmup_steps: 130
min_lr: 3e-5
eps: 1e-8
mixed_precision: null
weight_decay: 0.1
train_steps: 40000 # 40k steps ~ 20B

# total batch size = 1024 (context length) * 512 (update_batch_size) * 1 (grad_acc_steps) = (~0.5M)
# total number of tokens = train_steps * total batch size = 40k * 0.5M = 20B tokens
update_batch_size: 512 # micro batch size is update_batch_size // num_gpus
grad_acc_steps: 1
block_size: 1024

# saving/evaluation/logging frequency
save_step_freq: 2000
eval_step_freq: 1000
log_step_freq: 50
val_datasets: ['openwebtext'] # measuring ppl
batch_size_eval: 256
eval_limit: 1000

# sae
insert_layer_index: 5 # CoCoMix model's layer that predict and insert the concept
sae_layer_index: 5 # SAE layer that is used for concept extraction
Loading
Loading