In [6]:
import torch
import torch.utils.data as torchdata
import torch.nn as nn
import tqdm
from timeit import default_timer as timer
import utils

import torch.backends.cudnn as cudnn
cudnn.benchmark = True

# Variables and dataset

In [None]:
model='R110_C10'
load = 'cv/finetuned/R110_C10/ckpt_E_2000_A_0.936_R_1.95E-01_S_16.93_#_469.t7'
data_dir = 'data/'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    
trainset, testset = utils.get_dataset(model, data_dir)

# Accuracy Measure

In [2]:
# compute accuracy
def get_accuracy(preds, target, batch_size):
    corrects = (torch.max(preds, 1)[1].view(target.size()).data == target.data).sum()
    accuracy = 100.0 * corrects/batch_size
    return accuracy.item()

# Inference calculataions

In [3]:
def inference():
    acc = 0.
    matches, policies = [], []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(testloader), total=len(testloader)):
            #load inputs and targets to device
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            # Run Policy Network: probs= dropping probabilites for each block of ResNet
            probs, _ = agent(inputs)

            # transform dropping probabilites to dropping strategy (<0.5 drop; else: keep)
            policy = probs.clone()
            policy[policy<0.5] = 0.0
            policy[policy>=0.5] = 0.0
            
            # run ResNet
            preds = rnet.forward_single(inputs, policy.data.squeeze(0))
            
            # save policy data
            policies.append(policy.data)
            
            # calculate accuracy
            acc += get_accuracy(preds, targets, len(targets))
    
    # save values for later printing 
    policies = torch.cat(policies, 0)
    min_blocks = policies.sum(1).min()
    sparsity = policies.sum(1).mean()

    accuracy = acc/batch_idx
    print(f"Accuracy of the model: {acc/batch_idx:4.2f}%", "Blocks used in avg:", sparsity)
    print("Min blocks: ", min_blocks)
    

# Run inference 

In [5]:
num_workers = 4 if torch.device("cuda") else 1
testloader = torchdata.DataLoader(testset, batch_size=1, shuffle=False, num_workers=num_workers)

rnet, agent = utils.get_model(model, device)

# if no model is loaded, use all blocks
agent.logit.weight.data.fill_(0)
agent.logit.bias.data.fill_(10)

print("loading checkpoints")

if load is not None:
    utils.load_checkpoint(rnet, agent, load)

rnet.eval().to(device)
agent.eval().to(device)

start_testing = timer()

inference() 

end_testing = timer()

testing_time = end_testing - start_testing

print("Testing in ms: %.2f"%(1000*testing_time/len(testloader)))


loading checkpoints
loaded resnet from ckpt_E_2000_A_0.936_R_1.95E-01_S_16.93_#_469.t7
loaded agent from ckpt_E_2000_A_0.936_R_1.95E-01_S_16.93_#_469.t7


100%|██████████| 10000/10000 [00:42<00:00, 233.28it/s]

Accuracy of the model: 9.82% Blocks used in avg: tensor(0., device='cuda:0')
Min blocks:  tensor(0., device='cuda:0')
Testing in ms: 4.30





BlockDrop:
Accuracy of the model: 93.56% Blocks used in avg: tensor(16.9326, device='cuda:0')
Min blocks:  tensor(5., device='cuda:0')
Testing in ms: 8.15

Single ResNet:
Accuracy of the model: 93.12% Blocks used in avg: tensor(54., device='cuda:0')
Min blocks:  tensor(54., device='cuda:0')
Testing in ms: 16.51