In [1]:
import os
import albumentations as A
import torch
from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch.transforms import ToTensorV2
import numpy as np
from skimage import io, exposure
import json
from PIL import Image
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_valid_transform():
    return A.Compose([
        A.Resize(512, 512),
        ToTensorV2(),
    ])


In [3]:
class LungsAnnotationDataset(Dataset):
    def __init__(self, image_dir, transforms=None):
        super().__init__()
        self.image_dir = image_dir
        self.image_ids = os.listdir(image_dir)  # List all image files in the directory
        self.transforms = transforms

    def __getitem__(self, index: int):
        image_id = self.image_ids[index]
        image_path = os.path.join(self.image_dir, image_id)

        image = Image.open(image_path)

        if image.mode != 'RGB':
            image = image.convert('RGB')

        image = image.resize((512, 512))

        image = np.array(image).astype(np.float32) / 255.0

        h, w, c = image.shape

        image = np.transpose(image, (2, 0, 1))  

        image = torch.tensor(image)

        return image, image_id  # Return image and image_id

    def __len__(self):
        return len(self.image_ids)

In [4]:
val_dataset = LungsAnnotationDataset(image_dir='test_resized/', transforms=None)

In [5]:
def collate_fn(batch):
    images, image_ids = zip(*batch)
    images = [image.permute(1, 2, 0) if image.shape[0] != 3 else image for image in images]
    return torch.stack(images), image_ids

In [6]:
class_brands = {
    0: 'Aortic enlargement',
    1: 'Atelectasis',
    2: 'Calcification',
    3: 'Cardiomegaly',
    4: 'Consolidation',
    5: 'ILD',
    6: 'Infiltration',
    7: 'Lung Opacity',
    8: 'Nodule/Mass',
    9: 'Other lesion',
    10: 'Pleural effusion',
    11: 'Pleural thickening',
    12: 'Pneumothorax',
    13: 'Pulmonary fibrosis'
}

In [7]:
val_data_loader = DataLoader(val_dataset, batch_size=5, shuffle=False, collate_fn=lambda x: list(zip(*x)))

In [8]:
num_classes=15

In [9]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

in_features = model.roi_heads.box_predictor.cls_score.in_features

model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)



In [10]:
def set_device():
    device =  "cpu"
    return device

device=set_device()

In [11]:
model.load_state_dict(torch.load('x_ray_models/model_fasterRCNN_finetuned.pth', map_location=device), strict=False)


_IncompatibleKeys(missing_keys=[], unexpected_keys=['roi_heads.box_predictor.fc.weight', 'roi_heads.box_predictor.fc.bias'])

In [12]:
def save_predictions_to_json(model, val_data_loader, device, class_brands, filename='phase1_results.json'):
    results = {}
    
    model.to(device)  # Ensure the model is on the same device as the inputs
    model.eval()
    with torch.no_grad():
        for images, image_ids in val_data_loader:  # Only images and image_ids
            images = [image.to(device) for image in images]  # Move images to the same device

            # Ensure images are in correct format (e.g., 3 channels, correct shape)
            for idx, image in enumerate(images):
                # print(f"Image ID: {image_ids[idx]}, Shape: {image.shape}")  # Print image shape
                assert image.shape[0] == 3, f"Expected 3 channels, but got {image.shape[0]} channels"  # Confirm 3 channels
                assert image.shape[1] == 512 and image.shape[2] == 512, f"Expected (512, 512), but got {(image.shape[1], image.shape[2])}"

            outputs = model(images)  # Model and inputs are on the same device
            
            for i, output in enumerate(outputs):
                image_id = image_ids[i]  # Use image_id directly
                
                pred_boxes = output['boxes']
                labels_pred = output['labels']
                scores = output['scores'].data.cpu().numpy()

                # Filter predicted boxes based on confidence score
                valid_indices = scores >= 0.6
                boxes_pred = pred_boxes[valid_indices]
                labels_pred = labels_pred[valid_indices]

                # Convert results to list of dictionaries
                pred_results = []
                for box, label in zip(boxes_pred, labels_pred):
                    box_np = box.detach().cpu().numpy().astype(int).tolist()  # Convert tensor to list
                    class_name = class_brands.get(label.item(), 'Unknown')
                    pred_results.append({
                        'box': box_np,
                        'class_label': class_name
                    })

                

                results[image_id] = pred_results
    
    # Write results to a JSON file
    with open(filename, 'w') as f:
        json.dump(results, f, indent=4)

    print(f"Results saved to {filename}")


In [13]:
save_predictions_to_json(model, val_data_loader, device, class_brands)

Results saved to phase1_results.json
