In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import ViTImageProcessor
import os
from PIL import Image
from torchvision import transforms
from torchvision.transforms import Compose, Normalize, ToTensor, Resize

from app.slow_classificator import ResClassifier, Classificator, VitClassifier

dataset = CustomImageDataset(root_dir="/home/user1/hack/train_data_rkn/dataset")
# train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True,num_workers=4)
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        
        self.id2label = {k: v for k, v in enumerate(sorted(os.listdir(root_dir)))}
        self.label2id = {v: k for k, v in self.id2label.items()}
        
        self.image_paths = []
        self.labels = []

        self.improcessor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
        
        self.size = self.improcessor.size["height"]
        self.normalize = Normalize(
            mean=self.improcessor.image_mean,
            std=self.improcessor.image_std
        )

        self._transforms = Compose([
            Resize((self.size, self.size)),
            ToTensor(),
            self.normalize
        ])

        for cls in self.id2label.values():
            cls_folder = os.path.join(root_dir, cls)
            if os.path.isdir(cls_folder):
                for img_name in os.listdir(cls_folder):
                    img_path = os.path.join(cls_folder, img_name)
                    self.image_paths.append(img_path)
                    self.labels.append(cls)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        return {
            "pixel_values": self.improcessor(
                images=Image.open(self.image_paths[idx]).convert("RGB")).pixel_values[0].squeeze(), # .squeeze()
            "labels": self.label2id[self.labels[idx]]
        }

In [52]:
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
import torchvision as tv
import tqdm
import onnxruntime as ort
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset, random_split
from torchmetrics.classification import AveragePrecision
from transformers import (
    DefaultDataCollator,
    Trainer,
    TrainerCallback,
    TrainingArguments,
    ViTForImageClassification,
    ViTImageProcessor,
)
from huggingface_hub import hf_hub_download


import os
from PIL import Image
from torchvision import transforms
from torchvision.transforms import Compose, Normalize, ToTensor, Resize
# import numpy as np




class ClassificatorONNX:
    def __init__(self, model_path: str="vit_v4.onnx", device="cpu"):
        self.device = device
        hf_hub_download(repo_id="alan3333/hack_rkn_onnx", filename=model_path, local_dir="./")

        self.session = ort.InferenceSession(model_path)
        
        self.image_processor = tv.transforms.Compose(
            [
                tv.transforms.Resize((224, 224)),
                tv.transforms.ToTensor(),
                tv.transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ]
        )

    def _transform(self, image: Image.Image) -> np.ndarray:
        processed_image = self.image_processor(image).unsqueeze(0)
        return processed_image.numpy().astype(np.float32)

    def predict_proba_class(self, image: Image.Image) -> List[np.ndarray]:
        processed_image = self._transform(image)
    
        outputs = self.session.run(None, {"pixel_values": processed_image})
        probabilities = (np.exp(outputs)/np.sum(np.exp(outputs)))[0][0]
        predicted_class = probabilities.argmax()
       
        return int(predicted_class), float(probabilities[predicted_class])

    def predict(self, image: Image.Image) -> int:
        return self.predict_proba_class(image)[0]

    def predict_embedding(self, image: Image.Image) -> np.ndarray:
        processed_image = self._transform(image)
        outputs = self.session.run(None, {"pixel_values": processed_image})
  
        embedding = outputs[0][0]
        return embedding.squeeze().tolist()

    def predict_result(self, image: Image.Image) -> Dict[str, np.ndarray]:
        predicted_class, probabilities = self.predict_proba_class(image)
        embedding = self.predict_embedding(image)

        return {
            "class": predicted_class,
            "probs_class": probabilities,
            "embedding": embedding,
        }


classificator_instance = ClassificatorONNX()

In [53]:
classifier = ClassificatorONNX("vit_v4.onnx")

In [54]:
classifier.predict_proba_class(Image.open(dataset.image_paths[1000]))

(6, 0.8637479543685913)

In [56]:
classifier.predict_result(Image.open(dataset.image_paths[1000]))

{'class': 6,
 'probs_class': 0.8637479543685913,
 'embedding': [-0.14666026830673218,
  -0.02052750252187252,
  -0.6322766542434692,
  0.08765003085136414,
  -0.12225610017776489,
  -0.04616245627403259,
  6.7528791427612305,
  -0.1726534515619278,
  2.1956520080566406,
  0.17181754112243652,
  -0.7886292338371277,
  -0.28554877638816833,
  -0.04729550704360008,
  -0.9365891218185425,
  -0.01777680218219757,
  0.2869463562965393,
  0.38678357005119324,
  -0.13066191971302032,
  0.6296558380126953,
  -0.3088721036911011,
  0.05625346302986145,
  0.3677195906639099,
  -0.722286581993103,
  -0.1251567155122757,
  0.4465736150741577,
  1.6457087993621826,
  -0.09712395071983337,
  -0.700269877910614,
  0.2704384922981262,
  -0.2038898468017578,
  -0.45533400774002075,
  0.31055569648742676,
  -0.8964690566062927,
  -0.7936358451843262,
  -0.14049625396728516,
  -0.37208789587020874,
  -0.6660160422325134,
  0.3401380777359009,
  -1.1864463090896606,
  -0.13739338517189026,
  2.505615949630

In [30]:
hf_hub_download(repo_id="alan3333/hack_rkn_onnx", filename="README.md", local_dir="./")


README.md:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

'README.md'

In [60]:

import json 
json.dumps({k: v for k, v in enumerate(sorted(os.listdir("/home/user1/hack/train_data_rkn/dataset")))})


'{"0": ".DS_Store", "1": "Accordion", "2": "Adhesive tape", "3": "Aircraft", "4": "Airplane", "5": "Alarm clock", "6": "Alpaca", "7": "Ambulance", "8": "Animal", "9": "Ant", "10": "Apple", "11": "Artichoke", "12": "Banana", "13": "Barge", "14": "Bathtub", "15": "Belt", "16": "Binoculars", "17": "Bottle", "18": "Bow and arrow", "19": "Bread", "20": "Briefcase", "21": "Broccoli", "22": "Camera", "23": "Cannon", "24": "Cassette deck", "25": "Cat", "26": "Cello", "27": "Christmas tree", "28": "Coin", "29": "Common fig", "30": "Cosmetics", "31": "Cucumber", "32": "Cutting board", "33": "Earrings", "34": "Elephant", "35": "Fedora", "36": "Flashlight", "37": "Frying pan", "38": "Glasses", "39": "Glove", "40": "Goat", "41": "Goldfish", "42": "Grape", "43": "Harp", "44": "Hat", "45": "Helmet", "46": "High heels", "47": "Hippopotamus", "48": "Honeycomb", "49": "Horse", "50": "Insect", "51": "Invertebrate", "52": "Ipod", "53": "Isopod", "54": "Jacket", "55": "Jet ski", "56": "Koala", "57": "Land 