# Minimal Example of using a pretrained Binary Model for Inference

In [1]:
from bitorch.models import ResnetE18
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from pathlib import Path
import matplotlib.pyplot as plt
import time
import random
import numpy as np
import os


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if os.environ.get("IMAGENET_DATA_DIR") is None:
    data_dir = Path.home() / "data" / "imagenet"
    print("IMAGENET_DATA_DIR environment variable is not set, using default:", str(data_dir.resolve()))
else:
    data_dir = Path(os.environ["IMAGENET_DATA_DIR"])

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class ImageNet(Dataset):
    name = "imagenet"
    num_classes = 1000
    shape = (1, 3, 224, 224)
    mean = (0.485, 0.456, 0.406)
    std_dev = (0.229, 0.224, 0.255)
    num_train_samples = 1281167
    num_val_samples = 50000
    
    def __init__(self, root_directory: Path, is_train: bool = True):
        self.root_directory = root_directory
        self.is_train = is_train
        self.dataset = self.get_dataset(download=False)

    def get_data_dir(self) -> Path:
        split = "train" if self.is_train else "val"
        directory = self.root_directory / split
        return directory

    def get_transform(self) -> transforms.Compose:
        return self.train_transform() if self.is_train else self.test_transform()

    def get_dataset(self) -> Dataset:
        directory = self.get_data_dir()
        print("got directory for imagenet:", directory)
        if not directory.is_dir():
            raise RuntimeError(f"ImageNet directory {str(directory.resolve())} does not exist!")
        return ImageFolder(directory, transform=self.get_transform())

    @classmethod
    def train_transform(cls) -> transforms.Compose:
        crop_scale = 0.08
        return transforms.Compose(
            [
                transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                cls.get_normalize_transform(),
            ]
        )

    @classmethod
    def test_transform(cls) -> transforms.Compose:
        return transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                cls.get_normalize_transform(),
            ]
        )

test_dataset = ImageNet(data_dir, is_train=False)

In [None]:
model = ResnetE18.from_pretrained(input_shape=ImageNet.shape, num_classes=ImageNet.num_classes)
model = model.to(device)

In [None]:
from bitorch.layers import convert
from bitorch import RuntimeMode
import bitorch_engine

bitorch_engine.initialize()

model = convert(model, RuntimeMode.INFERENCE_AUTO, device=device, verbose=True)

In [None]:
def show_picture(picture):
    plt.imshow(np.transpose(picture, (1, 2, 0)))
    plt.show()

random_image = test_dataset[random.randint(0, len(test_dataset) - 1)][0]
random_image = random_image.unsqueeze(0).to(device)

start = time.time()
prediction = model(random_image)
duration = time.time() - start

print(f"Prediction: {prediction} (took {duration:.2f} seconds)")
show_picture(random_image.cpu().detach().numpy()[0])