In [22]:
from torch.utils.data import DataLoader
from facenet_pytorch import MTCNN, InceptionResnetV1
import torch
import numpy as np
import onnx
from PIL import Image
import PIL
import torchvision


In [23]:


def show_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    print('model size: {:.3f}MB'.format(size_all_mb))

In [123]:
mtcnn = MTCNN(image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,)

emb_model = InceptionResnetV1(pretrained="vggface2").eval()

In [124]:
img.size

(160, 160)

In [125]:
img = Image.open(r"C:\Users\Gram\Desktop\projects\facenet_speaker\data\test\n000002\0001_01.jpg")
# img
# img = VGG_Faces2_Triplet(root=r"C:\Users\Gram\Desktop\projects\facenet_speaker\data\test")._load_image(r"C:\Users\Gram\Desktop\projects\facenet_speaker\data\test\n000002\0001_01.jpg")
# img._size
# img
cropped = mtcnn(img)

In [126]:
cropped.shape
# np.array(img).shape

torch.Size([3, 160, 160])

In [127]:
show_model_size(mtcnn)
show_model_size(emb_model)

model size: 1.892MB
model size: 106.583MB


In [175]:
import os
import numpy as np
import collections
import glob
import torchvision.transforms as transforms

class VGG_Faces2_Triplet(torch.utils.data.Dataset):
    mean_bgr = np.array([91.4953, 103.8827, 131.0912])  # from resnet50_ft.prototxt
    fallback_transform = transforms.Compose([
        transforms.Resize((160, 160)),  # match MTCNN output size
        transforms.ToTensor(),
        transforms.Normalize(mean=VGG_Faces2_Triplet.mean_bgr[::-1] / 255.0, std=[1.0, 1.0, 1.0])
    ])
    
    def __init__(self, root, transform=True, horizontal_flip=False, upper=None):
        assert os.path.exists(root), f"root: {root} not found."
        self.root = root
        self._transform = transform
        self.horizontal_flip = horizontal_flip

        self.class_to_images = collections.defaultdict(list)
        self.img_info = []

        for i, img_file in enumerate(glob.glob(os.path.join(root, "*/*.jpg"))):
            class_id = img_file.split("\\")[-2]  # or split("/")[-2] on Linux
            label = class_id

            info = {'cid': class_id, 'img': img_file, 'lbl': label}
            self.img_info.append(info)
            self.class_to_images[class_id].append(info)

            if i % 1000 == 0:
                print(f"processing: {i} images")
            if upper and i == upper - 1:
                break

        # List of all class IDs
        self.class_ids = list(self.class_to_images.keys())

    def __len__(self):
        return len(self.img_info)

    def __getitem__(self, index):
        anchor_info = self.img_info[index]
        anchor_class = anchor_info['cid']

        # Sample positive
        positives = self.class_to_images[anchor_class]
        positive_info = anchor_info
        while positive_info == anchor_info and len(positives) > 1:
            positive_info = np.random.choice(positives)

        # Sample negative from different class
        negative_class = anchor_class
        while negative_class == anchor_class:
            negative_class = np.random.choice(self.class_ids)
        negative_info = np.random.choice(self.class_to_images[negative_class])

        # Load and transform images
        anchor_img = self._load_image(anchor_info['img'])
        positive_img = self._load_image(positive_info['img'])
        negative_img = self._load_image(negative_info['img'])

        return anchor_img, positive_img, negative_img

    def _load_image(self, img_path):
        img = PIL.Image.open(img_path)
        img_crop = mtcnn(img)
        if img_crop is not None:
            return img_crop
        return self.fallback_transform(img)

In [176]:
root = r"C:\Users\Gram\Desktop\projects\facenet_speaker\data\test"
ds = VGG_Faces2_Triplet(root=root)
# glob.glob(os.path.join(root, "*/*.jpg"))

processing: 0 images


In [177]:
ds_loader = DataLoader(ds, batch_size=128)

In [178]:
triplet_loss_fn = torch.nn.TripletMarginLoss(margin=1.0, p=2)


In [179]:
from tqdm import tqdm

total_loss = 0
with torch.no_grad():
    for data in tqdm(ds_loader):
        acnhor = emb_model(data[0])
        pos = emb_model(data[1])
        neg = emb_model(data[2])
        loss = triplet_loss_fn(acnhor, pos, neg)
        total_loss += loss.item()
print(total_loss/len(ds_loader))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [01:48<00:00, 21.63s/it]

0.38907732367515563





In [180]:
print(total_loss/len(ds_loader))

0.38907732367515563


In [188]:
ds_loader = DataLoader(ds, batch_size=128)

In [189]:
torch.onnx.export(
    emb_model,              
    (next(iter(ds_loader))[0],),
    "resnet50_vggface2.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=11
)

In [190]:
onnx_model = onnx.load("resnet50_vggface2.onnx")
onnx.checker.check_model(onnx_model)

In [191]:

model_path = "resnet50_vggface2.onnx"
size_bytes = os.path.getsize(model_path)
size_mb = size_bytes / (1024 * 1024)

print(f"ONNX model size: {size_mb:.2f} MB")

ONNX model size: 89.57 MB


In [202]:
import onnxruntime as ort
providers = ["CPUExecutionProvider"]
ort_sess = ort.InferenceSession('resnet50_vggface2.onnx', providers=providers)

In [195]:
test_data = next(iter(ds_loader))[0].numpy()

In [205]:
import time
start = time.time()
results = ort_sess.run(["output"], {"input": test_data})
time.time() - start

1.1868040561676025

In [197]:
from onnxruntime.quantization import quantize_dynamic, QuantType
import onnxruntime as ort

base_model = "resnet50_vggface2.onnx"
quant_model = "resnet50_vggface2_quantized16.onnx"
quantize_dynamic(base_model, quant_model, weight_type=QuantType.QUInt8)

In [198]:

model_path = "resnet50_vggface2_quantized16.onnx"
size_bytes = os.path.getsize(model_path)
size_mb = size_bytes / (1024 * 1024)

print(f"ONNX model size: {size_mb:.2f} MB")

ONNX model size: 22.66 MB


In [206]:
providers = ["CPUExecutionProvider"]
ort_sess = ort.InferenceSession('resnet50_vggface2_quantized16.onnx', providers=providers)


In [207]:
start = time.time()
results = ort_sess.run(["output"], {'input': test_data})
# olrder cpu? not optimized for quantized models?
time.time() - start

2.510838031768799

In [107]:
from tqdm import tqdm

total_loss = 0
with torch.no_grad():
    for data in tqdm(ds_loader):
        acnhor = ort_sess.run(["output"], {'input': data[0].numpy()})[0]
        pos = ort_sess.run(["output"], {'input': data[1].numpy()})[0]
        neg = ort_sess.run(["output"], {'input': data[2].numpy()})[0]
        loss = triplet_loss_fn(torch.Tensor(acnhor), torch.Tensor(pos), torch.Tensor(neg))
        total_loss += loss.item()
print(total_loss/len(ds_loader))

 80%|█████████████████████████████████████████████████████████████████████████████████████████████████▌                        | 4/5 [03:07<00:46, 46.75s/it]


InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Got invalid dimensions for input: input for the following indices
 index: 0 Got: 88 Expected: 128
 Please fix either the inputs or the model.

In [117]:
# last iteration didn't work since the dimusion for the last batch changed
# smaller loss probably due to last batch that wasn't being loaded
print(total_loss/4)


0.9721526801586151


In [118]:
data = next(iter(ds_loader))[0]

In [121]:
start = time.time()
emb_model(data)
time.time() - start

26.560136318206787

In [120]:
import time

start = time.time()
ort_sess.run(["output"], {'input': data.numpy()})[0]
time.time() - start

14.256854057312012