-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Description
I am using DeepSpeed with Zero Optimization (Stage 2) to train a custom model on multiple GPUs. i want to compute gradients on the input for explainability. However, I am facing challenges when integrating gradient computation for the input in this setup. The memory usage increases significantly, and I lose the memory savings typically achieved by DeepSpeed.
Below is the relevant DeepSpeed configuration I use, passed to the Hugging Face Trainer via the deepspeed argument:
DeepSpeed JSON Configuration (./scripts/zero2.json):
json
Copy code
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
}
}
Below is a minimal example to reproduce the issue:
to launch this script i use this code
/home/user/miniconda3/envs/proejct/bin/python -m deepspeed.launcher.launch --world_info=eyIxMjcuMC4wLjEiOiBbMCwgMSwgMiwgMywgNCwgNSwgNiwgN119 --master_addr=127.0.0.1 --master_port=4242 --no_local_rank /home/user/project/explain/explain.py --deepspeed ./scripts/zero2.json
import torch
from torch import nn
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
class ExampleDataset(Dataset):
def init(self, tokenizer, size=10000, max_length=128):
self.tokenizer = tokenizer
self.texts = [f"This is example text {i}" for i in range(size)]
self.labels = torch.randint(0, 2, (size,)) # Binary classification
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
tokenized = self.tokenizer(
self.texts[idx],
padding="max_length",
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
input_ids = tokenized["input_ids"].squeeze(0)
aux = torch.rand((5,)) * 100 # Random auxiliary input
return {
"aux": aux,
"input_ids": input_ids,
"attention_mask": tokenized["attention_mask"].squeeze(0),
"labels": self.labels[idx],
}
class CustomModel(nn.Module):
def init(self, base_model_name="bert-base-uncased", num_labels=2):
super(CustomModel, self).init()
self.bert = BertForSequenceClassification.from_pretrained(base_model_name, num_labels=num_labels)
self.embedding_dim = self.bert.config.hidden_size
self.aux_linear = nn.Linear(5, self.embedding_dim)
def forward(self, input_ids, attention_mask, aux, labels=None):
aux_embedded = self.aux_linear(aux) # Embed auxiliary input
input_embeddings = self.bert.bert.embeddings(input_ids)
input_embeddings[:, 0, :] += aux_embedded # Modify embeddings
outputs = self.bert(
inputs_embeds=input_embeddings,
attention_mask=attention_mask,
labels=labels,
)
return outputs
def compute_saliency_maps(trainer, loader, device, repeat_factor=1000):
"""
Compute gradients for auxiliary input tensor (aux) in a DeepSpeed-enabled setting.
"""
model = trainer._wrap_model(trainer.model, training=false, dataloader=loader)
model.eval()
for _ in range(repeat_factor):
for batch in tqdm(loader, desc="Computing Saliency Maps"):
batch = {key: value.to(device) if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
batch["aux"].requires_grad_(True) # Enable gradients for aux
outputs = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
aux=batch["aux"],
labels=batch["labels"]
)
loss = outputs.loss
loss.backward() # Compute gradients
grads = batch["aux"].grad
print(f"Gradient norm: {grads.norm().item()}")
if name == "main":
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = CustomModel()
dataset = ExampleDataset(tokenizer, size=200)
loader = DataLoader(dataset, batch_size=1000)
training_args = TrainingArguments(
output_dir="./results",
per_device_eval_batch_size=8,
do_train=False,
do_eval=True,
logging_dir="./logs",
deepspeed="./scripts/zero2.json",
)
trainer = Trainer(
model=model,
args=training_args,
eval_dataset=dataset,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
compute_saliency_maps(trainer, loader, device, repeat_factor=3)
Observations:
Without Gradient Computation for aux: The model works as expected, and DeepSpeed successfully reduces memory usage.
With Gradient Computation for aux: Memory usage increases significantly, negating the benefits of Zero Optimization Stage 2.
Increased Memory Usage in Multi-GPU Setting: While the toy example fits in memory, my actual model OOMs when gradients are enabled for aux.