In [5]:
import torch
import joblib
from torchvision import models, transforms
from PIL import Image
import os
import json
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

In [6]:
# Загрузка моделей
def load_models():
    # yolo_v8_model = YOLO('yolov8n.pt')
    faster_rcnn_model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True).eval()
    # detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").eval()
    # return yolo_v8_model, faster_rcnn_model, detr_model
    return faster_rcnn_model


# COCO метки
# coco_labels = {
#     2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus', 7: 'train', 8: 'truck',
#     9: 'boat', 10: 'traffic light', 11: 'fire hydrant', 13: 'stop sign', 14: 'parking meter', 15: 'bench',
#     16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear',
#     24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag', 32: 'tie', 33: 'suitcase',
#     34: 'frisbee', 35: 'skis', 36: 'snowboard', 37: 'sports ball', 38: 'kite', 39: 'baseball bat',
#     40: 'baseball glove', 41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle', 46: 'wine glass',
#     47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon', 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich',
#     55: 'orange', 56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut', 61: 'cake', 62: 'chair',
#     63: 'couch', 64: 'potted plant', 65: 'bed', 66: 'remote',  67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop',
#     74: 'mouse', 75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven', 80: 'toaster',
#     81: 'sink', 82: 'refrigerator', 84: 'book', 85: 'clock', 86: 'vase', 87: 'scissors', 88: 'teddy bear',
#     89: 'hair drier', 90: 'toothbrush'
# }

coco_labels = {
    1: 'person',
    2: 'bicycle',
    3: 'car',
    4: 'motorcycle',
    5: 'airplane',
    6: 'bus',
    7: 'train',
    8: 'truck',
    9: 'boat',
    10: 'traffic light',
    11: 'fire hydrant',
    13: 'stop sign',
    14: 'parking meter',
    15: 'bench',
    16: 'bird',
    17: 'cat',
    18: 'dog',
    19: 'horse',
    20: 'sheep',
    21: 'cow',
    22: 'elephant',
    23: 'bear',
    24: 'zebra',
    25: 'giraffe',
    27: 'backpack',
    28: 'umbrella',
    31: 'handbag',
    32: 'tie',
    33: 'suitcase',
    34: 'frisbee',
    35: 'skis',
    36: 'snowboard',
    37: 'sports ball',
    38: 'kite',
    39: 'baseball bat',
    40: 'baseball glove',
    41: 'skateboard',
    42: 'surfboard',
    43: 'tennis racket',
    44: 'bottle',
    46: 'wine glass',
    47: 'cup',
    48: 'fork',
    49: 'knife',
    50: 'spoon',
    51: 'bowl',
    52: 'banana',
    53: 'apple',
    54: 'sandwich',
    55: 'orange',
    56: 'broccoli',
    57: 'carrot',
    58: 'hot dog',
    59: 'pizza',
    60: 'donut',
    61: 'cake',
    62: 'chair',
    63: 'couch',
    64: 'potted plant',
    65: 'bed',
    67: 'dining table',
    70: 'toilet',
    72: 'tv',
    73: 'laptop',
    74: 'mouse',
    75: 'remote',
    76: 'keyboard',
    77: 'cell phone',
    78: 'microwave',
    79: 'oven',
    80: 'toaster',
    81: 'sink',
    82: 'refrigerator',
    84: 'book',
    85: 'clock',
    86: 'vase',
    87: 'scissors',
    88: 'teddy bear',
    89: 'hair drier',
    90: 'toothbrush'
}


# Функция для обработки изображения через Faster R-CNN
def detect_objects_faster(faster_rcnn_model, image_path):
    image = Image.open(image_path).convert("RGB")
    image_tensor = transforms.ToTensor()(image).unsqueeze(0)
    
    with torch.no_grad():
        faster_rcnn_results = faster_rcnn_model(image_tensor)[0]

    all_results = []
    for i in range(len(faster_rcnn_results["boxes"])):
        box = faster_rcnn_results["boxes"][i].tolist()
        label = coco_labels.get(int(faster_rcnn_results["labels"][i].item()), "unknown")
        score = faster_rcnn_results["scores"][i].item()
        all_results.append({"box": box, "label": label, "score": score, "source": "Faster R-CNN"})
    
    return all_results

# Функция для обработки всех изображений в заданных директориях
def process_images_in_directories(paths, model, output_dir):
    for path in paths:
        full_res = []
        for dirpath, dirnames, filenames in os.walk(path):
            for dirname in dirnames:
                dir_path = os.path.join(dirpath, dirname)
                if 'desc' not in dir_path:
                    current_dir_data = []
                    for _, _, filenames in os.walk(dir_path):
                        for filename in filenames:
                            file = os.path.join(dir_path, filename)
                            try:
                                current_dir_data.append(detect_objects_faster(model, file))
                            except Exception as e:
                                print(f"Error processing {file}: {e}")
                    full_res.append(current_dir_data)
        name = path.rsplit('/', 1)[-1]
        save_results(full_res, name, output_dir)

# Функция для сохранения результатов
def save_results(data, name, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    with open(f'{output_dir}/full_{name}.txt', 'w') as fw:
        json.dump(data, fw)

# Фильтрация данных без изображений людей
def filter_data(input_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True) 
    for dirpath, dirnames, filenames in os.walk(input_dir):
        for filename in filenames:
            objs = []
            path = os.path.join(dirpath, filename)
            with open(path, 'r') as fr:
                lst = json.load(fr)
                f = []
                for i in lst:
                    objs = []
                    for b in i:
                        for c in b:
                            objs.append(c['label'])
                    f.append(objs)
            name = path.split('full_')[1]
            with open(os.path.join(output_dir, name), 'w') as fw:
                json.dump(f, fw)

# Загрузка данных и обучение модели
def load_data(directory_path):
    data = []
    labels = []
    for category in os.listdir(directory_path):
        category_path = os.path.join(directory_path, category)
        if os.path.isfile(category_path):
            with open(category_path, 'r') as file:
                lists_of_words = json.load(file)
                for word_list in lists_of_words:
                    data.append(' '.join(word_list))
                    labels.append(category)
    return data, labels

# Обучение модели классификации
def train_classifier(data, labels):
    vectorizer = TfidfVectorizer()
    X = vectorizer.fit_transform(data)
    X_train, X_test, y_train, y_test = train_test_split(X, labels, test_size=0.2, random_state=42)
    
    random_forest_model = RandomForestClassifier(n_estimators=100, random_state=42)
    random_forest_model.fit(X_train, y_train)
    
    y_pred_rf = random_forest_model.predict(X_test)
    print("Random Forest Accuracy:", accuracy_score(y_test, y_pred_rf))
    
    return random_forest_model, vectorizer

# Предсказание категории для нового списка объектов
def predict_category(word_list, model, vectorizer):
    word_list_joined = ' '.join(word_list)
    vectorized_word_list = vectorizer.transform([word_list_joined])
    return model.predict(vectorized_word_list)[0]

# Функция для сохранения модели и векторизатора
def save_model(model, vectorizer, model_filename, vectorizer_filename):
    joblib.dump(model, model_filename)
    joblib.dump(vectorizer, vectorizer_filename)
    print(f"Model saved as {model_filename} and vectorizer saved as {vectorizer_filename}")

# Функция для загрузки модели и векторизатора
def load_model(model_filename, vectorizer_filename):
    model = joblib.load(model_filename)
    vectorizer = joblib.load(vectorizer_filename)
    print(f"Model loaded from {model_filename} and vectorizer loaded from {vectorizer_filename}")
    return model, vectorizer

In [7]:
# Основная функция пайплайна
def main_pipeline():
    # Шаг 1: Загрузка моделей
    faster_rcnn_model = load_models()

    # Шаг 2: Пути к изображениям
    paths = [
       '/Users/diana/PycharmProjects/OZON/dataset/test'
    ]
    
    output_dir = 'all_data'
    output_dir_last = 'clean_data'

   #  # Шаг 3: Обработка изображений
   #  process_images_in_directories(paths, faster_rcnn_model, output_dir)
    
    # Шаг 4: Фильтрация данных
    filter_data(output_dir, output_dir_last)
    
    # Шаг 5: Загрузка данных и обучение модели
    data, labels = load_data(output_dir_last)
    random_forest_model, vectorizer = train_classifier(data, labels)

    # Шаг 6: Сохранение модели после обучения
    save_model(random_forest_model, vectorizer, 'random_forest_model.pkl', 'vectorizer.pkl')
    
    # Шаг 7: Загрузка модели
    random_forest_model, vectorizer = load_model('random_forest_model.pkl', 'vectorizer.pkl')


    # Шаг 8: Предсказание категории для нового списка объектов
    new_word_list = ["vase", "dining table", "tv", "chair", "laptop", "clock", "umbrella", "book", "cake", "potted plant"]
    predicted_category_rf = predict_category(new_word_list, random_forest_model, vectorizer)
    print("Predicted category (Random Forest):", predicted_category_rf)

In [8]:
if __name__ == "__main__":
    main_pipeline()



Random Forest Accuracy: 0.8541353383458646
Model saved as random_forest_model.pkl and vectorizer saved as vectorizer.pkl
Model loaded from random_forest_model.pkl and vectorizer loaded from vectorizer.pkl
Predicted category (Random Forest): стол.txt
