In [None]:
from ultralytics import YOLO
from PIL import Image
import cv2
from transformers import AutoImageProcessor, ResNetModel
from torchvision.transforms.functional import pil_to_tensor
import torch
import numpy as np

In [None]:
seg_model = YOLO('../segmentation/dronuniver_yolov8nseg.pt')
emb_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
emb_model = ResNetModel.from_pretrained("mask_recognition")

In [None]:
def get_pred(img_path):
    img = Image.open(img_path)
    return seg_model(img)[0]

In [None]:
def get_boxes(preds):
    return preds.boxes.xywh.int()

In [None]:
def get_masks(preds):
    return [mask.astype(int) for mask in preds.masks.xy]

In [None]:
def get_crops(img_path, preds):
    img = cv2.imread(img_path)
    boxes = get_boxes(preds)
    masks = get_masks(preds)
    crops = []
    for i, box in enumerate(boxes):
        x,y,w,h = box
        crop = img[y:y+h, x:x+w].copy()
        masks[i] = masks[i] - masks[i].min(axis=0)
        mask = np.zeros(crop.shape[:2], np.uint8)
        cv2.drawContours(mask, [masks[i]], -1, 255, -1)
        crops.append(cv2.bitwise_and(crop, crop, mask=mask))
    return crops

In [None]:
def get_embeddings(crops):
    embeddings = []
    for crop in crops:
        img = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
        im_pil = Image.fromarray(img)
        im_pil = torch.clamp(pil_to_tensor(im_pil.resize((32, 32))) / 255, 0, 1)
        inputs = emb_processor(im_pil, return_tensors="pt")
        with torch.no_grad():
            outputs = emb_model(**inputs).pooler_output
        embeddings.append(outputs)
    return embeddings

In [None]:
pred = get_pred('ф.png')

In [None]:
boxes = get_boxes(pred)
boxes

In [None]:
masks = get_masks(pred)
masks

In [None]:
crops = get_crops('ф.png', pred)

In [None]:
emb = get_embeddings(crops)

In [None]:
emb