<a href="https://colab.research.google.com/github/harsh1532/Remove_Sky/blob/main/Remove_Sky.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import cv2
import matplotlib.pyplot as plt
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import Visualizer
from detectron2.data.datasets import register_coco_instances

def remove_sky_from_dataset(dataset_name, output_dataset_name):
    # Load the COCO dataset
    register_coco_instances(dataset_name, {}, "path_to_annotation_json", "path_to_images_folder")
    dataset_metadata = MetadataCatalog.get(dataset_name)

    # Create configuration for the model
    cfg = get_cfg()
    cfg.merge_from_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
    cfg.MODEL.WEIGHTS = "path_to_pretrained_model_weights"

    # Create the predictor
    predictor = DefaultPredictor(cfg)

    # Iterate over the dataset and remove sky from each image
    dataset = MetadataCatalog.get(dataset_name)
    for data in dataset:
        image_path = data['file_name']
        image = cv2.imread(image_path)

        # Use the predictor to detect objects in the image
        outputs = predictor(image)

        # Remove the sky from the image by cropping to the highest non-sky object
        instances = outputs['instances']
        pred_classes = instances.pred_classes
        pred_scores = instances.scores
        pred_masks = instances.pred_masks

        highest_object_index = pred_scores.argmax().item()
        highest_object_mask = pred_masks[highest_object_index].to("cpu").numpy()

        image_without_sky = image.copy()
        image_without_sky[highest_object_mask == 0] = 0

        # Visualize the image without sky
        v = Visualizer(image_without_sky[:, :, ::-1], metadata=dataset_metadata, scale=1.2)
        out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
        plt.imshow(out.get_image()[:, :, ::-1])
        plt.show()

        # Save the manipulated image to the output dataset
        output_image_path = "path_to_save_output_images"  # Provide the path to save the output images
        cv2.imwrite(output_image_path, image_without_sky)

        # Update the annotation file for the output dataset

remove_sky_from_dataset("original_dataset", "output_dataset")


ModuleNotFoundError: ignored