In [None]:
import cv2
import numpy as np
import urllib
from super_gradients.common.object_names import Models
from super_gradients.training import models

class ObjectDetector:
    def __init__(self, model):
        self.model = model

    def process_frame(self, frame, confidence_threshold, filter_label):
        predictions = self.model.predict(frame, conf=confidence_threshold)
        frame_predictions = []

        if not isinstance(predictions, list):
            predictions = [predictions]

        for prediction in predictions:
            class_names = prediction.class_names
            labels = prediction.prediction.labels
            confidence = prediction.prediction.confidence
            bboxes = prediction.prediction.bboxes_xyxy

            for label, conf, bbox in zip(labels, confidence, bboxes):
                if class_names[int(label)] == filter_label and conf >= confidence_threshold:
                    frame_predictions.append({
                        "class_name": class_names[int(label)],
                        "confidence": conf,
                        "bbox": bbox
                    })

                    xmin, ymin, xmax, ymax = map(int, bbox)
                    cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)
                    cv2.putText(frame, f'{filter_label.capitalize()}: {conf:.2f}', (xmin, ymin - 10),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
        return frame, frame_predictions


    def detect_objects(self, input_source, output_path=None, max_frames=None, confidence_threshold=0.2, filter_label=None):
        if isinstance(input_source, str):
            if input_source.startswith('http'):
                stream = urllib.request.urlopen(input_source)
                bytes = bytearray()
                frame = None  # Inicializa a variável frame com None
                while True:
                    bytes += stream.read(1024)
                    a = bytes.find(b'\xff\xd8')
                    b = bytes.find(b'\xff\xd9')
                    if a != -1 and b != -1:
                        jpg = bytes[a:b + 2]
                        bytes = bytes[b + 2:]
                        frame = cv2.imdecode(np.frombuffer(jpg, dtype=np.uint8), cv2.IMREAD_COLOR)
                        break
                frame_height, frame_width, _ = frame.shape  # Inicializa as variáveis frame_width e frame_height
            else:
                cap = cv2.VideoCapture(input_source)
                frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        else:
            cap = cv2.VideoCapture(input_source)
            frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        if output_path is not None:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            out = cv2.VideoWriter(output_path, fourcc, 20.0, (frame_width, frame_height))
            output_is_image = False
        else:
            output_is_image = True

        frame_count = 0
        while True:
            if isinstance(input_source, str) and input_source.startswith('http'):
                bytes += stream.read(1024)
                a = bytes.find(b'\xff\xd8')
                b = bytes.find(b'\xff\xd9')
                if a != -1 and b != -1:
                    jpg = bytes[a:b + 2]
                    bytes = bytes[b + 2:]
                    frame = cv2.imdecode(np.frombuffer(jpg, dtype=np.uint8), cv2.IMREAD_COLOR)

                    frame_processed, _ = self.process_frame(frame, confidence_threshold, filter_label)
                    out.write(frame_processed)

                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break
            else:
                ret, frame = cap.read()
                if not ret:
                    break

                frame_processed, _ = self.process_frame(frame, confidence_threshold, filter_label)

                if output_is_image:
                    cv2.imshow('Live Object Detection', frame_processed)
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break
                else:
                    out.write(frame_processed)
                    frame_count += 1
                    if max_frames is not None and frame_count >= max_frames:
                        break

        if not output_is_image:
            cap.release()
            out.release()
            print("Vídeo com previsões salvo com sucesso em:", output_path)
        else:
            cv2.destroyAllWindows()


# Carregue o modelo YOLO-NAS-L
model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")

# Instancie o ObjectDetector com o modelo carregado
detector = ObjectDetector(model)

# Fonte de entrada (pode ser um número de dispositivo de webcam ou o nome de um arquivo de vídeo)
# input_source = 0  # 0 para webcam, ou o caminho para um arquivo de vídeo
input_source =  '/kaggle/input/video/videoplayback (2).mp4'

# Caminho de saída (opcional, deixe como None para visualização em tempo real)
output_path = "/kaggle/working/saida_com_previsoes.mp4"  # ou None para visualização em tempo real

# Número máximo de quadros a serem processados (opcional)
max_frames = 1000  # Processará apenas os primeiros 1000 quadros

# Limiar de confiança para detecção de objetos (opcional)
confidence_threshold = 0.2

# Label do objeto que deseja filtrar (opcional)
filter_label = 'person'

# Chame o método detect_objects para processar a entrada e salvar o vídeo de saída (se aplicável)
detector.detect_objects(input_source, output_path, max_frames=max_frames, confidence_threshold=confidence_threshold, filter_label=filter_label)


 It is your responsibility to determine whether you have permission to use the models for your use case.
 The model you have requested was pre-trained on the coco dataset, published under the following terms: https://cocodataset.org/#termsofuse
[2024-05-11 15:19:49] INFO - checkpoint_utils.py - License Notification: YOLO-NAS pre-trained weights are subjected to the specific license terms and conditions detailed in 
https://github.com/Deci-AI/super-gradients/blob/master/LICENSE.YOLONAS.md
By downloading the pre-trained weight files you agree to comply with these terms.
[2024-05-11 15:19:50] INFO - checkpoint_utils.py - Successfully loaded pretrained weights for architecture yolo_nas_l
[2024-05-11 15:19:50] INFO - pipelines.py - Fusing some of the model's layers. If this takes too much memory, you can deactivate it by setting `fuse_model=False`
[2024-05-11 15:19:51] INFO - pipelines.py - Fusing some of the model's layers. If this takes too much memory, you can deactivate it by setting `f