In [1]:
import sys
sys.path.append('../')
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from net.model_factory import model_factory
from solver import ArbitrarySolver
from graph import Segment
from tqdm import tqdm

# Create Model and Specify Input size and cuda device
Here we use darts_cifar10 model as an example

In [2]:
arch = 'darts_cifar10'
device = 'cuda:0'
input_size = (64, 3, 32, 32)
model = model_factory[arch]().to(device)

108 108 36
108 144 36
144 144 36
144 144 36
144 144 36
144 144 36
144 144 72
144 288 72
288 288 72
288 288 72
288 288 72
288 288 72
288 288 72
288 288 144
288 576 144
576 576 144
576 576 144
576 576 144
576 576 144
576 576 144


# Create Computation Graph of the Model
Set the model to train status, create a random tensor of input size and send it into model.parse_graph function. The parse_graph function does a forward with input tensor and create the computation graph

In [3]:
model.train()
inp = torch.rand(*input_size).to(device)
G, source, target = model.parse_graph(inp)

# Solve optimal checkpointing for the model

In [4]:
solver = ArbitrarySolver()
run_graph, best_cost = solver.solve(G, source, target, use_tqdm=False)
run_segment = Segment(run_graph, source, target, do_checkpoint=True)
torch.cuda.empty_cache()

Building Division Tree
Getting Max Terms
Solving Optimal for Each Max Term


# Create CIFAR-10 dataset and data loader

In [5]:
transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=input_size[0],
                                              shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=input_size[0],
                                             shuffle=False, num_workers=0)

Files already downloaded and verified
Files already downloaded and verified


# Set up loss and optimizer, train and evaluate for two epochs
The only difference with regular training is that we use run_segment to perform checkpointing training in training phase.

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# train for 2 epoch and eval for 2 epoch
for epoch in range(2):
    # use model to switch between train and evaluation
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        inputs.requires_grad = True
        optimizer.zero_grad()

        # use run_segment to do checkpointing forward and backward for training
        outputs = run_segment(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 99:  # print every 100 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

    model.eval()
    eval_running_loss = 0.0
    for i, data in enumerate(testloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # use model to do forward for evaluation
        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        eval_running_loss += loss.item()
    print('[%d] loss: %.3f' % (epoch + 1, eval_running_loss / len(testloader)))

    # save model weights
    torch.save(model.state_dict(), './checkpoint.pth')

[1,   100] loss: 2.287
[1,   200] loss: 2.131
[1,   300] loss: 1.993
[1,   400] loss: 1.849
[1,   500] loss: 1.781
[1,   600] loss: 1.744
[1,   700] loss: 1.670
[1] loss: 1.632
[2,   100] loss: 1.623
[2,   200] loss: 1.576
[2,   300] loss: 1.533
[2,   400] loss: 1.544
[2,   500] loss: 1.505
[2,   600] loss: 1.497
[2,   700] loss: 1.471
[2] loss: 1.414
