Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training siamese (biencoder) based transformer model with gradient checkpointing throws error #23801

Closed
sachinya00 opened this issue May 26, 2023 · 6 comments

Comments

@sachinya00
Copy link

sachinya00 commented May 26, 2023

System Info

PyTorch Lightning Version 1.6.5
Torch 1.13.0
Python version 3.8
CUDA Version: 11.4
4 NVIDIA A100-SXM4-40GBs
transformers 4.24.0

Reproduction

After adding model.gradient_checkpointing_enable() to the training code, throwing below error
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations

The workaround to fix this is add use_reentrant=False in the below file.
https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L600

              layer_outputs = torch.utils.checkpoint.checkpoint(
                   create_custom_forward(layer_module),
                   hidden_states,
                   attention_mask,
                   layer_head_mask,
                   encoder_hidden_states,
                   encoder_attention_mask,
                   use_reentrant=False
               )

What's the best way to fix this? instead of adding the above flag manually in the source code

Expected behavior

adding model.gradient_checkpointing_enable() shouldn't throw any error

Code to reproduce

import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel
import pytorch_lightning as pl


# Sample data
class SampleDataset(Dataset):
    def __init__(self):
        self.data = [
            ("I love coding", "I enjoy programming", 1),
            ("Python is great", "Java is popular", 0),
            ("Deep learning is fascinating", "Machine learning is interesting", 1),
            ("I prefer cats", "I like dogs", 0)
        ]
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text1, text2, label = self.data[idx]
        encoded_text1 = self.tokenizer.encode_plus(text1, add_special_tokens=True, padding='max_length', max_length=128, truncation=True)
        encoded_text2 = self.tokenizer.encode_plus(text2, add_special_tokens=True, padding='max_length', max_length=128, truncation=True)
        input_ids1 = torch.tensor(encoded_text1['input_ids'])
        attention_mask1 = torch.tensor(encoded_text1['attention_mask'])
        input_ids2 = torch.tensor(encoded_text2['input_ids'])
        attention_mask2 = torch.tensor(encoded_text2['attention_mask'])
        return (input_ids1, attention_mask1), (input_ids2, attention_mask2), label


# Define your LightningModule
class SiameseBiEncoder(pl.LightningModule):
    def __init__(self):
        super(SiameseBiEncoder, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.hidden_size = self.bert.config.hidden_size
        self.cosine_similarity = nn.CosineSimilarity(dim=1)
        self.criterion = nn.BCELoss()

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        return pooled_output

    def training_step(self, batch, batch_idx):
        (input_ids1, attention_mask1), (input_ids2, attention_mask2), labels = batch
        embeddings1 = self.forward(input_ids1, attention_mask1)
        embeddings2 = self.forward(input_ids2, attention_mask2)
        similarity_scores = self.cosine_similarity(embeddings1, embeddings2)
        loss = self.criterion(similarity_scores, labels.float())
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        (input_ids1, attention_mask1), (input_ids2, attention_mask2), labels = batch
        embeddings1 = self.forward(input_ids1, attention_mask1)
        embeddings2 = self.forward(input_ids2, attention_mask2)
        similarity_scores = self.cosine_similarity(embeddings1, embeddings2)
        loss = self.criterion(similarity_scores, labels.float())
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=2e-5)
        return optimizer

# Create the LightningDataModule
class SampleDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=4):
        super(SampleDataModule, self).__init__()
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = SampleDataset()
        self.val_dataset = SampleDataset()

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

# Create an instance of your LightningModule
model = SiameseBiEncoder()
model.bert.gradient_checkpointing_enable()
print(f"Gradient Checkpointing: {model.bert.is_gradient_checkpointing}")
# Create the LightningDataModule instance
data_module = SampleDataModule()

# Create a Trainer instance
trainer = pl.Trainer(
    max_epochs=3, devices=2, accelerator="gpu", strategy="ddp")
trainer.fit(model, data_module)
@sgugger
Copy link
Collaborator

sgugger commented May 26, 2023

cc @ArthurZucker and @younesbelkada

@KadriMufti
Copy link

@sachinya00 What does your code look like, including training setup and training args?

@sachinya00
Copy link
Author

I've updated the post with the code to reproduce the same

@ArthurZucker
Copy link
Collaborator

Hey, thanks for providing a reproduction script.
Based on the provided traceback it seems like the issue lies with DDP that is asking you to use _set_static_graph(). Did that work for you?

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Jul 4, 2023
@sachinya00
Copy link
Author

The issue is very similar to the below one and I'm not able to make it work even with _set_static_graph()
https://discuss.pytorch.org/t/distributed-data-parallel-with-triplet-transformer-model/137347/4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants