In [1]:
import torch
import torch.nn as nn

In [2]:
def cuda_summary():
    _alloc = torch.cuda.memory_allocated() / 1E9
    _max_alloc = torch.cuda.max_memory_allocated() / 1E9
    mem = f'alloc: {_alloc:.3g}G, max_alloc: {_max_alloc:.3g}G' # (GB)
    print(mem)
    torch.cuda.reset_peak_memory_stats()

def test_cuda_memory(ks, stride, half=False):
    in_ch = 96
    out_ch = 192
    x = torch.randn(16, in_ch, 320, 320).cuda()
    pad = (ks - 1) // 2
    model = nn.Conv2d(in_ch, out_ch, kernel_size=ks, stride=stride, padding=pad).cuda()
    
    if half:
        x = x.half()
        model = model.half()

    print('Mem. usage of x and model:')
    cuda_summary()
    with torch.no_grad():
        model(x)
    print('Mem. usage of forward pass:')
    cuda_summary()

In [3]:
test_cuda_memory(ks=3, stride=1, half=False)

Mem. usage of x and model:
alloc: 0.63G, max_alloc: 0.63G
Mem. usage of forward pass:
alloc: 0.63G, max_alloc: 6.14G


In [4]:
test_cuda_memory(ks=3, stride=1, half=True)

Mem. usage of x and model:
alloc: 0.315G, max_alloc: 0.944G
Mem. usage of forward pass:
alloc: 0.315G, max_alloc: 1.89G


In [5]:
test_cuda_memory(ks=5, stride=1, half=False)

Mem. usage of x and model:
alloc: 0.631G, max_alloc: 0.631G
Mem. usage of forward pass:
alloc: 0.631G, max_alloc: 2.01G


In [6]:
test_cuda_memory(ks=5, stride=1, half=True)

Mem. usage of x and model:
alloc: 0.315G, max_alloc: 0.946G
Mem. usage of forward pass:
alloc: 0.315G, max_alloc: 1.89G
