## Parts segmentation
- we already wrote an eda notebook at eda/00_parts_seg.ipynb
- it outputs `bg_imgs.txt`, we will use this as base for everything.

In [None]:
import fastcore.all as fc
import numpy as np

from hexray25.coco_utils import COCO
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
root = fc.Path("/home/ubuntu/foundations/vjt-data/")
images_loc = fc.L((root / "images").glob("*.png"))
print(f"total images: {len(images_loc)}")
annots1 = COCO.for_vjt(root / "annotations/phase6_background_separation_27apr24.json")
annots2 = COCO.for_vjt(root / "annotations/background_seperation_ds3_ds4_bg_jan12_2025_fixed.json")
annots3 = COCO.for_vjt(root / "annotations/full_bg_backup_may2025_fixed.json")

In [None]:
all_imgs = annots1.imgname + annots2.imgname + annots3.imgname
#all_imgs = annots3.imgname
all_imgs 


In [None]:
def get_annot_file(img_name):
    if img_name in annots1.imgname:
        return annots1
    elif img_name in annots2.imgname:
        return annots2
    elif img_name in annots3.imgname:
        return annots3
    else:
        return None


> split the dataset into train and test 

In [None]:
train_ds, val_ds = train_test_split(all_imgs, test_size=0.1, random_state=42)

fc.Path("ds/parts/").mkdir(exist_ok=True)

store = open("ds/parts/train.txt", "w")
for i in train_ds:
    store.write(i+"\n")
store.close()

store = open("ds/parts/test.txt", "w")
for i in val_ds:
    store.write(i+"\n")
store.close()

In [None]:
img_name = all_imgs[np.random.randint(len(all_imgs))]
annots = get_annot_file(img_name)
I, anns = annots.loadimgAnns(img_name=img_name, root=root)
print("before filtering", len(anns))
anns = [ann for ann in anns if ann['category_id'] == 1]
print("after filtering", len(anns))


In [None]:
annots.print_stats()

> we will use only `parts` for training.

In [None]:
mask = np.zeros((I.shape[0], I.shape[1]))
for ann in anns:
    mask+=annots.annToMask(ann)

mask = mask.astype(np.uint8)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(I)
ax[1].imshow(mask)
plt.show()

In [None]:
fc.Path("ds/parts/masks/").mkdir(exist_ok=True)
for img_name in tqdm(all_imgs):
    annots = get_annot_file(img_name)
    I, anns = annots.loadimgAnns(img_name=img_name, root=root)
    anns = [ann for ann in anns if ann['category_id'] == 1]
    
    mask = np.zeros((I.shape[0], I.shape[1]))
    for ann in anns:
        mask+=annots.annToMask(ann)
    
    mask[mask>0] = 1
    mask[mask==1] = 255

    mask = mask.astype(np.uint8)
    Image.fromarray(mask).save(fc.Path("ds/parts/masks/") / f"{img_name}")