In [None]:
import torchvision.transforms as transforms

# Normalize data with mean=0.5, std=1.0
mnist_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize((0.5,), (1.0,))
])

In [None]:
from torchvision.datasets import MNIST

download_root = './MNIST_DATASET'

#train_dataset = MNIST(download_root, transform=mnist_transform, train=True, download=True)
#valid_dataset = MNIST(download_root, transform=mnist_transform, train=False, download=True)
test_dataset = MNIST(download_root, transform=mnist_transform, train=False, download=True)

# Test

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class _LoopBody(nn.Module):
    def __init__(self, channels):
        super(_LoopBody, self).__init__()
        conv = nn.Conv2d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=3,
            padding=1,
        )
        self.conv = conv

    def forward(self, x):
        x = self.conv(x)
        x = F.relu(x)
        return x

class ControlFlowNet(nn.Module):
    def __init__(self, num_channels: int):
        super(ControlFlowNet, self).__init__()
        self.loop_body = _LoopBody(num_channels)

    def forward(self, x):
        avg = torch.mean(x)
        if avg.item() < 0:
            loop_count = 2
        else:
            loop_count = 1
        for _ in range(loop_count):
            x = self.loop_body(x)
        return x

In [2]:
model = ControlFlowNet(num_channels=3)
print(model)
scripted_model = torch.jit.script(model)

import coremltools
mlmodel = coremltools.converters.convert(
  scripted_model,
  inputs=[coremltools.TensorType(shape=(1, 3, 64, 64))],
)

mlmodel.save("test.mlmodel")

ControlFlowNet(
  (loop_body): _LoopBody(
    (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)
Converting Frontend ==> MIL Ops:   0%|          | 0/10 [00:00<?, ? ops/s]
Converting Frontend ==> MIL Ops: 0 ops [00:00, ? ops/s]

Converting Frontend ==> MIL Ops: 0 ops [00:00, ? ops/s]

Converting Frontend ==> MIL Ops:  83%|████████▎ | 5/6 [00:00<00:00, 1327.06 ops/s]

Converting Frontend ==> MIL Ops:  83%|████████▎ | 5/6 [00:00<00:00, 1249.42 ops/s]
Converting Frontend ==> MIL Ops:  90%|█████████ | 9/10 [00:00<00:00, 168.04 ops/s]
Running MIL optimization passes: 100%|██████████| 18/18 [00:00<00:00, 9916.91 passes/s]
Translating MIL ==> MLModel Ops:   0%|          | 0/12 [00:00<?, ? ops/s]
Translating MIL ==> MLModel Ops: 0 ops [00:00, ? ops/s]

Translating MIL ==> MLModel Ops: 0 ops [00:00, ? ops/s]

Translating MIL ==> MLModel Ops: 100%|██████████| 2/2 [00:00<00:00, 4760.84 ops/s]

Translating MIL ==> MLModel Ops: 100%|██████████| 9/9 [00:00<00:00, 12450.11 o