In [None]:
# System libs
import os, csv, torch, numpy, scipy.io, PIL.Image, torchvision.transforms
# Our libs
from mit_semseg.models import ModelBuilder, SegmentationModule
from mit_semseg.utils import colorEncode

def get_background_classes():
    """Define background classes in ADE20K"""
    return {
        'wall': 1,
        'floor': 4,
        'ceiling': 6,
        'windowpane': 9,
        'door': 15
        #'bed': 8,
        #'cabinet': 11,
        #'table': 16,
        #'chair': 20,
        #'painting': 23,
        #'sofa': 24,
        #'shelf': 25,
        #'mirror': 28,
        #'rug, carpet, carpeting': ???,
        #'armchair': 31,
        #'seat': 32,
        #'desc': 34,
        #'wardrobe, closet, press': 36,
        #'lamp': 37,
        #.....

    }

def create_foreground_mask(pred, pred_classes, names, background_classes):
    """Create a mask of all non-background objects"""
    # Initialize foreground mask
    foreground_mask = numpy.ones_like(pred, dtype=numpy.uint8)

    # Print what we're finding
    print("\nDetected objects:")
    for c in pred_classes[:15]:
        pixel_count = numpy.sum(pred == c)
        percentage = (pixel_count / pred.size) * 100
        class_name = names[c+1] if c+1 in names else f'Class {c}'

        # Check if it's a background class
        is_background = (c + 1) in background_classes.values()
        type_label = "BACKGROUND" if is_background else "FOREGROUND"

        # Print detection
        print(f"{class_name}: {percentage:.2f}% of image ({type_label})")

        # Update mask
        if is_background:
            foreground_mask[pred == c] = 0

    return foreground_mask

def process_image(image_path, output_dir=None):
    """Process a single image and generate segmentation masks"""
    print("Initializing model...")

    # Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch='resnet50dilated',
        fc_dim=2048,
        weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth')
    net_decoder = ModelBuilder.build_decoder(
        arch='ppm_deepsup',
        fc_dim=2048,
        num_class=150,
        weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth',
        use_softmax=True)

    crit = torch.nn.NLLLoss(ignore_index=-1)
    segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)
    segmentation_module.eval()

    print("Processing image...")
    # Load and normalize one image as a singleton tensor batch
    pil_to_tensor = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
    ])

    pil_image = PIL.Image.open(image_path).convert('RGB')
    print(f"Original image size: {pil_image.size}")
    img_original = numpy.array(pil_image)
    img_data = pil_to_tensor(pil_image)
    singleton_batch = {'img_data': img_data[None]}
    output_size = img_data.shape[1:]

    print(f"Running segmentation...")
    # Run the segmentation at the highest resolution
    with torch.no_grad():
        scores = segmentation_module(singleton_batch, segSize=output_size)

    # Get the predicted scores for each pixel
    _, pred = torch.max(scores, dim=1)
    pred = pred.cpu()[0].numpy()

    # Load class names
    names = {}
    with open('data/object150_info.csv') as f:
        reader = csv.reader(f)
        next(reader)
        for row in reader:
            names[int(row[0])] = row[5].split(";")[0]

    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        base_filename = os.path.splitext(os.path.basename(image_path))[0]

        # Get predicted classes and their frequencies
        predicted_classes = numpy.bincount(pred.flatten()).argsort()[::-1]

        # Create foreground mask
        background_classes = get_background_classes()
        foreground_mask = create_foreground_mask(pred, predicted_classes, names, background_classes)

        # Save binary foreground mask
        mask_path = os.path.join(output_dir, f"{base_filename}_foreground_mask.png")
        PIL.Image.fromarray(foreground_mask * 255).save(mask_path)

        # Save colored visualization
        colors = numpy.random.randint(0, 255, (2, 3))  # 2 colors: one for foreground, one for background
        vis_img = img_original.copy()
        vis_img[foreground_mask == 1] = vis_img[foreground_mask == 1] * 0.7 + colors[0] * 0.3
        PIL.Image.fromarray(vis_img).save(os.path.join(output_dir, f"{base_filename}_foreground_vis.png"))

    return pred, img_original

if __name__ == "__main__":
    # Paths
    input_dir = r"C:\Users\ferran.marti\Desktop\MaskGenerator\selected_images_withFurniture"
    output_dir = r"C:\Users\ferran.marti\Desktop\MaskGenerator\automatic_masks"

    try:
        print("Note: Running on CPU. This might be slower than GPU execution.")

        # Get list of all image files in the input directory
        image_files = [f for f in os.listdir(input_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        total_images = len(image_files)

        print(f"\nFound {total_images} images to process")

        # Process each image
        for idx, image_file in enumerate(image_files, 1):
            image_path = os.path.join(input_dir, image_file)
            print(f"\nProcessing image {idx}/{total_images}: {image_file}")

            try:
                pred, img_original = process_image(image_path, output_dir)
                print(f"Successfully processed {image_file}")
            except Exception as e:
                print(f"Error processing {image_file}: {str(e)}")
                continue

        print(f"\nAll processing complete. Results saved to: {output_dir}")

    except Exception as e:
        print(f"Error: {str(e)}")
        import traceback
        traceback.print_exc()