In [1]:
import torch
from torchvision import transforms, models
from torch import nn
import shutil
from PIL import Image
import os

#### Setup a model trained earlier that will decide if image contais layout or not

In [2]:
def setup_model(num_classes=2):
    # Load pre-trained ResNet18
    model = models.resnet18(weights='IMAGENET1K_V1')
    
    # Freeze all layers
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace the final fully connected layer
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(num_features, 256),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(256, num_classes)
    )
    
    # Move model to GPU if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    print(f"Using device: {device}")
    
    return model, device

#### Function that will move images containing layout to filtered_images directory

In [3]:
def filter_images(model, model_path='layout_classifier.pth', 
                            source_dir='data/raw',
                            filtered_dir='data/filtered_images'):
     
     if isinstance(model, str):
        try:
            model, _ = setup_model() 
            model.load_state_dict(torch.load(model_path))
        except NameError:
             print("Error: setup_model function not defined. Cannot load model from path.")
             return 0, 0
        except FileNotFoundError:
             print(f"Error: Model file not found at {model_path}")
             return 0, 0
     
     model.eval()
     device = next(model.parameters()).device

     val_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
     
     os.makedirs(filtered_dir, exist_ok=True)

     total_images = 0
     filtered_images = 0

     for source in os.listdir(source_dir):
          source_path = os.path.join(source_dir, source)
          if not os.path.isdir(source_path):
               continue
          for track in os.listdir(source_path):
               track_path = os.path.join(source_path, track)
               if not os.path.isdir(track_path):
                    continue
               os.makedirs(os.path.join(filtered_dir, track), exist_ok=True)

               for img_name in os.listdir(track_path):
                    img_path = os.path.join(track_path, img_name)
                    total_images += 1

                    try:
                         img = Image.open(img_path).convert('RGB')
                         img_tensor = val_transforms(img).unsqueeze(0).to(device)

                         with torch.no_grad():
                              output = model(img_tensor)
                              _, predicted = torch.max(output, 1)
                         
                         if predicted.item() == 0:
                              filtered_images += 1
                              shutil.copy(img_path, os.path.join(filtered_dir, track, f'{source}_{img_name}'))
                    except Exception as e:
                         print(f"Error processing image {img_path}: {e}")
                    
     print(f"\nProcessing complete!")
     print(f"Total images processed: {total_images}")
     print(f"Images with layout: {filtered_images}")
     
     if total_images > 0:
          print(f"Clean Images rate: {filtered_images/total_images*100:.1f}%")
     else:
          print("Clean Images rate: N/A (No images processed)")
    
     return None

In [4]:
filter_images('layout_classifier.pth')

Using device: cpu


  model.load_state_dict(torch.load(model_path))


Error processing image data/raw/google/Bahrain International Circuit/.DS_Store: cannot identify image file '/Users/bszczesniak/projekty/f1-layout-recognition/data/raw/google/Bahrain International Circuit/.DS_Store'

Processing complete!
Total images processed: 1369
Images with layout: 1109
Clean Images rate: 81.0%


##### Summary
Model that 'cleans' data isn't perfect so it's still not the best, but for now this must work. In future, the feedback loop will help creating more realistic images.