In [1]:
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
import torchvision as tv
import tqdm
import onnxruntime as ort
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from transformers import (
    ViTForImageClassification,
    ViTImageProcessor,
)
import os
from PIL import Image
from torchvision import transforms
from torchvision.transforms import Compose, Normalize, ToTensor, Resize
# import numpy as np


In [2]:
from FAST_classificator import ClassificatorONNX

In [3]:
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        
        self.id2label = {k: v for k, v in enumerate(sorted(os.listdir(root_dir)))}
        self.label2id = {v: k for k, v in self.id2label.items()}
        
        self.image_paths = []
        self.labels = []

        self.improcessor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
        
        self.size = self.improcessor.size["height"]
        self.normalize = Normalize(
            mean=self.improcessor.image_mean,
            std=self.improcessor.image_std
        )

        self._transforms = Compose([
            Resize((self.size, self.size)),
            ToTensor(),
            self.normalize
        ])

        for cls in self.id2label.values():
            cls_folder = os.path.join(root_dir, cls)
            if os.path.isdir(cls_folder):
                for img_name in os.listdir(cls_folder):
                    img_path = os.path.join(cls_folder, img_name)
                    self.image_paths.append(img_path)
                    self.labels.append(cls)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        return {
            "pixel_values": self.improcessor(
                images=Image.open(self.image_paths[idx]).convert("RGB")).pixel_values[0].squeeze(), # .squeeze()
            "labels": self.label2id[self.labels[idx]]
        }

In [4]:
import os
dataset = CustomImageDataset(root_dir="/home/user1/hack/train_data_rkn/dataset")

In [5]:
classifier = ClassificatorONNX("vit_v4.onnx")

In [6]:
classifier.predict_proba_class(Image.open(dataset.image_paths[1000]))

(np.int64(6), np.float32(0.86374795))