In [5]:
import torch
from model_training.network import str2Model
from model_training.dataset import CAGTestDataset
from torchvision import transforms
import time

class MockEvent:
    def __init__(self):
        pass
    def record(self):
        self.time = time.time()
    def elapsed_time(self, end):
        return (end.time - self.time) * 1000

def run(device):
    start = torch.mps.Event(enable_timing=True)
    end = torch.mps.Event(enable_timing=True)

    if device == 'cpu':
        start = MockEvent()
        end = MockEvent()

    model = torch.load('output/checkpoint_50000.pth', map_location=torch.device('cpu'))
    model2 = str2Model(model['model_type'])()
    model = model['network']
    model2.load_state_dict(model)
    model = model2.eval()
    model = model.to(device)

    dataset = CAGTestDataset()

    size = [32, 64, 128, 192, 256, 384, 512]
    l = []

    for i, s in enumerate(size):
        img, label = dataset[i]
        img = transforms.Resize((s, s))(img)
        img = img.to(device)
        img = img.unsqueeze(0)

        for _ in range(5):
            model(torch.rand_like(img))
        start.record()
        model(img)
        end.record()
        if device == 'mps':
            torch.mps.synchronize()
        t = start.elapsed_time(end)
        l.append(t)
    return l

    

In [6]:
run('mps')



[7.694291,
 6.7687919999999995,
 6.10575,
 8.721916,
 9.681208,
 17.98275,
 30.364749999999997]

In [7]:
run('cpu')



[8.825063705444336,
 5.953788757324219,
 20.135164260864258,
 43.57504844665527,
 86.02476119995117,
 205.21283149719238,
 408.39409828186035]