In [66]:
import cv2
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from typing import Tuple, List

In [None]:
def load_detection_model(model_type='yolo'):
    """Загрузка модели детекции"""
    if model_type == 'yolo':
        model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
    elif model_type == 'detr':
        from transformers import DetrImageProcessor, DetrForObjectDetection
        processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
        model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
        return model, processor
    return model, None

# Детекция
def detect_parrot(image, model_type='yolo'):
    """Детекция попугая на изображении"""
    try:
        if model_type == 'yolo':
            model, _ = load_detection_model(model_type)
            results = model([image])
            detections = results.pandas().xyxy[0]
            parrot_boxes = detections[detections['name'] == 'bird']
            if len(parrot_boxes) == 0:
                return None, "Parrot not detected"
            bbox = parrot_boxes.iloc[0][['xmin', 'ymin', 'xmax', 'ymax']].values.astype(int)
            return bbox.tolist(), 'yolo'
        
        elif model_type == 'detr':
            model, processor = load_detection_model(model_type)
            inputs = processor(images=image, return_tensors="pt")
            outputs = model(**inputs)
            target_sizes = torch.tensor([image.size[::-1]])
            results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0]
            bird_indices = [i for i, label in enumerate(results["labels"]) if label == 16]
            if not bird_indices:
                return None, "Parrot not detected"
            bbox = results["boxes"][bird_indices[0]].int().tolist()
            return bbox, 'detr'
    
    except Exception as e:
        raise RuntimeError(f"Detection error: {str(e)}")

# Обрезка
def crop_image(image, bbox):
    """Обрезка изображения по bounding box"""
    return image.crop(bbox)

# Классификация
class ParrotClassifier:
    def __init__(self, model_path='parrot_resnet.pth'):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = torch.load(model_path, map_location=self.device)
        self.model.eval()
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.classes = ['Ара', 'Какаду', 'Амазон', 'Не попугай']

    def predict(self, image):
        """Предсказание класса"""
        tensor = self.transform(image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            outputs = self.model(tensor)
            _, pred = torch.max(outputs, 1)
        return self.classes[pred.item()]

# Полный пайплайн
def parrot_pipeline(image_path, detection_model='yolo', classifier_model='parrot_resnet.pth'):
    """Полный пайплайн распознавания попугаев"""
    
    # 1. Загрузка изображения
    try:
        image = Image.open(image_path).convert('RGB')
    except Exception as e:
        raise ValueError(f"Error loading image: {str(e)}")
    
    # 2. Детекция
    bbox, detector_type = detect_parrot(image, detection_model)
    if bbox is None:
        return "Parrot not detected"
    
    # 3. Обрезка
    try:
        cropped_image = crop_image(image, bbox)
    except Exception as e:
        raise ValueError(f"Cropping error: {str(e)}")
    
    # 4. Классификация
    classifier = ParrotClassifier(classifier_model)
    try:
        result = classifier.predict(cropped_image)
    except Exception as e:
        raise RuntimeError(f"Classification error: {str(e)}")
    
    return f"Result: {result} (Detected with {detector_type.upper()})"