In [2]:
import torch
#from torch.autograd import Variable

In [3]:
from torch.utils.checkpoint import checkpoint_sequential
import torch.nn as nn

In [5]:
model = nn.Sequential(
    nn.Linear(100, 50),
    nn.ReLU(),
    nn.Linear(50, 20),
    nn.ReLU(),
    nn.Linear(20, 5),
    nn.ReLU()
)

In [8]:
input_var = torch.randn(1,100, requires_grad=True)

In [11]:
segments=2
modules = [module for k, module in model._modules.items()]

out = checkpoint_sequential(modules, segments, input_var)

In [12]:
model.zero_grad()
out.sum().backward()

In [13]:
output_checkpointed = out.data.clone()
grad_checkpointed = {}
for name, param in model.named_parameters():
    grad_checkpointed[name] = param.grad.data.clone()

In [15]:
original = model
x = input_var.clone().detach().requires_grad_(True)

In [16]:
out = original(x)
out_not_checkpointed = out.data.clone()

In [17]:
original.zero_grad()
out.sum().backward()
grad_not_checkpointed = {}
for name, param in model.named_parameters():
    grad_not_checkpointed[name] = param.grad.data.clone()

In [21]:
output_checkpointed, out_not_checkpointed

(tensor([[0.1718, 0.0000, 0.2658, 0.0000, 0.0030]]),
 tensor([[0.1718, 0.0000, 0.2658, 0.0000, 0.0030]]))

In [23]:
for name in grad_checkpointed:
    print(grad_checkpointed[name]==grad_not_checkpointed[name])

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])
tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True])
tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         

In [25]:
# define the model 
class RNNModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)
        self.nhid = nhid
        self.nlayers = nlayers

    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

In [27]:
class RNNModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)
        self.nhid = nhid
        self.nlayers = nlayers
        
        def run_function(self, start, end):
            def custom_forward(*inputs):
                output, hidden = self.rnn(
                    inputs[0][start:(end+1)], (inputs[1], inputs[2])
                )
                return output, hidden[0], hidden[1]
        return custom_forward
    
    def forward(self, input, hidden, segments):
        emb = self.drop(self.encoder(input))
        # checkpoint self.rnn() computation
        output = []
        segment_size = len(modules) // segments
        for start in range(0, segment_size * (segments - 1), segment_size):
            end = start + segment_size - 1
            # Note that if there are multiple inputs, we pass them as as is without 
            # wrapping in a tuple etc.
            out = checkpoint.checkpoint(
                self.run_function(start, end), emb, hidden[0], hidden[1]
            )
            output.append(out[0])
            hidden = (out[1], out[2])
        out = checkpoint.checkpoint(
            self.run_function(end + 1, len(modules) - 1), emb, hidden[0], hidden[1]
        )
        output.append(out[0])
        hidden = (out[1], out[2])
        
        output = torch.cat(output, 0)
        hidden = (out[1], out[2])
        output = self.drop(output)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable, Function
import torch.utils.checkpoint as checkpoint

class ConvBNReLU(nn.Module):
    
    def __init__(self, in_planes, out_planes):
        
        super(ConvBNReLU, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_planes)
        self.relu1 = nn.ReLU(inplace=True)
    
    def forward(self, x):
        out = self.relu1(self.bn1(self.conv1(x)))
        return out

class DummyNet(nn.Module):
    def __init__(self):
        super(DummyNet, self).__init__()
        self.features = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)),
            ('bn1', nn.BatchNorm2d(16)),
            ('relu1', nn.ReLU(inplace=True)),
        ]))

        # The module that we want to checkpoint
        self.module = ConvBNReLU(16, 64) 
        
        self.final_module = ConvBNReLU(64, 64)

In [37]:
def ppp(arg1, arg2, arg3=1):
    print(arg1)
    print(arg2)
    print(arg3)
    
def p(arg1):
    print(arg1)

def h(d, args):
    ppp(*args)

In [40]:
h(1, (1,2))

1
2
1
