From 7b5e71b780ad3f8dd65f4de6bdda07f1bf6c4b64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B3=BD=E5=90=89?= Date: Wed, 26 Nov 2025 22:49:16 +0800 Subject: [PATCH 1/4] add gitignore --- F2LLM/.gitignore | 1 + 1 file changed, 1 insertion(+) create mode 100644 F2LLM/.gitignore diff --git a/F2LLM/.gitignore b/F2LLM/.gitignore new file mode 100644 index 0000000..b42097e --- /dev/null +++ b/F2LLM/.gitignore @@ -0,0 +1 @@ +**/__pycache__/ \ No newline at end of file From 2841a3fa8654f87ddb305f4f5f41bcd1ec6b5379 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B3=BD=E5=90=89?= Date: Wed, 26 Nov 2025 22:50:06 +0800 Subject: [PATCH 2/4] adapt universal decoder-only base model --- F2LLM/arguments.py | 32 ++++-- F2LLM/configs/config_llama3_8b.json | 31 ++++++ F2LLM/configs/config_universal.json | 32 ++++++ F2LLM/model.py | 14 ++- F2LLM/model_adapters.py | 66 +++++++++++ F2LLM/model_factory.py | 110 ++++++++++++++++++ F2LLM/run.py | 4 +- F2LLM/tokenize_data_universal.py | 167 ++++++++++++++++++++++++++++ 8 files changed, 442 insertions(+), 14 deletions(-) create mode 100644 F2LLM/configs/config_llama3_8b.json create mode 100644 F2LLM/configs/config_universal.json create mode 100644 F2LLM/model_adapters.py create mode 100644 F2LLM/model_factory.py create mode 100644 F2LLM/tokenize_data_universal.py diff --git a/F2LLM/arguments.py b/F2LLM/arguments.py index b967c8f..970cf62 100644 --- a/F2LLM/arguments.py +++ b/F2LLM/arguments.py @@ -4,31 +4,38 @@ @dataclass class Args: + model_path: str = None + experiment_id: str = None - model_path: str - experiment_id: str + model_config: dict = None + tokenizer_config: dict = None + # save dir - output_dir: str - tb_dir: str - cache_dir: str + output_dir: str = None + tb_dir: str = None + cache_dir: str = None + # training arguments - train_data_path: str + train_data_path: str = None train_batch_size: int = 8 max_seq_length: int = 2048 learning_rate: float = 1e-4 min_lr: float = 1e-6 weight_decay: float = 1e-2 warmup_steps: int = 100 + # embedding-related settings num_hard_neg: int = 7 + # train steps take precedence over epochs, set to -1 to disable train_steps: int = -1 train_epochs: int = 5 log_interval: int = 20 checkpointing_steps: int = 100 validation_steps: int = 100 + # just placeholder, for logging purpose - num_processes: int=0 + num_processes: int = 0 def dict(self): return asdict(self) @@ -36,11 +43,20 @@ def dict(self): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str) + parser.add_argument("--config", type=str, required=True, + help="Path to configuration file") arg = parser.parse_args() + with open(arg.config) as f: config = json.load(f) + args = Args(**config) + + # 确保model_path正确设置 + if not args.model_path and args.model_config: + args.model_path = args.model_config["model_path"] + args.output_dir = f"{args.output_dir}/{args.experiment_id}" args.tb_dir = f"{args.tb_dir}/{args.experiment_id}" + return args \ No newline at end of file diff --git a/F2LLM/configs/config_llama3_8b.json b/F2LLM/configs/config_llama3_8b.json new file mode 100644 index 0000000..784957a --- /dev/null +++ b/F2LLM/configs/config_llama3_8b.json @@ -0,0 +1,31 @@ +{ + "model_config": { + "model_type": "llama", + "model_path": "meta-llama/Meta-Llama-3-8B", + "model_params": { + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2" + } + }, + "tokenizer_config": { + "add_special_tokens": false, + "padding_side": "right", + "pad_token": null + }, + "experiment_id": "llama3-8b+lr.8e-6+bs.16x32+context.1024+2epochs", + "train_data_path": "training_data/data_tokenized", + "output_dir": "output", + "tb_dir": "output/tb", + "cache_dir": "cache", + "train_batch_size": 16, + "checkpointing_steps": 5000, + "validation_steps": 5000, + "max_seq_length": 1024, + "learning_rate": 8e-6, + "min_lr": 1e-7, + "weight_decay": 0.01, + "warmup_steps": 500, + "train_epochs": 2, + "log_interval": 100, + "num_hard_neg": 7 +} \ No newline at end of file diff --git a/F2LLM/configs/config_universal.json b/F2LLM/configs/config_universal.json new file mode 100644 index 0000000..3239ddf --- /dev/null +++ b/F2LLM/configs/config_universal.json @@ -0,0 +1,32 @@ +{ + "model_config": { + "model_type": "qwen", + "model_path": "models/qwen3-4b", + "model_params": { + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "trust_remote_code": true + } + }, + "tokenizer_config": { + "add_special_tokens": false, + "padding_side": "right", + "pad_token": null + }, + "experiment_id": "4b+lr.8e-6+bs.16x32+context.1024+2epochs", + "train_data_path": "training_data/data_tokenized", + "output_dir": "output", + "tb_dir": "output/tb", + "cache_dir": "cache", + "train_batch_size": 16, + "checkpointing_steps": 5000, + "validation_steps": 5000, + "max_seq_length": 1024, + "learning_rate": 8e-6, + "min_lr": 1e-7, + "weight_decay": 0.01, + "warmup_steps": 500, + "train_epochs": 2, + "log_interval": 100, + "num_hard_neg": 7 +} \ No newline at end of file diff --git a/F2LLM/model.py b/F2LLM/model.py index d33ade7..42657d5 100644 --- a/F2LLM/model.py +++ b/F2LLM/model.py @@ -1,5 +1,5 @@ import torch -from transformers import AutoModel, AutoTokenizer +from model_factory import ModelFactory class F2LLM: @@ -11,10 +11,15 @@ def __init__(self, self.args = args self.dtype = torch.bfloat16 - self.device = None # set after accelerator.prepare - self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2') + self.device = None # set after accelerator.prepare + + # Use model factory to create adapter + self.adapter = ModelFactory.create_adapter(model_path, max_seq_length, args) + + # Load model and tokenizer + self.lm = self.adapter.load_model() self.lm.config.use_cache = False - self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.tokenizer = self.adapter.load_tokenizer() self.max_seq_length = max_seq_length def set_device(self): @@ -34,4 +39,3 @@ def forward(self, batch): 'passage_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs, 2*bs)]), 'negative_passage_features': None if num_hard_neg == 0 else torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(2*bs, len(batch['seq_lens']))]).view(bs, num_hard_neg, -1) } - diff --git a/F2LLM/model_adapters.py b/F2LLM/model_adapters.py new file mode 100644 index 0000000..a6034ea --- /dev/null +++ b/F2LLM/model_adapters.py @@ -0,0 +1,66 @@ +import torch +from transformers import AutoModel, AutoTokenizer +import os +import json +from abc import ABC, abstractmethod + + +class BaseModelAdapter(ABC): + """Base model adapter interface""" + + def __init__(self, model_path, max_seq_length=512, args=None): + self.model_path = model_path + self.max_seq_length = max_seq_length + self.args = args + self.dtype = torch.bfloat16 + self.device = None + + @abstractmethod + def load_model(self): + """Load model""" + pass + + @abstractmethod + def load_tokenizer(self): + """Load tokenizer""" + pass + + def get_model_config(self): + """Get model configuration""" + config_path = os.path.join(self.model_path, 'config.json') + if os.path.exists(config_path): + with open(config_path) as f: + return json.load(f) + return {} + + +class QwenAdapter(BaseModelAdapter): + """Qwen series model adapter (Qwen, Qwen2, Qwen3)""" + + def load_model(self): + return AutoModel.from_pretrained( + self.model_path, + trust_remote_code=True, + torch_dtype=self.dtype, + attn_implementation='flash_attention_2' + ) + + def load_tokenizer(self): + return AutoTokenizer.from_pretrained(self.model_path) + + +class LlamaAdapter(BaseModelAdapter): + """Llama series model adapter (Llama-2, Llama-3, CodeLlama)""" + + def load_model(self): + return AutoModel.from_pretrained( + self.model_path, + torch_dtype=self.dtype, + attn_implementation='flash_attention_2' + ) + + def load_tokenizer(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_path) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer \ No newline at end of file diff --git a/F2LLM/model_factory.py b/F2LLM/model_factory.py new file mode 100644 index 0000000..6c36a1a --- /dev/null +++ b/F2LLM/model_factory.py @@ -0,0 +1,110 @@ +import os +import json +from typing import Dict, Type +from model_adapters import BaseModelAdapter, QwenAdapter, LlamaAdapter + + +class ModelFactory: + """Model factory for creating adapters based on model type""" + + # Mapping of model types to adapters + MODEL_ADAPTERS: Dict[str, Type[BaseModelAdapter]] = { + 'qwen': QwenAdapter, + 'qwen2': QwenAdapter, + 'qwen3': QwenAdapter, + 'llama': LlamaAdapter, + } + + @classmethod + def create_adapter(cls, model_path: str, max_seq_length: int = 512, args=None) -> BaseModelAdapter: + """Create adapter based on model path and type""" + model_type = cls.detect_model_type(model_path) + adapter_class = cls.MODEL_ADAPTERS.get(model_type) + + if not adapter_class: + # Use LlamaAdapter as fallback for unknown model types + print(f"Warning: Unknown model type '{model_type}', using LlamaAdapter as fallback") + adapter_class = LlamaAdapter + + return adapter_class(model_path, max_seq_length, args) + + @classmethod + def detect_model_type(cls, model_path: str) -> str: + """Detect model type""" + # Method 1: Detect via config file + config_path = os.path.join(model_path, 'config.json') + if os.path.exists(config_path): + try: + with open(config_path) as f: + config = json.load(f) + model_type = config.get('model_type', '').lower() + if model_type: + return model_type + except Exception: + pass + + # Method 2: Infer from path name + path_lower = model_path.lower() + model_type_mappings = { + 'qwen': ['qwen', 'qwen2', 'qwen3'], + 'llama': ['llama', 'llama-2', 'llama-3', 'meta-llama', 'codellama'], + } + + for model_type, keywords in model_type_mappings.items(): + for keyword in keywords: + if keyword in path_lower: + return model_type + + # Method 3: Detect via folder structure + if os.path.exists(os.path.join(model_path, 'tokenizer_config.json')): + try: + tokenizer_config_path = os.path.join(model_path, 'tokenizer_config.json') + with open(tokenizer_config_path) as f: + tokenizer_config = json.load(f) + tokenizer_class = tokenizer_config.get('tokenizer_class', '').lower() + + if 'qwen' in tokenizer_class: + return 'qwen' + elif 'llama' in tokenizer_class: + return 'llama' + except Exception: + pass + + return 'unknown' + + @classmethod + def list_supported_models(cls) -> list: + """Return list of supported model types""" + return list(cls.MODEL_ADAPTERS.keys()) + + @classmethod + def get_model_info(cls, model_path: str) -> dict: + """Get model information""" + model_type = cls.detect_model_type(model_path) + adapter_class = cls.MODEL_ADAPTERS.get(model_type) + + info = { + 'model_path': model_path, + 'detected_type': model_type, + 'adapter_class': adapter_class.__name__ if adapter_class else 'Unknown', + 'is_supported': model_type in cls.MODEL_ADAPTERS + } + + # Try to get model configuration info + config_path = os.path.join(model_path, 'config.json') + if os.path.exists(config_path): + try: + with open(config_path) as f: + config = json.load(f) + info.update({ + 'model_name': config.get('_name_or_path', 'Unknown'), + 'vocab_size': config.get('vocab_size', 0), + 'hidden_size': config.get('hidden_size', 0), + 'num_layers': config.get('num_hidden_layers', 0), + 'num_attention_heads': config.get('num_attention_heads', 0), + 'max_position_embeddings': config.get('max_position_embeddings', 0) + }) + except Exception as e: + info['config_error'] = str(e) + + return info \ No newline at end of file diff --git a/F2LLM/run.py b/F2LLM/run.py index e40b707..e566ff3 100644 --- a/F2LLM/run.py +++ b/F2LLM/run.py @@ -14,6 +14,7 @@ from torch.nn.utils.rnn import pad_sequence from torch.optim import AdamW from model import F2LLM +from model_factory import ModelFactory os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -69,7 +70,8 @@ def collate_fn(batch_raw): train_datasets.append((dataset_name, dataset['train'])) valid_datasets.append((dataset_name, dataset['test'])) -tokenizer = AutoTokenizer.from_pretrained(args.model_path) +adapter = ModelFactory.create_adapter(args.model_path) +tokenizer = adapter.load_tokenizer() train_loaders = { name: DataLoader(ds, shuffle=True, batch_size=args.train_batch_size, collate_fn=collate_fn) diff --git a/F2LLM/tokenize_data_universal.py b/F2LLM/tokenize_data_universal.py new file mode 100644 index 0000000..ae661be --- /dev/null +++ b/F2LLM/tokenize_data_universal.py @@ -0,0 +1,167 @@ +from multiprocessing import Pool +import numpy as np +import pandas as pd +import os +import json +from transformers import AutoTokenizer +from tqdm.auto import tqdm +from model_factory import ModelFactory + + +class UniversalTokenizer: + """Universal tokenizer supporting multiple model types""" + + def __init__(self, model_config_path): + """Initialize universal tokenizer + + Args: + model_config_path: Path to model config file or model directory + """ + self.model_config = self._load_model_config(model_config_path) + self.model_type = self.model_config.get('model_type', 'unknown') + self.model_path = self.model_config.get('model_path', model_config_path) + self.max_seq_length = self.model_config.get('max_seq_length', 1023) + + # Create adapter and load tokenizer using model factory + self.adapter = ModelFactory.create_adapter(self.model_path) + self.tokenizer = self.adapter.load_tokenizer() + + # Apply tokenizer configuration + self._apply_tokenizer_config() + + def _load_model_config(self, config_path): + """Load model configuration""" + if os.path.isfile(config_path) and config_path.endswith('.json'): + with open(config_path) as f: + return json.load(f) + elif os.path.isdir(config_path): + # If directory, try to load config.json + config_file = os.path.join(config_path, 'config.json') + if os.path.exists(config_file): + with open(config_file) as f: + return json.load(f) + return {'model_path': config_path} + + def _apply_tokenizer_config(self): + """Apply tokenizer configuration""" + tokenizer_config = self.model_config.get('tokenizer_config', {}) + + if 'padding_side' in tokenizer_config: + self.tokenizer.padding_side = tokenizer_config['padding_side'] + + if 'pad_token' in tokenizer_config: + if tokenizer_config['pad_token'] is not None: + self.tokenizer.pad_token = tokenizer_config['pad_token'] + elif self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + elif self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + def tokenize(self, text): + """Universal tokenize method""" + add_special_tokens = self.model_config.get('tokenizer_config', {}).get('add_special_tokens', False) + + tokenizer_outputs = self.tokenizer( + text, + max_length=self.max_seq_length, + truncation=True, + add_special_tokens=add_special_tokens + ) + return np.array(tokenizer_outputs.input_ids + [self.tokenizer.eos_token_id]) + + +def process_sent(sentence, tokenizer): + """Process single sentence""" + add_special_tokens = tokenizer.model_config.get('tokenizer_config', {}).get('add_special_tokens', False) + + tokenizer_outputs = tokenizer.tokenizer( + sentence, + max_length=tokenizer.max_seq_length, + truncation=True, + add_special_tokens=add_special_tokens + ) + return np.array(tokenizer_outputs.input_ids + [tokenizer.tokenizer.eos_token_id]) + + +def process_sent_batch(s, tokenizer): + """Process sentences in batch""" + return s.apply(lambda x: process_sent(x, tokenizer)) + + +def parallelize(data, func, num_of_processes=8, tokenizer=None): + """Parallel processing of data""" + indices = np.array_split(data.index, num_of_processes) + data_split = [data.iloc[idx] for idx in indices] + with Pool(num_of_processes) as pool: + data = pd.concat(pool.starmap(func, [(df, tokenizer) for df in data_split])) + return data + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description='Universal data tokenization script') + parser.add_argument('--model_config', type=str, required=True, + help='Path to model config file or model directory') + parser.add_argument('--data_dir', type=str, default='training_data', + help='Directory containing training data') + parser.add_argument('--output_dir', type=str, default='data_tokenized', + help='Output directory for tokenized data') + parser.add_argument('--num_processes', type=int, default=8, + help='Number of processes for parallel processing') + + args = parser.parse_args() + + # Initialize universal tokenizer + tokenizer = UniversalTokenizer(args.model_config) + + print(f"Model type: {tokenizer.model_type}") + print(f"Model path: {tokenizer.model_path}") + print(f"Vocab size: {len(tokenizer.tokenizer)}") + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Process data - exactly same as tokenize_data_qwen.py + root_dir = args.data_dir + for ds_name in tqdm(sorted(os.listdir(root_dir))): + if not ds_name.endswith('.parquet'): + continue + + print(f"Processing {ds_name}...", flush=True) + + df = pd.read_parquet(os.path.join(root_dir, ds_name)) + + # Process query - use parallel processing, no deduplication + df['query_input_ids'] = parallelize(df['query'], process_sent_batch, args.num_processes, tokenizer) + + # Determine number of negative samples + num_neg = 24 if 'negative_2' in df.columns else 1 + + # Collect all passages and negatives (no deduplication for query) + ls = df.passage.to_list() + for i in range(1, num_neg+1): + if f'negative_{i}' in df.columns: + ls += df[f'negative_{i}'].to_list() + + # Deduplicate passages and negatives (exactly same as tokenize_data_qwen.py) + ls = list(set(ls)) + df_tmp = pd.DataFrame({'text': ls}) + df_tmp['input_ids'] = parallelize(df_tmp['text'], process_sent_batch, args.num_processes, tokenizer) + df_tmp = df_tmp.set_index('text') + + # Apply mappings + df['passage_input_ids'] = df.passage.map(df_tmp.input_ids) + + for i in range(1, num_neg+1): + if f'negative_{i}' in df.columns: + df[f'negative_{i}_input_ids'] = df[f'negative_{i}'].map(df_tmp.input_ids) + + # Save results + output_path = os.path.join(args.output_dir, ds_name) + df.to_parquet(output_path, index=False) + print(f"Saved tokenized data to {output_path}") + + +if __name__ == "__main__": + main() \ No newline at end of file From e6ab2bc370d0f59185a8e9a0ee3e096fa26d39f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B3=BD=E5=90=89?= Date: Thu, 27 Nov 2025 15:49:20 +0800 Subject: [PATCH 3/4] update README.md --- F2LLM/README.md | 7 +++++++ F2LLM/configs/config_llama3_8b.json | 31 ----------------------------- 2 files changed, 7 insertions(+), 31 deletions(-) delete mode 100644 F2LLM/configs/config_llama3_8b.json diff --git a/F2LLM/README.md b/F2LLM/README.md index 6b79819..60681ad 100644 --- a/F2LLM/README.md +++ b/F2LLM/README.md @@ -30,6 +30,13 @@ In this repo we provide a streamlined and efficient script for training embeddin - Modify model path, data path, and other arguments in `configs/config.json`. - Start training with `accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config.json`. +#### For More Decoder-Only Models +- Setup environment following `requirements.txt`. We note that transformers>=4.51.0 is required for training decoder-only models. +- Download data and backbone models from Hugging Face. +- Run `tokenize_data_universal.py` with model path, tokenized data_dir to tokenize the downloaded data +- Modify model path, train_data_path, and other arguments in `configs/config_universal.json`. +- Start training with `accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config.json`. + Note: we recommend setting `num_processes` to 1 in `configs/accelerate_config.yaml` and launch the training code once to generate cache for training data before starting the actual training. For multi-node training, run on the main node: diff --git a/F2LLM/configs/config_llama3_8b.json b/F2LLM/configs/config_llama3_8b.json deleted file mode 100644 index 784957a..0000000 --- a/F2LLM/configs/config_llama3_8b.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "model_config": { - "model_type": "llama", - "model_path": "meta-llama/Meta-Llama-3-8B", - "model_params": { - "torch_dtype": "bfloat16", - "attn_implementation": "flash_attention_2" - } - }, - "tokenizer_config": { - "add_special_tokens": false, - "padding_side": "right", - "pad_token": null - }, - "experiment_id": "llama3-8b+lr.8e-6+bs.16x32+context.1024+2epochs", - "train_data_path": "training_data/data_tokenized", - "output_dir": "output", - "tb_dir": "output/tb", - "cache_dir": "cache", - "train_batch_size": 16, - "checkpointing_steps": 5000, - "validation_steps": 5000, - "max_seq_length": 1024, - "learning_rate": 8e-6, - "min_lr": 1e-7, - "weight_decay": 0.01, - "warmup_steps": 500, - "train_epochs": 2, - "log_interval": 100, - "num_hard_neg": 7 -} \ No newline at end of file From ea5219dcd8f6ad0e192fd45f831ee9cca8f48650 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B3=BD=E5=90=89?= Date: Thu, 27 Nov 2025 15:50:38 +0800 Subject: [PATCH 4/4] update README.md --- F2LLM/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/F2LLM/README.md b/F2LLM/README.md index 60681ad..717ff13 100644 --- a/F2LLM/README.md +++ b/F2LLM/README.md @@ -35,7 +35,7 @@ In this repo we provide a streamlined and efficient script for training embeddin - Download data and backbone models from Hugging Face. - Run `tokenize_data_universal.py` with model path, tokenized data_dir to tokenize the downloaded data - Modify model path, train_data_path, and other arguments in `configs/config_universal.json`. -- Start training with `accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config.json`. +- Start training with `accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config_universal.json`. Note: we recommend setting `num_processes` to 1 in `configs/accelerate_config.yaml` and launch the training code once to generate cache for training data before starting the actual training.