In [None]:
import os
import numpy as np
import torch
from PIL import Image
import xml.etree.ElementTree as ET
from torchvision import transforms
from tqdm import tqdm 

def create_bbox_coords(bbox):
    xmin = float(bbox.find('xmin').text)
    ymin = float(bbox.find('ymin').text)
    xmax = float(bbox.find('xmax').text)
    ymax = float(bbox.find('ymax').text)
    return [xmin, ymin, xmax, ymax]

def create_mask(plasmodium_img, bbox):
    xmin, ymin, xmax, ymax = create_bbox_coords(bbox)
    mask = np.zeros((plasmodium_img.size[1], plasmodium_img.size[0]), dtype=np.uint8)
    mask[int(ymin):int(ymax), int(xmin):int(xmax)] = 1 
    return mask 

    
class MalariaPlasmodiumDataset(torch.utils.data.Dataset):
    # Będziemy czytać pliki jpg i odpowiadające im pliki XML 
    # z katalogu directory_root 
    # Podamy też transformacje jakie chcemy przeprowadzać na zwracanych wartościach 
    
    def __init__(self, directory_root, images_transforms=None):

        # Przypisujemy parametetry konstruktora do self 
        # Chcemy aby nasz przyszły obiekt wiedzial o tym gdzie szukać plików oraz 
        # jakie transformacje wykonywać na przeczytanych JPG 
        self.directory_root = directory_root        
        self.images_transforms = images_transforms

        # Listujemy wszystkie pliki które mają rozszerzenie "JPG" 
        all_image_files = sorted([img for img in os.listdir(directory_root) if img.endswith(".jpg")])
        

        # wśród zdjęć w naszym datasecie są takie, na których nie znaleziono zarodźca 
        # usuwamy je z datasetu - tzn zapisujemy do self.imgs_with_plasmodium tylko 
        self.imgs_with_plasmodium = []
        for img_file in all_image_files:
            xml_file = os.path.join(self.directory_root, img_file.replace(".jpg", ".xml"))
            tree = ET.parse(xml_file)
            # wykrycie zarodźca na zdjęciu jpg jest równoważne z istnieniem taga "object" w XML - jeśli tylko znajdziemy takowy 
            # kwalifikujemy zdjęcie jako dobre do naszego wejściowego datasetu i dodajemy nazwę pliku do self.imgs_with_plasmodium  
            if tree.findall('object'): 
                self.imgs_with_plasmodium.append(img_file)

    def __getitem__(self, idx: int):
        # "magiczna" metoda __getitem__ jest wykorzystywana kiedy chcemy aby nasz obiekt był dostępny poprzez operator [int] 
        # podobnie jak lista czy dict 
        single_plasmodium_img_path = self.get_single_plasmodium_path(idx)
        single_annotation_file_path = single_plasmodium_img_path.replace(".jpg", ".xml")
        plasmodium_img = Image.open(single_plasmodium_img_path).convert("RGB") 
        
        # read xml file
        annotations = ET.parse(single_annotation_file_path)
        boxes = []
        masks = []        
        
        for detected_plasmodium in annotations.findall('object'):            
            bbox = detected_plasmodium.find('bndbox')
            boxes.append(
                create_bbox_coords(bbox)
            )
        
            masks.append(
                create_mask(
                    plasmodium_img, bbox
                )
            )
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        masks = torch.as_tensor(np.array(masks), dtype=torch.uint8)  

        image_id = torch.tensor([idx])
        labels = torch.ones((len(boxes),), dtype=torch.int64)                            
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])            
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)

        
        if self.images_transforms is not None:
            transformed_plasmodium_img = self.images_transforms(plasmodium_img)
        else:
            transformed_plasmodium_img = plasmodium_img            
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
        target["masks"] = masks
        
        return transformed_plasmodium_img, target

    def get_single_plasmodium_path(self, idx):
        single_plasmodium_img_path = os.path.join(self.directory_root, self.imgs_with_plasmodium[idx])
        return single_plasmodium_img_path

    
    def __len__(self):
        # magiczna metoda __len__ jest używana gdy na instancji wykonujemy len() 
        return len(self.imgs_with_plasmodium)

In [None]:
import torchvision.transforms as T
from PIL import Image, ImageDraw

def draw_bounding_boxes(image_path, bboxes, scores=None, color=(255, 0, 0), return_pt = False):    
    img_pil = Image.open(image_path).convert("RGBA")
    new = Image.new('RGBA', img_pil.size, (255, 255, 255, 0))
    draw =ImageDraw.Draw(new)

    for i, box in enumerate(bboxes):
        xmin, ymin, xmax, ymax = box        
        if scores is not None:          
            alpha = int(255 * scores[i])  # Convert score to an alpha value.                      
            color_with_alpha = color + (alpha,)
        else:       
            color_with_alpha = color + (255,)
        draw.rectangle([xmin, ymin, xmax, ymax], outline=color_with_alpha, width=2)

    out = Image.alpha_composite(img_pil, new).convert("RGB")
    return T.ToTensor()(out) if return_pt else out 


In [None]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2()
num_classes = 2  # 1 zarodziec + tło
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)


In [None]:
state_dict = torch.load("best_model.pth")
model.load_state_dict(state_dict)
model.eval();

In [None]:
images_transforms = transforms.Compose([
    transforms.ToTensor(), # chcemy najpierw 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    # https://stackoverflow.com/questions/58151507/why-pytorch-officially-use-mean-0-485-0-456-0-406-and-std-0-229-0-224-0-2
])

dataset = MalariaPlasmodiumDataset(
    "../plasmodium-phonecamera/test/", images_transforms=images_transforms
)

test_data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=1, 
    num_workers=1,
    collate_fn=lambda x: tuple(zip(*x))
)

In [None]:
from torchmetrics.detection.mean_ap import MeanAveragePrecision
test_metric = MeanAveragePrecision(iou_type="bbox", iou_thresholds = [0.5])
with torch.no_grad():
    for images, targets in tqdm(test_data_loader):                
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        output = model(images)    
        test_metric.update(output, targets)    
        torch.cuda.empty_cache()
print(test_metric.compute()['map'])

In [None]:
import random 
with torch.no_grad():    
    image, target = random.choice(dataset)
    target = {k: v.to(device) for k, v in target.items()}
    output = model([image.to(device)])
    output[0]['scores'][output[0]['scores'] < 0.9] = 0
    test_metric.update(output, [target])    
    torch.cuda.empty_cache()    
    bboxes_true = target['boxes']
    bboxes_predicted = output[0]['boxes']
    scores = output[0]['scores']
    img_id = target['image_id']
    img = dataset.get_single_plasmodium_path(target['image_id'])
    

