In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import os
import os
import pandas as pd
from tqdm import tqdm

In [None]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "/kaggle/input/sam-b/other/default/1/sam_vit_b_01ec64.pth"
model_type = "vit_b"

device = "cuda" if torch.cuda.is_available() else "cpu"

sam = sam_model_registry[model_type](checkpoint=None).to(device)
state_dict = torch.load(sam_checkpoint, weights_only=True)
sam.load_state_dict(state_dict)

In [None]:
def checkcolour(masks, hsv):
    colours = np.zeros((0,3))

    for i in range(len(masks)):
        color = hsv[masks[i]['segmentation']].mean(axis=(0))
        colours = np.append(colours,color[None,:], axis=0)
        
    idx_green = (colours[:,0]<75) & (colours[:,0]>35) & (colours[:,1]>35)
    if idx_green.sum()==0:
        # grow lights on adjust
        idx_green = (colours[:,0]<100) & (colours[:,0]>35) & (colours[:,1]>35)
    
    return(idx_green)

In [None]:
def checkfullplant(masks):
    combined = np.zeros_like(masks[0]['segmentation'], dtype=int)
    for m in masks:
        combined += m['segmentation']
    ious = [iou(m['segmentation'], combined>0) for m in masks]
    return np.array(ious) < 0.9

In [None]:
def getbiggestcontour(contours):
    nopoints = [len(cnt) for cnt in contours]
    return(np.argmax(nopoints))

def checkshape(masks):
    cratio = []

    for i in range(len(masks)):
        test_mask = masks[i]['segmentation']
        
        if not test_mask.max():
            cratio.append(0)
        else:

            contours,hierarchy = cv2.findContours((test_mask*255).astype('uint8'), 1, 2)

            # multiple objects possibly detected. Find contour with most points on it and just use that as object
            cnt = contours[getbiggestcontour(contours)]
            M = cv2.moments(cnt)

            area = cv2.contourArea(cnt)
            perimeter = cv2.arcLength(cnt,True)

            (x,y),radius = cv2.minEnclosingCircle(cnt)

            carea = np.pi*radius**2

            cratio.append(area/carea)
    idx_shape = np.array(cratio)>0.1
    return(idx_shape)

In [None]:
def iou(gtmask, test_mask):
    intersection = np.logical_and(gtmask, test_mask)
    union = np.logical_or(gtmask, test_mask)
    iou_score = np.sum(intersection) / np.sum(union)
    return (iou_score)

In [None]:
def issubset(mask1, mask2):
    # is mask2 subpart of mask1
    intersection = np.logical_and(mask1, mask2)
    return(np.sum(intersection)/mask2.sum()>0.9)

def istoobig(masks):
    if not masks:
        return []
        
    idx_toobig = []
    
    mask_all = np.zeros(masks[0]['segmentation'].shape[:2])

    for mask in masks:
        mask_all +=mask['segmentation']*1 

    for idx in range(len(masks)):
        if idx in idx_toobig:
            continue
        for idx2 in range(len(masks)):
            if idx==idx2:
                continue
            if idx2 in idx_toobig:
                continue
            if issubset(masks[idx2]['segmentation'], masks[idx]['segmentation']):
                # check if actually got both big and small copy delete if do
                if mask_all[masks[idx2]['segmentation']].mean() > 1.5:
                
                    idx_toobig.append(idx2)
    
    idx_toobig.sort(reverse=True)        
    return(idx_toobig)

def remove_toobig(masks, idx_toobig):
    masks_ntb = masks.copy()

    idx_del = []
    for idxbig in idx_toobig[1:]:
        maskbig = masks_ntb[idxbig]['segmentation'].copy()
        submasks = np.zeros(maskbig.shape)

        for idx in range(len(masks_ntb)):
            if idx==idxbig:
                continue
            if issubset(masks_ntb[idxbig]['segmentation'], masks_ntb[idx]['segmentation']):
                submasks +=masks_ntb[idx]['segmentation']

        if np.logical_and(maskbig, submasks>0).sum()/maskbig.sum()>0.9:
            # can safely remove maskbig
            idx_del.append(idxbig)
            del(masks_ntb[idxbig])
            
    return(masks_ntb)

In [None]:
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.88,
    stability_score_thresh=0.95,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=200,  
)

In [None]:
# ---------- Dataset parameters ----------
root_path = "/kaggle/input/gromo-mustard-dataset/dataset/content/drive/MyDrive/ACM grand challenge/Crops data/For_age_prediction/mustard"
plants = ["p1"] # change to p2 or p3 to generate masks for p2 or p3
levels = ['L1', 'L2', 'L3', 'L4', 'L5']
angles = list(range(0, 360, 15))
days = range(1, 48)
output_root = "/kaggle/working/sam_masks"

# Ensure output root exists
os.makedirs(output_root, exist_ok=True)


In [None]:
for plant in tqdm(plants, desc="Plants", position=0):
    for day in tqdm(days, desc=f"Days (Plant {plant})", leave=False, position=1):
        for level in tqdm(levels, desc=f"Levels (Plant {plant}, Day {day})", leave=False, position=2):
            for angle in angles:
                
                # Build output path first to check if already done
                out_dir = os.path.join(output_root, plant, f"d{day}", level)
                os.makedirs(out_dir, exist_ok=True)
                out_file = os.path.join(out_dir, f"mustard_{plant}_d{day}_{level}_{angle}_leaf_masks.npz")

                if os.path.exists(out_file):
                    continue
                
                # Construct input path
                img_name = f"mustard_{plant}_d{day}_{level}_{angle}.png"
                img_path = os.path.join(root_path, plant, f"d{day}", level, img_name)
                if not os.path.isfile(img_path):
                    continue
    
                # Load and preprocess image
                image = cv2.imread(img_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                image = cv2.resize(image,(512, 512))
                hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
    
                # get masks
                masks = mask_generator.generate(image)
    
                # remove things that aren't green enough to be leaves
                idx_green = checkcolour(masks, hsv)

                masks_g  = [m for i,m in enumerate(masks) if idx_green[i]]
    
                if len(masks_g) > 2:
                    # check to see if full plant detected and remove
                    idx_notall = checkfullplant(masks_g)                    
                    masks_na   = [m for i,m in enumerate(masks_g) if idx_notall[i]]
                else:
                    masks_na = masks_g.copy()
    
                if masks_na:
                    idx_shape = checkshape(masks_na)
                    masks_s   = [m for i,m in enumerate(masks_na) if idx_shape[i]]
                else:
                    masks_s = []
            
                    
                if masks_s:
                    idx_toobig = istoobig(masks_s)
                    masks_ntb  = remove_toobig(masks_s, idx_toobig)
                else:
                    masks_ntb = []


                H, W = image.shape[:2]
                if masks_ntb:
                    final_stack = np.stack(
                        [m['segmentation'].astype(np.uint8) for m in masks_ntb],
                        axis=0
                    )
                else:
                    final_stack = np.zeros((0, H, W), dtype=np.uint8)
                    
                # Prepare output directory
                out_dir = os.path.join(output_root, plant, f"d{day}", level)
                os.makedirs(out_dir, exist_ok=True)
    
                # Save masks
                out_file = os.path.join(out_dir, f"mustard_{plant}_d{day}_{level}_{angle}_leaf_masks.npz")
                np.savez_compressed(out_file, final=final_stack)

In [None]:
output_root = "/kaggle/working/sam_masks"
zip_path = "/kaggle/working/masks_p1.zip"  # _p2 or _p3 if generating masks for p2 or p3


if os.path.isfile(zip_path):
    os.remove(zip_path)

# Create a zip archive of the entire directory
shutil.make_archive(base_name=zip_path.replace(".zip",""),
                    format="zip",
                    root_dir=output_root)

print(f"Zipped all masks into {zip_path}")
