In [None]:
from transformers import ViTForImageClassification

temp_model_dir = "./vit-base-beans"
ckpt = "google/vit-base-patch16-224"

model = ViTForImageClassification.from_pretrained(ckpt)
model.save_pretrained(temp_model_dir, saved_model=True)

In [None]:
print(model)

In [9]:
from transformers import AutoImageProcessor

processor = AutoImageProcessor.from_pretrained(ckpt)
# print(processor)

ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}



In [13]:
import torch

import base64
import torch
from PIL import Image
from io import BytesIO
import torchvision.transforms as transforms

CONCRETE_INPUT = "pixel_values" # Which is what we investigated via the SavedModel CLI.
SIZE = processor.size["height"]

def normalize_img(img, mean=processor.image_mean, std=processor.image_std):
    # Assuming img is a PyTorch tensor of shape (C, H, W) and values in [0, 255]
    img = img / 255.0  # Scale to [0, 1]
    mean = torch.tensor(mean).view(-1, 1, 1)  # Reshape to (C, 1, 1) for broadcasting
    std = torch.tensor(std).view(-1, 1, 1)  # Reshape to (C, 1, 1) for broadcasting
    return (img - mean) / std


def preprocess(string_input):
    # Decode base64 and JPEG
    decoded_input = base64.b64decode(string_input)
    img = Image.open(BytesIO(decoded_input)).convert("RGB")
    
    # Apply transformations
    transform_pipeline = transforms.Compose([
        transforms.Resize((SIZE, SIZE)),
        transforms.ToTensor(),  # This also scales pixels to [0, 1]
        normalize_img,
        lambda x: x.permute(2, 0, 1)  # Channel-first if needed, but ToTensor already does this
    ])
    normalized = transform_pipeline(img)
    return normalized

def preprocess_fn(string_inputs):
    # Assume string_inputs is a list of base64 encoded images
    decoded_images = torch.stack([preprocess(string_input) for string_input in string_inputs])
    return {CONCRETE_INPUT: decoded_images}
