# Uso de la red neuronal para la detección de monos


## 1. Dependecias

In [1]:
!pip install torch torchvision pillow huggingface_hub




## 2. Descagar modelo y labels

In [2]:
from pathlib import Path
from huggingface_hub import snapshot_download
import os

HF_TOKEN = os.getenv("HF_TOKEN_READ")
assert HF_TOKEN, "⚠️ Debes definir primero la variable de entorno HF_TOKEN_READ"

REPO_ID = "Barearojojuan/monkey-classifier-pytorch"
local_dir = Path(snapshot_download(
    repo_id=REPO_ID,
    allow_patterns=["*.pth", "labels.txt"],
    token=HF_TOKEN
))

WEIGHTS = local_dir / "monkey_classifier_v0.1.pth"
LABELS = local_dir / "labels.txt"
print("Modelo:", WEIGHTS.exists(), "Labels:", LABELS.exists())


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


labels.txt:   0%|          | 0.00/180 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


monkey_classifier_v0.1.pth:   0%|          | 0.00/44.8M [00:00<?, ?B/s]

Modelo: True Labels: True


## 3. Definir clases y transformaciones

In [3]:
import torchvision.transforms as T
import torch

# Clases desde labels.txt
if LABELS.exists():
    classes = [l.strip() for l in LABELS.read_text().splitlines() if l.strip()]
else:
    classes = [
        "Mantled_howler","Patas_monkey","Bald_uakari","Japanese_macaque","Pygmy_marmoset",
        "White_headed_capuchin","Silvery_marmoset","Common squirrel_monkey",
        "Black_headed_night_monkey","Nilgiri_langur"
    ]

# Transformaciones (mismas que en tu código)
mean = [0.4363, 0.4328, 0.3291]
std  = [0.2129, 0.2075, 0.2038]

transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(torch.tensor(mean), torch.tensor(std))
])


## 4. Cargar modelo

In [5]:
import torch
import torchvision.models as models
from collections.abc import Mapping

def load_model(weights_path, num_classes):
    # 1) Intentar TorchScript
    try:
        m = torch.jit.load(str(weights_path), map_location="cpu")
        m.eval()
        print("[INFO] Cargado como TorchScript")
        return m
    except Exception as e:
        print("[INFO] No es TorchScript:", e.__class__.__name__)

    # 2) Cargar el checkpoint con torch.load
    obj = torch.load(str(weights_path), map_location="cpu")

    # 2a) Si es un nn.Module pickled -> úsalo directamente
    if isinstance(obj, torch.nn.Module):
        obj.eval()
        print(f"[INFO] Checkpoint es nn.Module pickled: {obj.__class__.__name__}")
        return obj

    # 2b) Si es un dict, puede ser state_dict o contenerlo
    if isinstance(obj, Mapping):
        # A veces viene el módulo en 'model'
        if "model" in obj and isinstance(obj["model"], torch.nn.Module):
            m = obj["model"]
            m.eval()
            print(f"[INFO] Encontrado nn.Module en 'model': {m.__class__.__name__}")
            return m
        # Si no, asumir que es state_dict (posiblemente en 'state_dict')
        sd = obj.get("state_dict", obj)
    else:
        raise TypeError(f"Tipo de checkpoint inesperado: {type(obj)}")

    # 3) Probar backbones comunes con state_dict
    backbones = [
        ("resnet18", models.resnet18(weights=None)),
        ("resnet34", models.resnet34(weights=None)),
        ("efficientnet_b0", models.efficientnet_b0(weights=None)),
    ]

    last_err = None
    for name, model in backbones:
        try:
            if name.startswith("resnet"):
                in_features = model.fc.in_features
                model.fc = torch.nn.Linear(in_features, num_classes)
            else:  # efficientnet
                in_features = model.classifier[-1].in_features
                model.classifier[-1] = torch.nn.Linear(in_features, num_classes)

            missing, unexpected = model.load_state_dict(sd, strict=False)
            print(f"[INFO] Cargado con {name}: missing={len(missing)}, unexpected={len(unexpected)}")
            model.eval()
            return model
        except Exception as e:
            last_err = e
            print(f"[WARN] Falló con {name}: {e}")

    raise RuntimeError(f"No se pudo cargar el modelo con los intentos realizados. Último error: {last_err}")

model = load_model(WEIGHTS, len(classes))



[INFO] No es TorchScript: RuntimeError
[INFO] Checkpoint es nn.Module pickled: ResNet


  obj = torch.load(str(weights_path), map_location="cpu")
