In [5]:
from PIL import Image
import numpy as np

import ivy
from ivy_models.transformers.perceiver_io import PerceiverIOSpec, PerceiverIO

In [2]:
ivy.set_backend("torch")

In [11]:
# params
# whether to load pretrained model weights or not
load_pretrained_weights = False

input_dim = 3
num_input_axes = 2
output_dim = 1000
network_depth = 8 if load_pretrained_weights else 1
num_lat_att_per_layer = 6 if load_pretrained_weights else 1
device = "cpu"
learn_query = [True]
batch_shape = [1]
img_dims = [224, 224]
queries_dim = 1024

In [7]:
# prepare image for classification
img_raw = Image.open("n01443537_goldfish.jpeg").resize((224, 224))

def normalize_and_standardize(img_raw:Image, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    img = np.array(img_raw)
    img = img.astype("float32")
    img /= 255

    mean = np.array(mean)
    std = np.array(std)

    img[:, :] -= mean
    img[:, :] /= std
    return img

img_arr = normalize_and_standardize(img_raw)

img = ivy.array(img_arr[None], dtype='float32', device=device) # (1, 224, 224, 3)

(224, 224, 3)


In [13]:
model = PerceiverIO(PerceiverIOSpec(input_dim=input_dim,
                                    num_input_axes=num_input_axes,
                                    output_dim=output_dim,
                                    queries_dim=queries_dim,
                                    network_depth=network_depth,
                                    learn_query=learn_query,
                                    query_shape=[1],
                                    num_fourier_freq_bands=64,
                                    num_lat_att_per_layer=num_lat_att_per_layer,
                                    device=device))

In [14]:
queries = None if learn_query else ivy.random_uniform(shape=batch_shape + [1, queries_dim], device=device)
logits = model(img, queries=queries) # (1, 1, 1000)

In [21]:
# get the ImageNet class ID
predicted_class = ivy.argmax(logits[0][0])
predicted_probability = ivy.softmax(logits[0][0])[predicted_class]

In [23]:
print("Models predicts class {} with {} probability".format(predicted_class, predicted_probability))

Models predicts class ivy.array(792) with ivy.array(0.0112668) probability
