Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions F2LLM/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class Args:
warmup_steps: int = 100
# embedding-related settings
num_hard_neg: int = 7
# gradient accumulation to simulate larger effective batch size
gradient_accumulation_steps: int = 1
# train steps take precedence over epochs, set to -1 to disable
train_steps: int = -1
train_epochs: int = 5
Expand Down
21 changes: 21 additions & 0 deletions F2LLM/configs/demo_accumulation.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"model_path": "bert-base-uncased",
"experiment_id": "demo-ga",
"output_dir": "output",
"tb_dir": "output/tb",
"cache_dir": "cache",
"train_data_path": "training_data/data_tokenized",
"train_batch_size": 2,
"max_seq_length": 128,
"learning_rate": 1e-4,
"min_lr": 1e-6,
"weight_decay": 1e-2,
"warmup_steps": 10,
"num_hard_neg": 1,
"train_steps": -1,
"train_epochs": 1,
"log_interval": 2,
"checkpointing_steps": 0,
"validation_steps": 1000000,
"gradient_accumulation_steps": 8
}
9 changes: 7 additions & 2 deletions F2LLM/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ def __iter__(self):
# determine training steps
override_train_step = False
if args.train_steps < 0:
args.train_steps = sum(len(v) for v in train_loaders.values()) * args.train_epochs
# interpret train_steps as optimization steps (after accumulation)
total_micro_batches = sum(len(v) for v in train_loaders.values()) * args.train_epochs
accum = max(1, getattr(args, 'gradient_accumulation_steps', 1))
args.train_steps = total_micro_batches // accum
override_train_step = True

accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************")
Expand Down Expand Up @@ -145,7 +148,9 @@ def __iter__(self):

# if training on multiple GPUs, length of dataloader would have changed
if override_train_step:
args.train_steps = len(train_dataloader) * args.train_epochs
total_micro_batches = len(train_dataloader) * args.train_epochs
accum = max(1, getattr(args, 'gradient_accumulation_steps', 1))
args.train_steps = total_micro_batches // accum
accelerator.print(f"******************************** Training step after prepare: {args.train_steps} ********************************")


Expand Down
161 changes: 161 additions & 0 deletions F2LLM/smoke_test_accumulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
from tqdm import tqdm

# Minimal tokenizer-like object
class DummyTokenizer:
def __init__(self, pad_token_id=0):
self.pad_token_id = pad_token_id

# Dummy model implementing required interface
class DummyModel:
def __init__(self, hidden_size=32, tokenizer=None, device="cpu"):
self.tokenizer = tokenizer or DummyTokenizer()
self.lm = nn.Sequential(
nn.Embedding(30522, hidden_size),
nn.Linear(hidden_size, hidden_size)
)
self._device = torch.device(device)
self.lm.to(self._device)

def set_device(self):
self._device = next(self.lm.parameters()).device

@property
def device(self):
return self._device

def forward(self, batch):
input_ids = batch['input_ids'].to(self.device) # [bs_total, seq]
attention_mask = batch['attention_mask'].to(self.device)
bs = batch['bs']
# Compute simple pooled features
emb = self.lm[0](input_ids) # [bs_total, seq, h]
pooled = (emb * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True).clamp_min(1.0)
# split back into query/passages/negatives
num_hard = 1 # keep simple
q = pooled[:bs]
p = pooled[bs:2*bs]
negs = pooled[2*bs:2*bs+bs*num_hard].view(bs, num_hard, -1)
return {
'query_passage_features': q.unsqueeze(1), # [bs,1,h]
'passage_passage_features': p.unsqueeze(1), # [bs,1,h]
'negative_passage_features': negs # [bs,num_hard,h]
}

class SyntheticDataset(Dataset):
def __init__(self, length=64, seq_len=16, vocab=100):
self.length = length
self.seq_len = seq_len
self.vocab = vocab

def __len__(self):
return self.length

def __getitem__(self, idx):
def rand_ids():
return [torch.randint(1, self.vocab, ()).item() for _ in range(self.seq_len)]
return {
'query_input_ids': rand_ids(),
'passage_input_ids': rand_ids(),
'negative_1_input_ids': rand_ids(),
'dataset_name': 'msmarco'
}

def _stack(input_ids, max_len, pad_id):
data = [ids[:max_len] for ids in input_ids]
lens = [len(x) for x in data]
tensor = torch.tensor(sum(data, []))
chunks = tensor.split(lens)
return chunks

def collate_fn(batch_raw, max_seq_length=32, tokenizer=None):
tokenizer = tokenizer or DummyTokenizer()
num_hard_neg = 1
input_ids = _stack(
[s['query_input_ids'] for s in batch_raw]+
[s['passage_input_ids'] for s in batch_raw]+
[s[f'negative_1_input_ids'] for s in batch_raw],
max_seq_length,
tokenizer.pad_token_id
)
seqlens = torch.tensor([ids.size(0) for ids in input_ids])
# pad to batch
input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
attention_masks = input_ids.ne(tokenizer.pad_token_id).long()
return {
'input_ids': input_ids,
'seq_lens': seqlens,
'attention_mask': attention_masks,
'bs': len(batch_raw),
'dataset_name': batch_raw[0]['dataset_name']
}

# Minimal loss helpers adapted from utils.py
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

def inbatch_loss(q, c, criterion, accelerator, temperature=0.05):
bs = q.size(0)
a_norm = F.normalize(q, p=2, dim=-1)
b_cross = accelerator.gather(c)
b_norm = F.normalize(b_cross, p=2, dim=-1)
logits = torch.matmul(a_norm, b_norm.t()) / temperature
labels = torch.arange(bs, device=logits.device) + bs * accelerator.process_index
loss_bs = criterion(logits, labels)
return loss_bs.mean()

def hard_loss(q, c, negs, criterion, accelerator, temperature=0.05):
if negs is None:
return torch.tensor(0.0, device=q.device)
bs = q.size(0)
a = F.normalize(q, p=2, dim=-1)
hard = torch.concat([c.unsqueeze(1), negs], dim=1)
hard = F.normalize(hard, p=2, dim=-1)
logits = (a.unsqueeze(1) * hard).sum(-1) / temperature
return criterion(logits, torch.zeros((bs), dtype=torch.long, device=logits.device)).mean()


def main():
accelerator = Accelerator()
tokenizer = DummyTokenizer()
model = DummyModel(tokenizer=tokenizer, device="cpu")
model.set_device()

ds = SyntheticDataset(length=32, seq_len=8, vocab=100)
loader = DataLoader(ds, batch_size=4, shuffle=True, collate_fn=lambda b: collate_fn(b, max_seq_length=16, tokenizer=tokenizer))
loader = accelerator.prepare(loader)

optimizer = torch.optim.SGD(model.lm.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0)
criterion = CrossEntropyLoss(reduction='none')

accumulation_steps = 4
total_micro = len(loader)
expected_opt_steps = total_micro // accumulation_steps
completed = 0
local_accum = 0

for batch in tqdm(loader, disable=not accelerator.is_local_main_process):
out = model.forward(batch)
loss_h = hard_loss(out['query_passage_features'].squeeze(1), out['passage_passage_features'].squeeze(1), out['negative_passage_features'], criterion, accelerator)
loss_ib = inbatch_loss(out['query_passage_features'].squeeze(1), out['passage_passage_features'].squeeze(1), criterion, accelerator)
loss = (loss_h + loss_ib) / accumulation_steps
accelerator.backward(loss)
local_accum += 1
if local_accum % accumulation_steps == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
completed += 1
if completed >= expected_opt_steps:
break

print(f"Optimization steps: {completed} (expected {expected_opt_steps})")
assert completed == expected_opt_steps, "Accumulation did not match expected steps"
print("Smoke test passed.")

if __name__ == "__main__":
main()
31 changes: 31 additions & 0 deletions F2LLM/test_gradient_accumulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from torch import nn


def run_accumulation_test(accumulation_steps=4, micro_batches=12):
torch.manual_seed(0)
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0)

steps = 0
optimizer.zero_grad()
for i in range(micro_batches):
x = torch.randn(8, 10)
y = torch.randn(8, 1)
out = model(x)
loss = nn.functional.mse_loss(out, y)
(loss / accumulation_steps).backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
steps += 1
return steps


if __name__ == "__main__":
s = run_accumulation_test(accumulation_steps=4, micro_batches=12)
print(f"Optimization steps: {s} (expected 3)")
assert s == 3, f"Expected 3 optimization steps, got {s}"
print("Gradient accumulation test passed.")
33 changes: 21 additions & 12 deletions F2LLM/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def accelerate_train(args,
criterion = CrossEntropyLoss(reduction='none')
pbar = tqdm(range(args.train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
accumulation_steps = max(1, getattr(args, 'gradient_accumulation_steps', 1))
local_accum_counter = 0
loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}
count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS}
Expand Down Expand Up @@ -164,19 +166,26 @@ def accelerate_train(args,
loss = 0.0

loss_total = loss + loss_hard
# scale loss for gradient accumulation
loss_total = loss_total / accumulation_steps

# backward, optimizer, scheduler
accelerator.backward(loss_total)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if optimizer.param_groups[0]['lr'] < args.min_lr:
for i in range(len(optimizer.param_groups)):
optimizer.param_groups[i]['lr'] = args.min_lr

# log
completed_steps += 1
if completed_steps % args.log_interval == 0:
local_accum_counter += 1
stepped = False
if local_accum_counter % accumulation_steps == 0:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
stepped = True
if optimizer.param_groups[0]['lr'] < args.min_lr:
for i in range(len(optimizer.param_groups)):
optimizer.param_groups[i]['lr'] = args.min_lr

# log only on optimization steps
if stepped:
completed_steps += 1
if completed_steps % args.log_interval == 0 and completed_steps > 0:
pbar.update(args.log_interval)

train_log_dict = {"lr": optimizer.param_groups[0]['lr']}
Expand All @@ -202,13 +211,13 @@ def accelerate_train(args,
count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()}

# validation
if completed_steps % args.validation_steps == 0:
if completed_steps % args.validation_steps == 0 and completed_steps > 0:
model.lm.eval()
validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer)
model.lm.train()

# step checkpoint
if args.checkpointing_steps and completed_steps % args.checkpointing_steps == 0:
if args.checkpointing_steps and completed_steps > 0 and completed_steps % args.checkpointing_steps == 0:
output_dir = os.path.join(args.output_dir, f"step_{completed_steps}")
save_checkpoint(args, accelerator, model, output_dir, lr_scheduler)

Expand Down
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,29 @@

<p align="center">
<img src="https://modelscope.cn/api/v1/models/codefuse-ai/CodeFuse-QWen-14B/repo?Revision=master&FilePath=LOGO.jpg&View=true" width="800"/>
### Gradient Accumulation

To train with larger effective batch sizes on limited GPU memory, we added gradient accumulation.

- New config key: `gradient_accumulation_steps` (default: 1)
- Effective global batch size: `train_batch_size * gradient_accumulation_steps * num_processes`
- `train_steps` represent optimization steps (after accumulation). When not set, they are computed as `total_micro_batches * train_epochs // gradient_accumulation_steps`.

Usage:

1. Set in your config JSON:
- `"gradient_accumulation_steps": 8`
2. Run training as usual with `F2LLM/run.py`.

Quick Tests (no real data required):

```bash
python F2LLM/test_gradient_accumulation.py
python F2LLM/smoke_test_accumulation.py
```

The first verifies optimizer step counts; the second runs a small synthetic pipeline on CPU with accumulation.

<p>

Embedding-related repos from CodeFuse, including:
Expand Down