In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"


In [2]:
from src.utils_RCNN import *

train_dataset = WildlifeYOLODataset(
    images_dir="data/format_rgb/images/train",
    labels_dir="data/format_rgb/labels/train",
    transforms=get_transform(train=True),
)

val_dataset = WildlifeYOLODataset(
    images_dir="data/format_rgb/images/val",
    labels_dir="data/format_rgb/labels/val",
    transforms=get_transform(train=False),
)

def collate_fn(batch):
    return tuple(zip(*batch))

train_loader = DataLoader(
    train_dataset,
    batch_size=2,      # CPU/Mac → chiquito mejor
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=2,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn,
)


In [3]:
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

device = torch.device("cpu")
print("Entrenando en", device)


# modelo base
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")

# reemplazar el predictor para 3 clases + background
num_classes = 3 + 1  # Cow, Deer, Horse + background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

model.to(device)


Entrenando en cpu
Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to C:\Users\isiva/.cache\torch\hub\checkpoints\fasterrcnn_resnet50_fpn_coco-258fb6c6.pth


100.0%


FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [4]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005,
)

num_epochs = 10  # podés subir después

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0

    for images, targets in train_loader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        epoch_loss += losses.item()

    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")


KeyboardInterrupt: 

In [None]:
model.eval()

from torchvision.utils import draw_bounding_boxes
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt

# agarramos una imagen del dataset de validación
img, _ = val_dataset[0]  # img es tensor CxHxW
img_device = img.to(device)

with torch.no_grad():
    prediction = model([img_device])[0]

# filtramos por score
scores = prediction["scores"].cpu()
keep = scores > 0.5

boxes = prediction["boxes"][keep].cpu()
labels = prediction["labels"][keep].cpu()

# nombres de clases (recordá que 0 es background)
class_names = ["__background__", "Cow", "Deer", "Horse"]
label_names = [class_names[i] for i in labels]

img_uint8 = (img * 255).to(torch.uint8)
img_boxes = draw_bounding_boxes(img_uint8, boxes, labels=label_names)

plt.imshow(F.to_pil_image(img_boxes))
plt.axis("off")
plt.show()


NotImplementedError: The operator 'torchvision::nms' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on https://github.com/pytorch/pytorch/issues/141287 and mention use-case, that resulted in missing op as well as commit hash Unknown. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.