In [36]:
import os
import torch
import torchvision
from torchvision import transforms
from PIL import Image

In [37]:
class_to_idx = {
    'green': 1,
    'red': 2,
    'yeloow': 3
}

In [38]:

idx_to_class = {v: k for k, v in class_to_idx.items()}

# Set device
device = torch.device("cpu")

# Load a pre-trained model (using Faster R-CNN for example)
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)


in_features = model.roi_heads.box_predictor.cls_score.in_features

num_classes = len(class_to_idx) + 1
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)


model.eval()


FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [40]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Define the input and output directories
input_dir = "NEW_IMAGES/cropped_boxes/"
output_dir = "classified_images/"


In [48]:

class_count = {cls: 0 for cls in class_to_idx.keys()}

max_images_per_class = 50

for image_name in os.listdir(input_dir):
    print(f"Processing image: {image_name}")
    image_path = os.path.join(input_dir, image_name)
    
    # Load and transform the image
    image = Image.open(image_path).convert("RGB")
    image_tensor = transform(image).unsqueeze(0).to(device)

    # Make predictions
    with torch.no_grad():
        predictions = model(image_tensor)

    # Extract predicted labels and scores
    predicted_labels = predictions[0]['labels'].cpu().numpy()
    predicted_scores = predictions[0]['scores'].cpu().numpy()

    threshold = 0.4
    for i, score in enumerate(predicted_scores):
        if score > threshold:
            predicted_class_idx = predicted_labels[i]
            predicted_class = idx_to_class[predicted_class_idx.item()]
            print("Ss")
            
            class_folder = os.path.join(output_dir, predicted_class)
            os.makedirs(class_folder, exist_ok=True)


            output_image_path = os.path.join(class_folder, image_name)
            image.save(output_image_path)

            class_count[predicted_class] += 1

    for count in class_count.values():
        if all(count >= max_images_per_class):
        break

print("Image classification and sorting completed.")

Processing image: box_dayClip1--00000.jpg
Processing image: box_dayClip1--00001.jpg
Processing image: box_dayClip1--00002.jpg
Ss
Processing image: box_dayClip1--00003.jpg
Ss
Processing image: box_dayClip1--00004.jpg
Processing image: box_dayClip1--00005.jpg
Ss
Processing image: box_dayClip1--00006.jpg
Processing image: box_dayClip1--00007.jpg
Processing image: box_dayClip1--00008.jpg
Processing image: box_dayClip1--00009.jpg
Processing image: box_dayClip1--00010.jpg
Processing image: box_dayClip1--00011.jpg
Processing image: box_dayClip1--00012.jpg
Ss
Processing image: box_dayClip1--00013.jpg
Ss
Processing image: box_dayClip1--00014.jpg
Processing image: box_dayClip1--00015.jpg
Processing image: box_dayClip1--00016.jpg
Ss
Processing image: box_dayClip1--00017.jpg
Ss
Processing image: box_dayClip1--00018.jpg
Ss
Processing image: box_dayClip1--00019.jpg
Ss
Processing image: box_dayClip1--00020.jpg
Processing image: box_dayClip1--00021.jpg
Processing image: box_dayClip1--00022.jpg
Ss
Ss
P

KeyboardInterrupt: 