In [7]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as T
import torchvision

class CustomDataset(Dataset):
    def __init__(self, data_dir, image_files, annotations, transform=None):
        self.data_dir = data_dir
        self.image_files = image_files
        self.annotations = annotations
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.data_dir, 'images', self.image_files[idx])
        image = Image.open(image_path).convert('RGB')

        filename = self.image_files[idx]
        annotation = self.annotations.get(filename, {})

        if self.transform is not None:
            image = self.transform(image)

        return image, annotation

In [8]:
class ResizeWithAnnotations(T.Compose):
    def __call__(self, image, annotation):
        image = self.transforms[0](image)

        original_width, original_height = image.size
        bbox_coords = annotation["bbox"]
        scale_factor_x = image.width / original_width
        scale_factor_y = image.height / original_height
        new_bbox_coords = [coord * scale_factor_x if idx % 2 == 0 else coord * scale_factor_y for idx, coord in enumerate(bbox_coords)]
        annotation["bbox"] = new_bbox_coords

        return image, annotation
    
# Define the transformations for the dataset
transform = ResizeWithAnnotations([
    T.Resize((800, 800)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [9]:
annotations_dict = {
    "Screenshot_1.png": [912, 411, 1022, 543],
    "Screenshot_10.png": [180, 142, 408, 383],
    "Screenshot_100.png": [202, 226, 300, 348],
    "Screenshot_101.png": [266, 159, 369, 285],
    "Screenshot_102.png": [334, 275, 431, 411],
    "Screenshot_103.png": [386, 314, 501, 474],
    "Screenshot_104.png": [707, 373, 796, 522],
    "Screenshot_105.png": [852, 201, 950, 315],
    "Screenshot_106.png": [669, 258, 718, 330],
    "Screenshot_108.png": [534, 193, 626, 298],
    "Screenshot_109.png": [476, 205, 569, 321],
    "Screenshot_11.png": [185, 140, 1051, 756],
    "Screenshot_110.png": [255, 361, 360, 475],
    "Screenshot_111.png": [320, 373, 412, 479],
    "Screenshot_112.png": [406, 415, 503, 532],
    "Screenshot_12.png": [671, 284, 1036, 530],
    "Screenshot_13.png": [514, 113, 614, 523],
    "Screenshot_14.png": [421, 628, 584, 781],
    "Screenshot_15.png": [862, 270, 932, 354],
    "Screenshot_16.png": [112, 334, 211, 462],
    "Screenshot_17.png": [777, 95, 834, 174],
    "Screenshot_18.png": [581, 133, 632, 213],
    "Screenshot_19.png": [667, 77, 742, 143],
    "Screenshot_2.png": [798, 475, 871, 579],
    "Screenshot_20.png": [688, 424, 769, 511],
    "Screenshot_21.png": [755, 187, 810, 241],
    "Screenshot_22.png": [524, 171, 570, 221],
    "Screenshot_23.png": [405, 416, 472, 486],
    "Screenshot_24.png": [459, 211, 515, 284],
    "Screenshot_25.png": [1415, 397, 1475, 475],
    "Screenshot_26.png": [521, 290, 647, 468],
    "Screenshot_27.png": [523, 396, 625, 496],
    "Screenshot_28.png": [185, 198, 1268, 443],
    "Screenshot_29.png": [345, 104, 593, 433],
    "Screenshot_3.png": [288, 336, 406, 475],
    "Screenshot_30.png": [140, 145, 1797, 337],
    "Screenshot_31.png": [339, 252, 428, 352],
    "Screenshot_32.png": [922, 218, 999, 316],
    "Screenshot_33.png": [1138, 374, 1233, 490],
    "Screenshot_34.png": [1176, 263, 1235, 328],
    "Screenshot_35.png": [161, 198, 229, 260],
    "Screenshot_36.png": [197, 128, 258, 182],
    "Screenshot_37.png": [671, 172, 729, 231],
    "Screenshot_38.png": [314, 183, 1460, 382],
    "Screenshot_39.png": [226, 180, 1029, 705],
    "Screenshot_4.png": [383, 326, 447, 410],
    "Screenshot_40.png": [285, 238, 359, 307],
    "Screenshot_41.png": [328, 293, 412, 378],
    "Screenshot_42.png": [991, 375, 1092, 481],
    "Screenshot_43.png": [935, 217, 1003, 281],
    "Screenshot_44.png": [266, 143, 343, 198],
    "Screenshot_45.png": [483, 264, 550, 335],
    "Screenshot_46.png": [885, 255, 942, 314],
    "Screenshot_47.png": [1226, 125, 1282, 188],
    "Screenshot_48.png": [643, 417, 707, 499],
    "Screenshot_49.png": [421, 269, 479, 332],
    "Screenshot_5.png": [72, 301, 155, 405],
    "Screenshot_50.png": [604, 233, 666, 291],
    "Screenshot_51.png": [1305, 244, 1372, 317],
    "Screenshot_52.png": [543, 97, 588, 161],
    "Screenshot_53.png": [787, 324, 845, 414],
    "Screenshot_54.png": [613, 288, 737, 376],
    "Screenshot_55.png": [585, 359, 684, 481],
    "Screenshot_56.png": [977, 213, 1015, 271],
    "Screenshot_57.png": [215, 402, 254, 461],
    "Screenshot_58.png": [942, 113, 1001, 183],
    "Screenshot_59.png": [396, 287, 465, 354],
    "Screenshot_6.png": [122, 106, 844, 673],
    "Screenshot_60.png": [540, 236, 653, 337],
    "Screenshot_61.png": [896, 284, 1020, 437],
    "Screenshot_62.png": [761, 318, 995, 591],
    "Screenshot_63.png": [40, 203, 1397, 404],
    "Screenshot_64.png": [72, 150, 1464, 348],
    "Screenshot_65.png": [253, 221, 1333, 509],
    "Screenshot_66.png": [633, 114, 699, 171],
    "Screenshot_68.png": [659, 213, 715, 288],
    "Screenshot_69.png": [13, 443, 77, 515],
    "Screenshot_7.png": [246, 163, 1120, 555],
    "Screenshot_70.png": [1053, 216, 1102, 273],
    "Screenshot_71.png": [15, 154, 114, 231],
    "Screenshot_72.png": [584, 157, 695, 248],
    "Screenshot_73.png": [794, 489, 923, 691],
    "Screenshot_74.png": [966, 237, 1005, 288],
    "Screenshot_75.png": [1016, 278, 1074, 337],
    "Screenshot_76.png": [575, 267, 633, 331],
    "Screenshot_77.png": [1172, 329, 1255, 423],
    "Screenshot_78.png": [1114, 350, 1226, 482],
    "Screenshot_79.png": [127, 577, 306, 793],
    "Screenshot_8.png": [17, 324, 104, 420],
    "Screenshot_80.png": [968, 215, 1073, 305],
    "Screenshot_81.png": [1059, 213, 1163, 312],
    "Screenshot_82.png": [194, 198, 294, 553],
    "Screenshot_83.png": [300, 473, 422, 594],
    "Screenshot_84.png": [57, 398, 166, 508],
    "Screenshot_85.png": [254, 284, 339, 369],
    "Screenshot_86.png": [203, 505, 349, 637],
    "Screenshot_87.png": [263, 254, 401, 373],
    "Screenshot_88.png": [452, 309, 525, 383],
    "Screenshot_89.png": [207, 220, 308, 302],
    "Screenshot_9.png": [45, 149, 1140, 450],
    "Screenshot_90.png": [139, 381, 232, 468],
    "Screenshot_91.png": [630, 170, 760, 295],
    "Screenshot_92.png": [971, 269, 1122, 409],
    "Screenshot_93.png": [856, 427, 990, 594],
    "Screenshot_94.png": [275, 311, 379, 457],
    "Screenshot_95.png": [514, 169, 637, 291],
    "Screenshot_96.png": [208, 158, 323, 296],
    "Screenshot_97.png": [490, 173, 631, 290],
    "Screenshot_98.png": [644, 266, 764, 414],
    "Screenshot_99.png": [499, 109, 618, 208]
}

In [10]:
# Create the dataset
data_dir = 'dataset'
dataset = CustomDataset(data_dir, annotations_dict, transform=transform)

In [11]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

def collate_fn(batch):
    images = [img.to(device) for img, _ in batch]
    annotations_list = [annotations for _, annotations in batch]

    return images, annotations_list

# Create the data loader with the custom collate function
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

# Load the pre-trained ResNet backbone
backbone = torchvision.models.detection.backbone_utils.resnet_fpn_backbone('resnet50', pretrained=True)

# Define the RPN anchor generator
rpn_anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
                                       aspect_ratios=((0.5, 1.0, 2.0),) * 5)

# Define the number of classes (including the background class)
num_classes = 4 + 1  # +1 for the background class

# Create the Faster R-CNN model
model = FasterRCNN(backbone,
                   num_classes=num_classes,  # Number of object classes + background
                   rpn_anchor_generator=rpn_anchor_generator)

# Move the model to GPU if available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

# Define optimizer and learning rate scheduler
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)  # You can adjust the learning rate and momentum


In [12]:
num_epochs = 10  # Define the number of training epochs
for epoch in range(num_epochs):
    model.train()
    for images, annotations_list in dataloader:  # Iterate over your dataset batches
        images = [transform(img).to(device) for img in images]
        
        processed_annotations = []
        for annotations_dict in annotations_list:
            for bbox_coords in annotations_dict["bbox"]:
                x_min, y_min, x_max, y_max = bbox_coords
                label = 1  # Assuming a single class for demonstration
                
                processed_annotation = {
                    "boxes": torch.tensor([[x_min, y_min, x_max, y_max]], dtype=torch.float32),
                    "labels": torch.tensor([label], dtype=torch.int64)
                }
                processed_annotations.append(processed_annotation)
        
        targets = processed_annotations
        
        # Forward pass and calculate losses
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        
        # Backpropagation and optimization
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

TypeError: __call__() missing 1 required positional argument: 'annotation'

In [None]:

# Save the trained model
torch.save(model.state_dict(), 'object_detection_model.pth')

In [None]:
# Set the model to evaluation mode
model.eval()

# Define a transformation for the test images (similar to the training transform)
test_transform = T.Compose([
    T.Resize((800, 800)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create a dataset for the test images
test_dataset = CustomDataset('dataset', transform=test_transform)

# Create a data loader for the test images
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

# List to store results
all_predictions = []

# Iterate over test images and perform inference
with torch.no_grad():
    for images, _ in test_dataloader:
        images = list(img.to(device) for img in images)

        # Perform object detection
        predictions = model(images)

        # Store the predictions
        all_predictions.append(predictions)

# Process the predictions as needed
for predictions in all_predictions:
    # Process the prediction to get the detected objects and their bounding boxes
    # You can access predictions[0]['boxes'], predictions[0]['labels'], predictions[0]['scores']
    boxes = predictions[0]['boxes'].cpu().numpy()
    labels = predictions[0]['labels'].cpu().numpy()
    scores = predictions[0]['scores'].cpu().numpy()

    # Process and visualize the results
    for box, label, score in zip(boxes, labels, scores):
        # Process the bounding box, label, and score as needed
        print("Label:", label, "Score:", score, "Box:", box)