In [1]:
import torch
from PIL import Image
import os
import re
from settings import *

In [2]:
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

model = timm.create_model('convnext_small_384_in22ft1k', pretrained=True)
model.eval()

config = resolve_data_config({}, model=model)
transform = create_transform(**config)

In [3]:
transform

Compose(
    Resize(size=384, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(384, 384))
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)

In [4]:
device = torch.device(DEVICE)

In [None]:
model.to(device)

In [6]:
img_files = []

for folder in FOLDERS:
    fname = os.path.join(IMAGE_DIR, folder)
    
    for im in os.listdir(fname):
        impath = os.path.join(fname, im)
        img_files.append(impath)

In [7]:
def read_labels(fname):
    with open(fname, "r") as f:
        labels = [re.sub('[^0-9a-zA-Z ]+', " ", s.lower().strip()) for s in f.readlines()]
    return labels

In [8]:
categories = read_labels("imagenet_classes.txt")

In [9]:
dog_labels = read_labels("dog_labels.txt")

In [10]:
def detect_dog(img_file, transform):
    input_tensor = transform(Image.open(img_file))
    input_batch = input_tensor.unsqueeze(0)
    input_batch = input_batch.to(device)

    with torch.no_grad():
        output = model(input_batch)
    
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    top5_prob, top5_catid = torch.topk(probabilities, 5)
    cats = set([categories[c] for c in top5_catid])
    dog_cats = cats.intersection(set(dog_labels))
        
    return len(dog_cats) > 0, cats

In [11]:
to_delete = []
to_keep = []

for img in img_files:
    dog, classes = detect_dog(img, transform)
    
    if dog:
        to_keep.append(img)
    else:
        to_delete.append(img)

In [12]:
len(to_delete)

1797

In [13]:
len(to_keep)

15921

In [14]:
def remove_items(items):
    for item in items:
        os.remove(item)

In [15]:
remove_items(to_delete)