In [1]:
from models.faceboxes import FaceBoxes
import torch.backends.cudnn as cudnn


In [2]:
import torch
x = torch.ones((1, 3, 360, 440)).cuda()

def remove_prefix(state_dict, prefix):
    f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
    return {f(key): value for key, value in state_dict.items()}


In [3]:
def load_model(model, pretrained_path):
    device = torch.cuda.current_device()
    pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
    if "state_dict" in pretrained_dict.keys():
        pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
    else:
        pretrained_dict = remove_prefix(pretrained_dict, 'module.')
    # check_keys(model, pretrained_dict)
    model.load_state_dict(pretrained_dict, strict=False)
    return model


In [4]:
net = FaceBoxes(phase='test', size=None, num_classes=2)    # initialize detector
net = load_model(net, "weights/FaceBoxes.pth")
cudnn.benchmark = True
device = torch.device("cpu" if False else "cuda")
net = net.to(device)
net.eval()

FaceBoxes(
  (conv1): CRelu(
    (conv): Conv2d(3, 24, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3), bias=False)
    (bn): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv2): CRelu(
    (conv): Conv2d(48, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (inception1): Inception(
    (branch1x1): BasicConv2d(
      (conv): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch1x1_2): BasicConv2d(
      (conv): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3_reduce): BasicConv2d(
      (conv): Conv2d(128, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(24, eps=1

In [5]:
from torch2trt import torch2trt

In [6]:
model_trt = torch2trt(net, [x])

In [11]:
from datetime import datetime
start = datetime.now()
res = model_trt(x)
end = datetime.now()
print((end - start).total_seconds())

0.005239


In [16]:
from datetime import datetime
start = datetime.now()
res = net(x)
end = datetime.now()
print((end - start).total_seconds())

0.043538


In [19]:
trt_path = "weights/model_trt.pth"
torch.save(model_trt.state_dict(), trt_path)
