In [5]:
import torch
from PIL import Image
from torchvision import transforms
import os
import re

In [6]:
IMAGE_DIR = "images"
FOLDERS = ["happy", "sad", "relaxed", "angry"]

In [7]:
img_files = []

for folder in FOLDERS:
    fname = os.path.join(IMAGE_DIR, folder)
    
    for im in os.listdir(fname):
        impath = os.path.join(fname, im)
        img_files.append(impath)

In [37]:
with open("imagenet_classes.txt", "r") as f:
    categories = [re.sub('[^0-9a-zA-Z ]+', " ", s.lower().strip()) for s in f.readlines()]

In [40]:
with open("dog_labels.txt", "r") as f:
    dog_labels = [re.sub('[^0-9a-zA-Z ]+', " ", s.lower().strip()) for s in f.readlines()]

In [11]:
device = torch.device("mps")

In [28]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnext101_32x8d', pretrained=True)

Using cache found in /Users/vik/.cache/torch/hub/pytorch_vision_v0.10.0
Downloading: "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth" to /Users/vik/.cache/torch/hub/checkpoints/resnext101_32x8d-8ba56ff5.pth


  0%|          | 0.00/340M [00:00<?, ?B/s]

In [29]:
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1

In [65]:
def detect_dog(img_files):
    input_images = [Image.open(img_file) for img_file in img_files]

    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    input_tensor = preprocess(input_images)
    input_batch = input_tensor.unsqueeze(0)
    input_batch = input_batch.to(device)

    with torch.no_grad():
        output = model(input_batch)

    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    top5_prob, top5_catid = torch.topk(probabilities, 8)
    cats = set([categories[c] for c in top5_catid])
    dog_cats = cats.intersection(set(dog_labels))
    return len(dog_cats) > 0, cats

In [56]:
to_delete = []
to_keep = []

for img in img_files:
    dog, classes = detect_dog(img)
    if dog:
        to_keep.append(img)
    else:
        to_delete.append(img)

In [57]:
len(to_delete)

1884

In [58]:
len(to_keep)

4195

In [75]:
def remove_items(items):
    for item in items:
        try:
            os.remove(item)
        except FileNotFoundError:
            print(item)

In [76]:
remove_items(to_delete)