In [1]:
%%capture 
%run config.ipynb
%run ViT.ipynb

In [2]:
img_path = data_dir / "cat.jpg"

In [3]:
import torch 
from torchvision import transforms
from PIL import Image

In [4]:
model = ViT(
    image_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    in_channels=IN_CHANNELS,
    n_head=N_HEAD, 
    d_model=D_MODEL, 
    ffn_hidden=FFN_HIDDEN, 
    mlp_hidden=MLP_HIDDEN, 
    n_layers=N_LAYERS, 
    class_num=CLASS_NUM, 
    device=device, 
    drop_prob=DROP_PROB,
)

# load model
model_path = model_dir / "model_68percent_acc"
model.load_state_dict(torch.load(model_path, weights_only=True, map_location=device))

model.eval()


ViT(
  (encoder): Encoder(
    (emb): TransformerEmbedding(
      (patch_emb): PatchEmbedding(
        (emb): Conv2d(3, 400, kernel_size=(4, 4), stride=(4, 4))
      )
      (pos_emb): PositionalEmbedding()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): ModuleList(
      (0-4): 5 x EncoderBlock(
        (norm): LayerNorm((400,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (multihead_attn): MultiheadAttentionBlock(
          (Wq): Linear(in_features=400, out_features=400, bias=True)
          (Wk): Linear(in_features=400, out_features=400, bias=True)
          (Wv): Linear(in_features=400, out_features=400, bias=True)
          (attention): SelfAttentionBlock(
            (softmax): Softmax(dim=-1)
          )
          (Wconcat): Linear(in_features=400, out_features=400, bias=True)
        )
        (ffn): FeedForwardBlock(
          (linear1): Linear(in_features=400, out_features=512, bias=True)
          (gelu): GE

In [5]:
# prepare single image tensor for input 
img_inference_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.Grayscale(1),
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5)
])

img = Image.open(img_path).convert("RGB")

img = img_inference_transform(img)

# set batch_size as 1 using unsqueeze
img = img.unsqueeze(0)
print(f'shape of img: {img.shape}')

shape of img: torch.Size([1, 1, 32, 32])


In [6]:
# read the cifar10 classes.txt
class_list = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
class_dict = dict()
for idx, classname in enumerate(class_list):
    class_dict[idx] = classname

print(f'{class_dict}')

# define a helper function 
def tensor2mnistid(tensor): 
    # convert tensor to foodid 
    # by picking index that has max value 
    # use argmax 
    mnist_id = torch.argmax(tensor, dim=-1)
    return mnist_id

def mnistid2mnistname(mnistid):
    if mnistid < 0 or mnistid > max(class_dict.keys()):
        raise Exception("Invalid food Id for foodid2foodname function")
    return class_dict[mnistid]
    

{0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9'}


In [7]:
def inference(model, img):
    with torch.no_grad():
        img = img.to(device)

        out = model(img)
        mnistid = tensor2mnistid(out)
        mnistname = mnistid2mnistname(mnistid.item())

        print(f'{mnistname}')
    return

In [8]:
inference(model, img)

RuntimeError: Given groups=1, weight of size [400, 3, 4, 4], expected input[1, 1, 32, 32] to have 3 channels, but got 1 channels instead