### Configuration 1 (baseline)

* vit-G/14 config from https://arxiv.org/pdf/2103.00020.pdf
* BART-L encoder config from https://huggingface.co/facebook/bart-large/blob/main/config.json

In [1]:
from tqdm import trange
import torch
from clip import CLIP, LeanTransformerConfig


vision_config = LeanTransformerConfig(
    hidden_size=1664,
    num_hidden_layers=48,
    num_attention_heads=16,
    intermediate_size=8192,
    position_embedding_type='absolute',
    max_position_embeddings= (336 // 14) ** 2,
)

text_config = LeanTransformerConfig(
    hidden_size=1024,
    num_hidden_layers=12,
    num_attention_heads=16,
    intermediate_size=4096,
    position_embedding_type='rotary',
)

clip = CLIP(embed_dim=1024, image_resolution=336, vision_patch_size=14, context_length=256, vocab_size=30_000,
            vision_config=vision_config, text_config=text_config)
# note: i could not find output dimension for vit-G/14, extrapolating to 1024 from CLIP paper

GRAD_CHECKPOINTS = False
# enable gradient checkpointing aka rematerialization
if GRAD_CHECKPOINTS:
    clip.transformer._get_sequential().gradient_checkpointing = True
    clip.visual.transformer._get_sequential().gradient_checkpointing = True


GPU_PARAMS_MIXED = False
# if enabled, this will emulate a config where (1) gpu params are mostly fp16, but (2) we store fp32 versions in RAM
if GPU_PARAMS_MIXED:
    for param in clip.parameters():
        if param.numel() > 2 ** 16:
            param.data = param.data.half()


for param in clip.parameters():
    if param.requires_grad:  # pre-populate grads to avoid fragmentation
        param.grad = torch.zeros_like(param)

clip = clip.cuda()

opt = torch.optim.SGD(clip.parameters(), lr=1e-3)
# using SGD as a mock-up for offloading. hivemind will offload optimizer to RAM, so the memory usage will be same as SGD


print(f"Total params: {sum(p.numel() for p in clip.parameters()) / 1e9 :.3f}B")
print(f"ViT params: {sum(p.numel() for p in clip.visual.parameters()) / 1e9 :.3f}B")
print(f"Text transformer params: {sum(p.numel() for p in clip.transformer.parameters()) / 1e9 :.3f}B")
print(f"Memory usage (model + grads): {torch.cuda.max_memory_allocated() / 2 ** 30:.1f}GiB")


Total params: 2.028B
ViT params: 1.845B
Text transformer params: 0.151B
Memory usage (model + grads): 15.1GiB


### Memory usage (with checkpoints, mixed params)

In [3]:
for batch_size in (1, 4, 8, 16):
    for i in trange(10):
        with torch.cuda.amp.autocast():
            image = torch.randn(batch_size, 3, 336, 336, device='cuda')
            text = torch.randint(30_000, size=(batch_size, 256), device='cuda')
            image_features, text_features, tau = clip.forward(image, text)
            not_a_real_loss = torch.mean(image_features @ text_features.t() * tau)
            not_a_real_loss.backward()
            
            del image, text, image_features, text_features, not_a_real_loss
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

    print(f"Memory usage (batch={batch_size}): {torch.cuda.max_memory_allocated() / 2 ** 30:.1f}GiB")

100%|██████████| 10/10 [00:04<00:00,  2.06it/s]


Memory usage (batch=1): 9.1GiB


100%|██████████| 10/10 [00:15<00:00,  1.54s/it]


Memory usage (batch=4): 9.7GiB


100%|██████████| 10/10 [00:29<00:00,  2.99s/it]


Memory usage (batch=8): 11.7GiB


100%|██████████| 10/10 [00:58<00:00,  5.83s/it]


Memory usage (batch=16): 15.8GiB


### Memory usage (with checkpoints, fp32 params)

In [2]:
# note: please restart kernel to reset max_memory_allocated

for batch_size in (1, 2, 4, 6):
    for i in trange(10):
        with torch.cuda.amp.autocast():
            image = torch.randn(batch_size, 3, 336, 336, device='cuda')
            text = torch.randint(30_000, size=(batch_size, 256), device='cuda')
            image_features, text_features, tau = clip.forward(image, text)
            not_a_real_loss = torch.mean(image_features @ text_features.t() * tau)
            not_a_real_loss.backward()
            
            del image, text, image_features, text_features, not_a_real_loss
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

    print(f"Memory usage (batch={batch_size}): {torch.cuda.max_memory_allocated() / 2 ** 30:.1f}GiB")

100%|██████████| 10/10 [00:06<00:00,  1.47it/s]


Memory usage (batch=1): 19.4GiB


100%|██████████| 10/10 [00:09<00:00,  1.07it/s]


Memory usage (batch=2): 20.0GiB


100%|██████████| 10/10 [00:16<00:00,  1.63s/it]


Memory usage (batch=4): 21.0GiB


100%|██████████| 10/10 [00:23<00:00,  2.36s/it]


Memory usage (batch=6): 22.0GiB


### Memory usage (no checkpoints, fp32 params)

In [2]:
# note: please restart kernel to reset max_memory_allocated

for batch_size in (1, 2, 3):
    for i in trange(10):
        with torch.cuda.amp.autocast():
            image = torch.randn(batch_size, 3, 336, 336, device='cuda')
            text = torch.randint(30_000, size=(batch_size, 256), device='cuda')
            image_features, text_features, tau = clip.forward(image, text)
            not_a_real_loss = torch.mean(image_features @ text_features.t() * tau)
            not_a_real_loss.backward()
            
            del image, text, image_features, text_features, not_a_real_loss
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

    print(f"Memory usage (batch={batch_size}): {torch.cuda.max_memory_allocated() / 2 ** 30:.1f}GiB")

100%|██████████| 10/10 [00:06<00:00,  1.62it/s]


Memory usage (batch=1): 20.3GiB


100%|██████████| 10/10 [00:08<00:00,  1.23it/s]


Memory usage (batch=2): 21.6GiB


100%|██████████| 10/10 [00:10<00:00,  1.09s/it]

Memory usage (batch=3): 23.0GiB





### Configuration 2: PixelFly

* use PixelFly block-sparse weight matrices from https://arxiv.org/abs/2112.00029
* increase hidden size 1644 -> 2048 to make it compatible with intermediate_size
* otherwise same as configuration 1

In [1]:
from tqdm import trange
import torch
from clip import CLIP, LeanTransformerConfig


vision_config = LeanTransformerConfig(
    hidden_size=2048,       # <-- changed this line!!!
    block_size=64,          # <-- added this line!!!
    lowrank_dim=64,         # <-- added this line!!!
    num_hidden_layers=48,
    num_attention_heads=16,
    intermediate_size=8192,
    position_embedding_type='absolute',
    max_position_embeddings= (336 // 14) ** 2,
)

text_config = LeanTransformerConfig(
    hidden_size=1024,
    num_hidden_layers=12,
    num_attention_heads=16,
    intermediate_size=4096,
    position_embedding_type='rotary',
)

clip = CLIP(embed_dim=1024, image_resolution=336, vision_patch_size=14, context_length=256, vocab_size=30_000,
            vision_config=vision_config, text_config=text_config)
# note: i could not find output dimension for vit-G/14, extrapolating to 1024 from CLIP paper

GRAD_CHECKPOINTS = False
# enable gradient checkpointing aka rematerialization
if GRAD_CHECKPOINTS:
    clip.transformer._get_sequential().gradient_checkpointing = True
    clip.visual.transformer._get_sequential().gradient_checkpointing = True


GPU_PARAMS_MIXED = False
# if enabled, this will emulate a config where (1) gpu params are mostly fp16, but (2) we store fp32 versions in RAM
if GPU_PARAMS_MIXED:
    for param in clip.parameters():
        if param.numel() > 2 ** 16:
            param.data = param.data.half()


for param in clip.parameters():
    if param.requires_grad:  # pre-populate grads to avoid fragmentation
        param.grad = torch.zeros_like(param)

clip = clip.cuda()

opt = torch.optim.SGD(clip.parameters(), lr=1e-3)
# using SGD as a mock-up for offloading. hivemind will offload optimizer to RAM, so the memory usage will be same as SGD


print(f"Total params: {sum(p.numel() for p in clip.parameters()) / 1e9 :.3f}B")
print(f"ViT params: {sum(p.numel() for p in clip.visual.parameters()) / 1e9 :.3f}B")
print(f"Text transformer params: {sum(p.numel() for p in clip.transformer.parameters()) / 1e9 :.3f}B")
print(f"Memory usage (model + grads): {torch.cuda.max_memory_allocated() / 2 ** 30:.1f}GiB")


Total params: 0.743B
ViT params: 0.559B
Text transformer params: 0.151B
Memory usage (model + grads): 5.6GiB


### Memory usage (no checkpoints, fp32 params)

In [2]:
for batch_size in (1, 4, 8):
    for i in trange(10):
        with torch.cuda.amp.autocast():
            image = torch.randn(batch_size, 3, 336, 336, device='cuda')
            text = torch.randint(30_000, size=(batch_size, 256), device='cuda')
            image_features, text_features, tau = clip.forward(image, text)
            not_a_real_loss = torch.mean(image_features @ text_features.t() * tau)
            not_a_real_loss.backward()
            
            del image, text, image_features, text_features, not_a_real_loss

    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    print(f"Memory usage (batch={batch_size}): {torch.cuda.max_memory_allocated() / 2 ** 30:.1f}GiB")
    
# note: throughput is a bit slower since we're using a bigger model (1664 -> 2048)

100%|██████████| 10/10 [00:06<00:00,  1.47it/s]


Memory usage (batch=1): 7.7GiB


100%|██████████| 10/10 [00:19<00:00,  1.96s/it]


Memory usage (batch=4): 13.0GiB


100%|██████████| 10/10 [00:36<00:00,  3.68s/it]

Memory usage (batch=8): 20.1GiB





### Configuration 3: Sharing with adapters

* Share large matrices as in ALBERT: https://arxiv.org/abs/1909.11942
* Each layer has non-shared layer-norm and biases as in https://arxiv.org/abs/2107.11817
* Each layer has non-shared low-dimensional adapter a-la [LoRA](https://arxiv.org/abs/2106.09685)

In [1]:
from tqdm import trange
import torch
from clip import CLIP, LeanTransformerConfig


vision_config = LeanTransformerConfig(
    hidden_size=1664,
    share_large_matrices=True,  # <-- added this line!!!
    adapter_dim=32,             # <-- added this line!!!
    num_hidden_layers=48,
    num_attention_heads=16,
    intermediate_size=8192,
    position_embedding_type='absolute',
    max_position_embeddings= (336 // 14) ** 2,
)  # vit-G/14 config from https://arxiv.org/pdf/2103.00020.pdf

text_config = LeanTransformerConfig(
    hidden_size=1024,
    share_large_matrices=True,  # <-- added this line!!!
    adapter_dim=32,             # <-- added this line!!!
    num_hidden_layers=12,
    num_attention_heads=16,
    intermediate_size=4096,
    position_embedding_type='rotary',
)  # BART-L encoder config from https://huggingface.co/facebook/bart-large/blob/main/config.json

clip = CLIP(embed_dim=1024, image_resolution=336, vision_patch_size=14, context_length=256, vocab_size=30_000,
            vision_config=vision_config, text_config=text_config)
# note: i could not find output dimension for vit-G/14, extrapolating to 1024 from CLIP paper

GRAD_CHECKPOINTS = False
# enable gradient checkpointing aka rematerialization
if GRAD_CHECKPOINTS:
    clip.transformer._get_sequential().gradient_checkpointing = True
    clip.visual.transformer._get_sequential().gradient_checkpointing = True


GPU_PARAMS_MIXED = False
# if enabled, this will emulate a config where (1) gpu params are mostly fp16, but (2) we store fp32 versions in RAM
if GPU_PARAMS_MIXED:
    for param in clip.parameters():
        if param.numel() > 2 ** 16:
            param.data = param.data.half()


for param in clip.parameters():
    if param.requires_grad:  # pre-populate grads to avoid fragmentation
        param.grad = torch.zeros_like(param)

clip = clip.cuda()

opt = torch.optim.SGD(clip.parameters(), lr=1e-3)
# using SGD as a mock-up for offloading. hivemind will offload optimizer to RAM, so the memory usage will be same as SGD


print(f"Total params: {sum(p.numel() for p in clip.parameters()) / 1e9 :.3f}B")
print(f"ViT params: {sum(p.numel() for p in clip.visual.parameters()) / 1e9 :.3f}B")
print(f"Text transformer params: {sum(p.numel() for p in clip.transformer.parameters()) / 1e9 :.3f}B")
print(f"Memory usage (model + grads): {torch.cuda.max_memory_allocated() / 2 ** 30:.1f}GiB")


Total params: 0.140B
ViT params: 0.089B
Text transformer params: 0.019B
Memory usage (model + grads): 1.0GiB


### Memory usage (no checkpoints, fp32 params)

In [2]:
for batch_size in (1, 4, 8, 12):
    for i in trange(10):
        with torch.cuda.amp.autocast():
            image = torch.randn(batch_size, 3, 336, 336, device='cuda')
            text = torch.randint(30_000, size=(batch_size, 256), device='cuda')
            image_features, text_features, tau = clip.forward(image, text)
            not_a_real_loss = torch.mean(image_features @ text_features.t() * tau)
            not_a_real_loss.backward()
            
            del image, text, image_features, text_features, not_a_real_loss

    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    print(f"Memory usage (batch={batch_size}): {torch.cuda.max_memory_allocated() / 2 ** 30:.1f}GiB")
    
# note: throughput is a bit slower since we're using a bigger model (1664 -> 2048)

100%|██████████| 10/10 [00:05<00:00,  1.70it/s]


Memory usage (batch=1): 2.8GiB


100%|██████████| 10/10 [00:13<00:00,  1.40s/it]


Memory usage (batch=4): 7.1GiB


100%|██████████| 10/10 [00:26<00:00,  2.61s/it]


Memory usage (batch=8): 12.8GiB


100%|██████████| 10/10 [00:38<00:00,  3.83s/it]

Memory usage (batch=12): 18.6GiB





### Configuration 4: reversible

* same as configuration 3, but now using reversible layers instead of checkpoints
* same reversible strategy as in [Reformer](https://arxiv.org/abs/2001.04451) (using [revlib](https://github.com/clashluke/revlib) under the hood)
* reversible transformer keeps 2 sets of activations in memory instead of keeping one checkpoint for every layer
* as a result, we can get a __batch size of up to 64__

In [1]:
from tqdm import trange
import torch
from clip import CLIP, LeanTransformerConfig


vision_config = LeanTransformerConfig(
    reversible=True,            # <-- added this line!!!
    hidden_size=1664,
    share_large_matrices=True,  # <-- added this line!!!
    adapter_dim=32,             # <-- added this line!!!
    num_hidden_layers=48,
    num_attention_heads=16,
    intermediate_size=8192,
    position_embedding_type='absolute',
    max_position_embeddings= (336 // 14) ** 2,
)  # vit-G/14 config from https://arxiv.org/pdf/2103.00020.pdf

text_config = LeanTransformerConfig(
    reversible=True,            # <-- added this line!!!
    hidden_size=1024,
    share_large_matrices=True,  # <-- added this line!!!
    adapter_dim=32,             # <-- added this line!!!
    num_hidden_layers=12,
    num_attention_heads=16,
    intermediate_size=4096,
    position_embedding_type='rotary',
)  # BART-L encoder config from https://huggingface.co/facebook/bart-large/blob/main/config.json

clip = CLIP(embed_dim=1024, image_resolution=336, vision_patch_size=14, context_length=256, vocab_size=30_000,
            vision_config=vision_config, text_config=text_config)
# note: i could not find output dimension for vit-G/14, extrapolating to 1024 from CLIP paper

# note: gradient checkpoints are not used if model is reversible!

for param in clip.parameters():
    if param.requires_grad:  # pre-populate grads to avoid fragmentation
        param.grad = torch.zeros_like(param)

clip = clip.cuda()

opt = torch.optim.SGD(clip.parameters(), lr=1e-3)
# using SGD as a mock-up for offloading. hivemind will offload optimizer to RAM, so the memory usage will be same as SGD


print(f"Total params: {sum(p.numel() for p in clip.parameters()) / 1e9 :.3f}B")
print(f"ViT params: {sum(p.numel() for p in clip.visual.parameters()) / 1e9 :.3f}B")
print(f"Text transformer params: {sum(p.numel() for p in clip.transformer.parameters()) / 1e9 :.3f}B")
print(f"Memory usage (model + grads): {torch.cuda.max_memory_allocated() / 2 ** 30:.1f}GiB")


Total params: 0.140B
ViT params: 0.089B
Text transformer params: 0.019B
Memory usage (model + grads): 1.0GiB


In [2]:
for batch_size in (1, 4, 16, 64):
    for i in trange(10):
        with torch.cuda.amp.autocast():
            image = torch.randn(batch_size, 3, 336, 336, device='cuda')
            text = torch.randint(30_000, size=(batch_size, 256), device='cuda')
            image_features, text_features, tau = clip.forward(image, text)
            not_a_real_loss = torch.mean(image_features @ text_features.t() * tau)
            not_a_real_loss.backward()
            opt.step()
            opt.zero_grad()
            
            del image, text, image_features, text_features, not_a_real_loss

    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    print(f"Memory usage (batch={batch_size}): {torch.cuda.max_memory_allocated() / 2 ** 30:.1f}GiB")

# note the singificantly larger batch size (up to 64 per gpu and still some vram left)

100%|██████████| 10/10 [00:06<00:00,  1.44it/s]


Memory usage (batch=1): 1.5GiB


100%|██████████| 10/10 [00:17<00:00,  1.71s/it]


Memory usage (batch=4): 2.1GiB


100%|██████████| 10/10 [01:01<00:00,  6.18s/it]


Memory usage (batch=16): 4.4GiB


100%|██████████| 10/10 [04:05<00:00, 24.57s/it]

Memory usage (batch=64): 13.8GiB





### Configuration 5: going crazy


* combine configurations 1-4
* increase model size as much as possible
* this setup only makes sense if you have 50+ collaborators with a 2080Ti or better GPU each

In [1]:
from tqdm import trange
import torch
from clip import CLIP, LeanTransformerConfig


vision_config = LeanTransformerConfig(
    hidden_size=8192,            # <-- changed this line!!!
    share_large_matrices=True,   # <-- added this line!!!
    block_size=64,               # <-- added this line!!!
    lowrank_dim=120,             # <-- added this line!!!
    adapter_dim=8,               # <-- added this line!!!
    num_hidden_layers=64,        # <-- changed this line!!!
    num_attention_heads=64,      # <-- changed this line!!!
    intermediate_size=32768,     # <-- changed this line!!!
    reversible=True,             # <-- added this line!!!
    position_embedding_type='absolute',
    max_position_embeddings= (336 // 14) ** 2,
)   # let's call this vit-ludicrous /14

text_config = LeanTransformerConfig(
    hidden_size=8192,            # <-- changed this line!!!
    share_large_matrices=True,   # <-- added this line!!!
    block_size=64,               # <-- added this line!!!
    lowrank_dim=120,             # <-- added this line!!!
    adapter_dim=8,               # <-- added this line!!!
    num_hidden_layers=64,        # <-- changed this line!!!
    num_attention_heads=64,      # <-- changed this line!!!
    intermediate_size=32768,     # <-- changed this line!!!
    reversible=True,             # <-- added this line!!!
    position_embedding_type='rotary',
)

clip = CLIP(embed_dim=4096, image_resolution=336, vision_patch_size=14, context_length=256, vocab_size=30_000,
            vision_config=vision_config, text_config=text_config)
# note: embed dim is now 4096, mostly for lulz


GPU_PARAMS_MIXED = False
# if enabled, this will emulate a config where (1) gpu params are mostly fp16, but (2) we store fp32 versions in RAM
if GPU_PARAMS_MIXED:
    for param in clip.parameters():
        if param.numel() > 2 ** 16:
            param.data = param.data.half()


for param in clip.parameters():
    if param.requires_grad:  # pre-populate grads to avoid fragmentation
        param.grad = torch.zeros_like(param)

clip = clip.cuda()

opt = torch.optim.SGD(clip.parameters(), lr=1e-3)
# using SGD as a mock-up for offloading. hivemind will offload optimizer to RAM, so the memory usage will be same as SGD


print(f"Total params: {sum(p.numel() for p in clip.parameters()) / 1e9 :.3f}B")
print(f"ViT params: {sum(p.numel() for p in clip.visual.parameters()) / 1e9 :.3f}B")
print(f"Text transformer params: {sum(p.numel() for p in clip.transformer.parameters()) / 1e9 :.3f}B")
print(f"Memory usage (model + grads): {torch.cuda.max_memory_allocated() / 2 ** 30:.1f}GiB")


Total params: 0.605B
ViT params: 0.183B
Text transformer params: 0.140B
Memory usage (model + grads): 4.5GiB


In [2]:
for batch_size in (1, 2, 4, 8):
    for i in trange(10):
        with torch.cuda.amp.autocast():
            image = torch.randn(batch_size, 3, 336, 336, device='cuda')
            text = torch.randint(30_000, size=(batch_size, 256), device='cuda')
            image_features, text_features, tau = clip.forward(image, text)
            not_a_real_loss = torch.mean(image_features @ text_features.t() * tau)
            not_a_real_loss.backward()
            opt.step()
            opt.zero_grad()
            
            del image, text, image_features, text_features, not_a_real_loss

    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    print(f"Memory usage (batch={batch_size}): {torch.cuda.max_memory_allocated() / 2 ** 30:.1f}GiB")

# note: technically speaking, batch 16 fits into memory with mixed params, but it takes a minute for each run

100%|██████████| 10/10 [00:56<00:00,  5.65s/it]


Memory usage (batch=1): 5.7GiB


100%|██████████| 10/10 [01:34<00:00,  9.49s/it]


Memory usage (batch=2): 6.5GiB


100%|██████████| 10/10 [03:03<00:00, 18.34s/it]


Memory usage (batch=4): 8.3GiB


100%|██████████| 10/10 [05:46<00:00, 34.69s/it]

Memory usage (batch=8): 11.9GiB



