In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
# Import PyTorch
import torch # import main library
import torch.nn as nn # import modules like nn.ReLU()
import torch.nn.functional as F # import torch functions like F.relu() and F.relu_()

In [2]:
def get_memory_allocated(device, inplace = False):
    '''
    Function measures allocated memory before and after the ReLU function call.
    INPUT:
      - device: gpu device to run the operation
      - inplace: True - to run ReLU in-place, False - for normal ReLU call
    '''
    
    # Create a large tensor
    t = torch.randn(10000, 10000, device=device)
    
    # Measure allocated memory
    torch.cuda.synchronize()
    start_max_memory = torch.cuda.max_memory_allocated() / 1024**2
    start_memory = torch.cuda.memory_allocated() / 1024**2
    
    # Call in-place or normal ReLU
    if inplace:
        # F.relu_(t)
        t.mul_(2).div_(5).mul_(7).div_(11)
    else:
        # output = F.relu(t)
        output = t.mul(2).div(5).mul(7).div(11)
    
    # Measure allocated memory after the call
    torch.cuda.synchronize()
    end_max_memory = torch.cuda.max_memory_allocated() / 1024**2
    end_memory = torch.cuda.memory_allocated() / 1024**2
    
    # Return amount of memory allocated for ReLU call
    return end_memory - start_memory, end_max_memory - start_max_memory

In [3]:
# setup the device
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

In [4]:
# call the function to measure allocated memory
memory_allocated, max_memory_allocated = get_memory_allocated(device, inplace = False)
print('Allocated memory: {}'.format(memory_allocated))
print('Allocated max memory: {}'.format(max_memory_allocated))

Allocated memory: 382.0
Allocated max memory: 764.0


In [5]:
memory_allocated_inplace, max_memory_allocated_inplace = get_memory_allocated(device, inplace = True)
print('Allocated memory: {}'.format(memory_allocated_inplace))
print('Allocated max memory: {}'.format(max_memory_allocated_inplace))

Allocated memory: 0.0
Allocated max memory: 0.0


In [6]:
a = torch.Tensor([1])

In [8]:
print(a.grad)

None


In [12]:
print(a.requires_grad)

True


In [11]:
a = torch.nn.Parameter(torch.Tensor([1]))