In [103]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import torch.nn as nn
import torchvision.models as models
import pyheif
import os
import os.path
import math
from sklearn import neighbors
import pickle
import numpy as np
from torchvision.utils import save_image
import torchvision

In [22]:
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}

In [12]:
resnet = models.resnet50(pretrained=True)

In [117]:
def get_image_fetures_mask(image_paht):
    
    img = Image.open(image_paht).convert('RGB')

    # Definir las transformaciones necesarias para el modelo
    transform = transforms.Compose([
        transforms.Resize(256),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    img = transform(img).unsqueeze(0)

# Cargar el modelo pre-entrenado de Mask R-CNN
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    model.eval()

    # Pasar la imagen por el modelo y obtener las características del objeto detectado
    with torch.no_grad():
        predictions = model(img)
        # print(predictions)
        if len(predictions[0]['boxes']) == 0:
            # No se detectaron objetos en la imagen
            return None
        features = predictions[0]
        print(features)

    # print(features)

    

In [87]:
def get_image_features(image_path):
    # Cargar la imagen
    img = Image.open(image_path).convert('RGB')

    # Definir las transformaciones necesarias para el modelo
    transform = transforms.Compose([
        transforms.Resize(256),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    img_transformed = transform(img)
    save_image(img_transformed, 'transformed_image.jpg')
    # Aplicar las transformaciones y agregar una dimensión adicional (batch_size=1)
    img = transform(img).unsqueeze(0)

    # Pasar la imagen por el modelo y obtener el vector característico
    features = resnet(img)
    features = features.detach().numpy()[0]

    return features

In [15]:
def converHEICtoJPEG(image_path , output):
    with open(image_path, 'rb') as f:
        img = pyheif.read(f)

    # Convierte la imagen HEIC a RGB
    rgb_image = Image.frombytes(
        img.mode, 
        img.size, 
        img.data,
        "raw",
        img.mode,
        img.stride,
    )
    # Guarda la imagen RGB en formato JPEG
    rgb_image.save(output, format='JPEG')

In [88]:
def train(train_dir, model_save_path=None,  n_neighbors=None, knn_algo='ball_tree', verbose=False):
    X = []
    y = []
    for class_dir in os.listdir(train_dir):
        if not os.path.isdir(os.path.join(train_dir, class_dir)):
            continue

        for img_file in os.listdir(os.path.join(train_dir, class_dir)):
            # Obtener el path completo de la imagen
            img_path = os.path.join(train_dir, class_dir, img_file)

            # Extraer el vector característico de la imagen
            features = get_image_features(img_path)
            features = features.reshape(1, -1)
            # Agregar el vector característico y la etiqueta al conjunto de datos
            X.append(features[0])
            y.append(class_dir)

    if n_neighbors is None:
        n_neighbors = int(round(math.sqrt(len(X))))
        if verbose:
            print("Chose n_neighbors automatically:", n_neighbors)

    # Create and train the KNN classifier
    knn_clf = neighbors.KNeighborsClassifier(n_neighbors=n_neighbors, algorithm=knn_algo, weights='distance')
    knn_clf.fit(X, y)

    # Save the trained KNN classifier
    if model_save_path is not None:
        with open(model_save_path, 'wb') as f:
            pickle.dump(knn_clf, f)

    return knn_clf

In [89]:
def predict(X_img_path, knn_clf=None, model_path=None, distance_threshold=0.6):
    if not os.path.isfile(X_img_path) or os.path.splitext(X_img_path)[1][1:] not in ALLOWED_EXTENSIONS:
        raise Exception("Invalid image path: {}".format(X_img_path))
    if knn_clf is None and model_path is None:
        raise Exception("Must supply knn classifier either thourgh knn_clf or model_path")
    
    if knn_clf is None:
        with open(model_path, 'rb') as f:
            knn_clf = pickle.load(f)

    X_img = get_image_features(X_img_path)
    X_img = X_img.reshape(1, -1)
    closest_distances = knn_clf.kneighbors(X_img, n_neighbors=3)
    print(closest_distances)
    are_matches = closest_distances[0][0][0] <= distance_threshold
    predit = knn_clf.predict(X_img)
    print(predit)
    # print(closest_distances)
    # print(closest_distances[0][0][0])
    # print(closest_distances)
    # predit = knn_clf.predict(X_img)
    # print(predit)
    # are_matches = closest_distances[0][0][0] <= distance_threshold


In [95]:
print("Training KNN classifier...")
classifier = train("data/train", model_save_path="trained_knn_model.clf", n_neighbors=2)
print("Training complete!")

# Test the model and check the accuracy of KNN classifier


Training KNN classifier...
Training complete!


In [91]:
predict("data/test/image1.jpg", model_path="trained_knn_model.clf", distance_threshold=0.6)

(array([[3.09587466, 3.10638987, 3.16593791]]), array([[1, 0, 2]]))
['Schwartz_Jodi 3.21.23_lgorbitvu_21_BATCHED']


In [72]:
converHEICtoJPEG("Dataset/Ensembles/IMG_5482.HEIC", "data/test/image2.jpg")

In [92]:
predict("data/test/image2.jpg", model_path="trained_knn_model.clf", distance_threshold=0.6)

(array([[2.82101575, 2.87046412, 2.90322464]]), array([[1, 2, 3]]))
['Schwartz_Jodi 3.21.23_lgorbitvu_21_BATCHED']


In [73]:
converHEICtoJPEG("Dataset/Sneakers/IMG_5459.HEIC", "data/test/image3.jpg")

In [94]:
predict("data/test/image4.jpg", model_path="trained_knn_model.clf", distance_threshold=0.6)

(array([[0.        , 2.96609496, 3.19797174]]), array([[5, 3, 2]]))
['Berchtold_Marvelle 3.20.23 Rack 2.1_lgorbitvu_7_BATCHED']


In [118]:
get_image_fetures_mask("data/test/image1.jpg")

KeyError: 'maskrcnn_bbox_features'

In [100]:
import torch
print(torch.hub.get_dir())

/home/jorge/.cache/torch/hub


In [101]:
import shutil
import torch

cache_dir = torch.hub.get_dir()
shutil.rmtree(cache_dir)