In [None]:
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import cv2
import os
import json
import numpy as np
import random
from tqdm.notebook import tqdm
from torchvision.transforms import functional as FN
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from torchvision import models
from torchsummary import summary
from torch.utils.data._utils.collate import default_collate
import torch.nn.functional as F
from torch.utils.data import WeightedRandomSampler
from torch.utils.data import Subset

from time import time
from IPython.display import clear_output

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize the cropped images
    transforms.ToTensor()
])

In [None]:
weights_path = "whatever.pt"
model_ft = models.resnet50()
num_ftrs = model.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
state_dict = torch.load(weights_path)
model_ft.load_state_dict(state_dict)
model_ft.to(device)

In [None]:
class PotholeDataset2(Dataset):
    def __init__(self, json_file, transform=None, subset=None):
        with open(json_file, 'r') as f:
            self.data = json.load(f)

        self.transform = transform
        self.subset = subset
        self.cropped_data = []
        self.original_images = {}
        self.prepare_dataset()


    def prepare_dataset(self):
            for item in self.data:
                if self.subset is not None and item.get('subset') != self.subset:
                    continue
                image_path = item['image']
                image = Image.open(image_path).convert('RGB')
                self.original_images[image_path] = image
                for box_info in item['boxes']:
                    box = box_info['box']
                    cropped_image = FN.crop(image, box[1], box[0], box[3], box[2])  # top, left, height, width
                    if self.transform:
                        cropped_image = self.transform(cropped_image)
                    self.cropped_data.append((cropped_image, image_path, box))


    def __len__(self):
        return len(self.cropped_data)

    def __getitem__(self, idx):
        cropped_image, image_path, box = self.cropped_data[idx]
        return cropped_image, image_path, box

In [None]:
picture_boxes = PotholeDataset2(json_file='processed_images_data.json', transform=transform, subset='test')
picture_loader = DataLoader(picture_boxes, batch_size=1, shuffle=False, num_workers=3)

In [None]:

def display_image_with_boxes(image_path, boxes_predictions):
    image = Image.open(image_path)
    fig, ax = plt.subplots(1)
    ax.imshow(image)
    for box in boxes_predictions:
        xmin, ymin, width, height = box
        rect = patches.Rectangle((xmin, ymin), width, height, linewidth=2, edgecolor='g', facecolor='none')
        ax.add_patch(rect)
    plt.show()



In [None]:
c = 0
with open("processed_images_data.json", 'r') as file:
    data = json.load(file)
for original_image_path in picture_boxes.original_images.keys():
    boxes_predictions = []


    print()
    for item in data:
        if item['image'] == original_image_path:
            image = Image.open(original_image_path).convert('RGB')
            print(original_image_path)

            for box_info in item['boxes']:
 
                cropped_image = FN.crop(image, box_info['box'][1],box_info['box'][0],box_info['box'][3],box_info['box'][2])

                cropped_image = transform(cropped_image)

                cropped_image = cropped_image.to(device)
                pred = model_ft(cropped_image.unsqueeze(0))
                propability = torch.softmax(pred, dim=1).cpu().detach().numpy()[0][1]
        
                if pred.argmax(1).item() == 1:
                    if propability> 0.7:  # Check if the prediction is class 1
                        boxes_predictions.append(box_info['box'])
            

        
      

    if boxes_predictions:  # Display the image if there are boxes with class 1
        display_image_with_boxes(original_image_path, boxes_predictions)
        c+= 1
        if c > 5:
            break
