Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do you train this beast? #8

Open
aegonwolf opened this issue Mar 12, 2023 · 1 comment
Open

How do you train this beast? #8

aegonwolf opened this issue Mar 12, 2023 · 1 comment

Comments

@aegonwolf
Copy link

Hi there,

thanks a lot for all your great repos and implementations!

I've wanted to try this for a segmentation problem and I've had issues training on colabs 40GB GPU with dimensions 256x256.
The Model I've wanted to use is initialized like so:

gen = XUnet(
        dim = target_shape,
        channels = 3,
        dim_mults = (1, 2, 4, 4),
        nested_unet_depths = (4, 3, 2, 1),     # nested unet depths, from unet-squared paper
        consolidate_upsample_fmaps = True,     # whether to consolidate outputs from all upsample blocks, used in unet-squared paper
).to(device)

Is there a trick or what do you estimate the needed Memory is?
I set pin_memory to false, which improved it a little, but still wasn't able to do a single pass (batch_size = 1).

I also noticed most of the memory is reserved, and not allocated, irrespective of the initial size? (always around 35 - 38 GB).

@qbeer
Copy link

qbeer commented May 13, 2024

Better later than never:

class XUnet(nn.Module):

    @beartype
    def __init__(
        self,
        dim,
        init_dim = None,
        out_dim = None,
        frame_kernel_size = 1,
        dim_mults: MaybeTuple(int) = (1, 2, 4, 8),
        num_blocks_per_stage: MaybeTuple(int) = (2, 2, 2, 2),
        num_self_attn_per_stage: MaybeTuple(int) = (0, 0, 0, 1),
        nested_unet_depths: MaybeTuple(int) = (0, 0, 0, 0),
        nested_unet_dim = 32,
        channels = 3,
        use_convnext = False,
        consolidate_upsample_fmaps = True,
        skip_scale = 2 ** -0.5,
        weight_standardize = False,
        attn_heads: MaybeTuple(int) = 8,
        attn_dim_head: MaybeTuple(int) = 32
    ):

Lower the attention heads and/or attention dims.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants