In [1]:
%%capture 
%run config.ipynb
%run ViT.ipynb

In [2]:
img_path = data_dir / "cat.jpg"

In [3]:
import torch 
from torchvision import transforms
from PIL import Image

In [4]:
model = ViT(
    image_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    in_channels=IN_CHANNELS,
    n_head=N_HEAD, 
    d_model=D_MODEL, 
    ffn_hidden=FFN_HIDDEN, 
    mlp_hidden=MLP_HIDDEN, 
    n_layers=N_LAYERS, 
    class_num=CLASS_NUM, 
    device=device, 
    drop_prob=DROP_PROB,
)

# load model
model_path = model_dir / "model_68percent_acc"
model.load_state_dict(torch.load(model_path, weights_only=True, map_location=device))

model.eval()


ViT(
  (encoder): Encoder(
    (emb): TransformerEmbedding(
      (patch_emb): PatchEmbedding(
        (emb): Conv2d(3, 400, kernel_size=(4, 4), stride=(4, 4))
      )
      (pos_emb): PositionalEmbedding()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): ModuleList(
      (0-4): 5 x EncoderBlock(
        (norm): LayerNorm((400,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (multihead_attn): MultiheadAttentionBlock(
          (Wq): Linear(in_features=400, out_features=400, bias=True)
          (Wk): Linear(in_features=400, out_features=400, bias=True)
          (Wv): Linear(in_features=400, out_features=400, bias=True)
          (attention): SelfAttentionBlock(
            (softmax): Softmax(dim=-1)
          )
          (Wconcat): Linear(in_features=400, out_features=400, bias=True)
        )
        (ffn): FeedForwardBlock(
          (linear1): Linear(in_features=400, out_features=512, bias=True)
          (gelu): GE

In [5]:
# prepare single image tensor for input 
img_inference_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5)
])

img = Image.open(img_path).convert("RGB")

img = img_inference_transform(img)

# set batch_size as 1 using unsqueeze
img = img.unsqueeze(0)
print(f'shape of img: {img.shape}')

shape of img: torch.Size([1, 3, 32, 32])


In [6]:
# read the cifar10 classes.txt
class_list = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
class_dict = dict()
for idx, classname in enumerate(class_list):
    class_dict[idx] = classname

print(f'{class_dict}')

# define a helper function 
def tensor2cifarid(tensor): 
    # convert tensor to foodid 
    # by picking index that has max value 
    # use argmax 
    cifar_id = torch.argmax(tensor, dim=-1)
    return cifar_id

def cifarid2cifarname(cifarid):
    if cifarid < 0 or cifarid > max(class_dict.keys()):
        raise Exception("Invalid food Id for foodid2foodname function")
    return class_dict[cifarid]
    

{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}


In [7]:
def inference(model, img):
    with torch.no_grad():
        img = img.to(device)

        out = model(img)
        cifarid = tensor2cifarid(out)
        cifarname = cifarid2cifarname(cifarid.item())

        print(f'{cifarname}')
    return

In [8]:
inference(model, img)

cat
