Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training Requirements #15

Closed
krantiparida opened this issue May 1, 2021 · 8 comments
Closed

Training Requirements #15

krantiparida opened this issue May 1, 2021 · 8 comments

Comments

@krantiparida
Copy link

Hi, I was trying to use your code for training in another dataset for the depth prediction task. I noticed that during training I could not increase the batch size beyond 2. With a batch size of 2 and images of size 224x448 it takes almost 9GB of memory. Can you comment on the memory requirement? Like how did you train the model and how much memory it took? It will be really helpful if you can share some insights on training.

Thanks

@krantiparida
Copy link
Author

Further, I was also trying to run the network in parallel by wrapping the model to torch.nn.Dataparallel(), but it seems the model is not correctly parallelized. The wights still reside in one GPU. I have attached the screenshot of the error below.
Screenshot 2021-05-02 at 6 44 17 PM

@ranftlr
Copy link
Contributor

ranftlr commented May 3, 2021

These are big models, so yes, they are memory intensive. We trained with 4 Quadro 6000 cards that have 24 GB memory each. Alternatively you can distribute across 8 GPUs with 12 GB of memory each. About 100 GBs of memory would be roughly the total amount of memory needed to train with our settings in terms of number of datasets, resolution, and batch size. If you don't have access to that kind of infrastructure, you could use gradient accumulation on fewer cards.

The second issue has to do with how we hook into timm: we monkey patch an additional method onto the object, so that we don't need to modify the original library source. See here for an example: https://github.com/intel-isl/DPT/blob/72830e11c7e72f58aee1465ab5207e3d0f0ab9fd/dpt/vit.py#L343

Unfortunately, this strategy doesn't play well with nn.DataParallel. AFAIK this is an open issue with DataParallel. See here for a similar discussion: Cadene/pretrained-models.pytorch#112. I recommend to either switch to DistributedDataParallel, where this issue doesn't occur, or to rewrite the backbone so that this monkey patching isn't required.

@krantiparida
Copy link
Author

Thanks, @ranftlr for the detailed response. Yes. I later found out that the reason for that. Further, I also noticed that register_forward_hook operation doesn't go well with nn.dataparallel as mentioned here: register-forward-hook-with-multiple-gpus. I worked around it and used the nn.Modulelist to store each of the blocks and then took the features from the corresponding block in the forward operation. This totally bypassed the register_forward_hook function and I can run the model in parallel now.

@ranftlr
Copy link
Contributor

ranftlr commented May 4, 2021

Great, thanks for pointing out this solution. I'm closing this issue for now.

@ranftlr ranftlr closed this as completed May 4, 2021
@Tord-Zhang
Copy link

@krantiparida Hi,I have run into the same issue. Could you please share the rewritten version of dpt model?

@krantiparida
Copy link
Author

@Tord-Zhang I have modified the code as per my requirement. I am not sure if that will be useful to you. However, I am attaching below the DPT model part. The other functions used are similar to the ones mentioned in this repo.

class VisualNet(nn.Module):
    def __init__(self, 
            backbone="vit_large_patch16_384",
            use_readout='project',
            vit_features = 1024,
            pretrained=False,
            start_index=1):
        super(VisualNet, self).__init__()
        self.hooks = {
                "vitb_rn50_384": [0, 1, 8, 11],
                "vit_base_patch16_384": [2, 5, 8, 11],
                "vit_large_patch16_384": [5, 11, 17, 23],
            }
        self.model = timm.create_model(backbone, pretrained=pretrained)
        self.model.features = nn.ModuleList()
        for blk in self.model.blocks:
            self.model.features.append(blk)
        
        self.model.start_index = start_index
        self.model.patch_size = [16, 16]
        self.model.hooks = self.hooks[backbone]
        
    def forward(self, x):
        b, c, h, w = x.shape
        x.contiguous(memory_format=torch.channels_last)

        _, activations = forward_flex(self.model, x)
        
        return (activations[0], activations[1], activations[2], activations[3])

def forward_flex(self, x):
    b, c, h, w = x.shape
    pos_embed = _resize_pos_embed(
        self. start_index, self.pos_embed, 
        h // self.patch_size[1], w // self.patch_size[0]
    )
    B = x.shape[0]
    if hasattr(self.patch_embed, "backbone"):
        x = self.patch_embed.backbone(x)
        if isinstance(x, (list, tuple)):
            x = x[-1]  # last feature if backbone outputs list/tuple of features
    x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
    if getattr(self, "dist_token", None) is not None:
        cls_tokens = self.cls_token.expand(
            B, -1, -1
        )  # stole cls_tokens impl from Phil Wang, thanks
        dist_token = self.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)
    else:
        cls_tokens = self.cls_token.expand(
            B, -1, -1
        )  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)

    x = x + pos_embed
    x = self.pos_drop(x)
    results= []
    for ii,model in enumerate(self.features):
        x = model(x)
        if ii in self.hooks:
            results.append(x)
    x = self.norm(x)

    return x, results`

@Tord-Zhang
Copy link

@krantiparida hi, as far as I am concerned, this code cannot solve the problem cause by self.patch_embed ?

@krantiparida
Copy link
Author

@Tord-Zhang yes. This will not address the problem of self.patch_embed. In my case, I did not require self.patch_embed. However, I think you can also implement the same using nn.Modulelist as well.

@ranftlr ranftlr mentioned this issue Sep 29, 2021
scyonggg added a commit to scyonggg/ICEIC2023 that referenced this issue Oct 17, 2022
Add DDP multi-GPU training code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants