In [None]:
from fastai.vision.all import *
from fastcore.parallel import *
from skimage.measure import label, regionprops, find_contours
import cv2

In [None]:
SEED=2022
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    #tf.set_random_seed(seed)
seed_everything(SEED)
torch.backends.cudnn.benchmark = True

In [None]:
vfs = get_files('data/video_clips')
len(vfs)

In [None]:
def extract_images(vf, n=1):

    dst=Path().absolute()/'data/segmentation/images'
    dst.mkdir(exist_ok=True)
    
    c=1 #number of frames to store
    video=cv2.VideoCapture(str(Path().absolute()/vf))
    
    while(True):
        ret,f=video.read()
        if ret:
            if c<=n: # already save n images so quit
                if cv2.countNonZero(cv2.cvtColor(f,cv2.COLOR_BGR2GRAY))!=0: # frame is blank (black pixels only
                    name=str(dst/f'{vf.stem}_{c:05}.jpg')
                    cv2.imwrite(name,f)
                    c+=1
            else:
                break
        else:
            break

    video.release()

In [None]:
%time parallel(extract_images,vfs,n_workers=100)

In [None]:
!prodigy image.manual binaryseg ./data/segmentation/images --label FOREGROUND --remove-base64 --width 1280

In [None]:
!prodigy db-out binaryseg > ./data/segmentation/binaryseg.jsonl

In [None]:
path = Path("data/segmentation")

In [None]:
codes = ["Background", "Foreground"]

def get_image_mask(fn):
    f=Path(str(fn).replace('images', 'masks').replace('jpg','png'))
    return PILMask.create(f) 


There are over 100 videos for which frames contain nothing but just black background, no surgical view or equipment. We had to ignore those files in the training. 

In [None]:
def proc_data():
    fs = get_image_files(path/'images')
    for f in fs:
        m = Path(str(f).replace('images', 'masks').replace('jpg','png'))
        if not m.exists():
            if os.path.exists(f):
                os.remove(f)
                print(f, ' removed successfully.')

proc_data()

In [None]:
size=(180,320)
batch_size=32

In [None]:
def get_dls(size, batch_size):
    dblock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=codes)),
                       get_items=get_image_files,
                       get_y = get_image_mask,
                       splitter=RandomSplitter(valid_pct=0.2),
                       item_tfms=[Resize(size, ResizeMethod.Squish)],
                       batch_tfms=[*aug_transforms(size=size,min_scale=1), 
                                   IntToFloatTensor(div_mask=255), 
                                   Normalize.from_stats(*imagenet_stats)])
    return dblock.dataloaders(path/'images', batch_size=batch_size)

In [None]:
dls=get_dls(size=size, batch_size=batch_size);

In [None]:
xb,yb = dls.one_batch()
xb[0].shape, type(xb[0]), yb[1].shape,type(yb[1])

In [None]:
torch.unique(yb[1])

In [None]:
dls.show_batch(max_n=5,nrows=2,vmin=1, vmax=30, figsize=(14,10),unique=True)

## Baseline learner with default loss and opt functions

In [None]:
def custom_accuracy(inp, targ):
    targ = targ.squeeze(1)
    return (inp.argmax(dim=1)==targ).float().mean()

In [None]:
learn = unet_learner(dls,resnet34, self_attention=True, metrics=custom_accuracy).to_fp16()

In [None]:
learn.lr_find()

In [None]:
learn.fine_tune(12,1e-3)

In [None]:
learn.show_results(vmin=1, vmax=30, figsize=(14,10))

In [None]:
learn.path=Path('models/seg')
learn.export('seg_v1.pkl')

In [None]:
preds, targs = learn.tta()

In [None]:
preds.shape, targs.shape

In [None]:
PILMask.create(np.array(targs[5]*255).astype(np.uint8))

In [None]:
PILMask.create((np.array(preds[5].argmax(0))*255).astype(np.uint8))

In [None]:
interp = SegmentationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
top_losses, top_idxs = interp.top_losses()

interp.plot_top_losses(9, figsize=(15,15))

## CrossEntropyFlat

In [None]:
learn = unet_learner(dls,resnet34, loss_func = CrossEntropyLossFlat(axis=1),self_attention=True, metrics=custom_accuracy).to_fp16()

In [None]:
learn.lr_find()

In [None]:
learn.fine_tune(12,1e-2)

In [None]:
learn.path=Path('models/seg')
learn.export('seg_celf_v1.pkl')

In [None]:
preds, targs = learn.tta()

In [None]:
PILMask.create(np.array(targs[2000]*255).astype(np.uint8))

In [None]:
PILMask.create((np.array(preds[2000].argmax(0))*255).astype(np.uint8))

In [None]:
interp = SegmentationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
top_losses, top_idxs = interp.top_losses()

interp.plot_top_losses(9, figsize=(15,15))

## Ranger

In [None]:
opt = ranger
learn = unet_learner(dls,resnet34, loss_func = CrossEntropyLossFlat(axis=1),self_attention=True, act_cls=Mish, opt_func=opt, metrics=custom_accuracy).to_fp16()

In [None]:
learn.lr_find()

In [None]:
lr=1e-4

In [None]:
learn.fit_flat_cos(12, slice(lr))

In [None]:
learn.path=Path('models/seg')
learn.export('seg_ranger_v1.pkl')

In [None]:
PILMask.create(np.array(targs[2000]*255).astype(np.uint8))

In [None]:
PILMask.create((np.array(preds[2000].argmax(0))*255).astype(np.uint8))

In [None]:
interp = SegmentationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
top_losses, top_idxs = interp.top_losses()

interp.plot_top_losses(9, figsize=(15,15))

# Using model for cropping images

In [5]:
#|default_exp crop_images

In [1]:
from nbdev.export import nb_export

In [2]:
#|export
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

In [3]:
#|export
from fastai.vision.all import *
from fastcore.parallel import *
from skimage.measure import label, regionprops, find_contours
from datetime import datetime
import cv2

In [4]:
#|export
codes = ["Background", "Foreground"]

In [5]:
#|export
def get_image_mask(fn):
    f=Path(str(fn).replace('images', 'masks').replace('jpg','png'))
    return PILMask.create(f) 

In [6]:
#|export
def custom_accuracy(inp, targ):
    targ = targ.squeeze(1)
    return (inp.argmax(dim=1)==targ).float().mean()

In [16]:
src = Path().absolute()/'data/train_images_orig'
dst = Path().absolute()/'data/train_images_crop'
dst.mkdir(exist_ok=True)

fssrc=set([parent_label(f)+'/'+f.name for f in get_image_files(src)])
fsdst=set([parent_label(f)+'/'+f.name for f in get_image_files(dst)])

fsdelta =  L(fssrc - fsdst)
len(fsdelta)

1933454

In [17]:
fs = L([src/f for f in fsdelta])

In [20]:
len(fs),fs[0]

(1933454,
 Path('/home/bilal/mlworks/surgtoolloc2/data/train_images_orig/clip_011547/01575.jpg'))

In [4]:
#|export
def main():
    # loading the best model
    learn=load_learner('/home/bilal/mlworks/surgtoolloc2/models/seg/seg_v1.pkl', cpu=False)

    # define src and dst folders
    src = Path().absolute()/'data/train_images_orig'
    dst = Path().absolute()/'data/train_images_crop'
    dst.mkdir(exist_ok=True)

    fssrc=set([parent_label(f)+'/'+f.name for f in get_image_files(src)])
    fsdst=set([parent_label(f)+'/'+f.name for f in get_image_files(dst)])

    fsdelta =  L(fssrc - fsdst)
    
    fs = L([src/f for f in fsdelta])

    # creating bunches for processing images
    bunches = [i for i in range(len(fs)) if i%50000==0]

    # for each bunch of images, predict masks and then use it to crop images and save them in the folder
    for i in range(len(bunches)):
        # setting start and end of a batch
        start=bunches[i]
        if not (i==len(bunches)-1):
            if bunches[i]==bunches[i+1]:
                end = len(bunches)
            else:
                end=bunches[i+1]
        else:
            end=len(fs)

        print("-Start Time =", datetime.now().strftime("%H:%M:%S"))

        print(f'-Predicting masks for images: {start} -> {end}.')
        preds,_ = learn.get_preds(dl=learn.dls.test_dl(fs[start:end]))

        print(f'-Cropping and saving images: {start} -> {end}.')

        # for p, f in zip(preds,fs[start:end]):
        for p, f in zip(preds,learn.dl.items):
            dst_clip = dst/parent_label(f)
            dst_clip.mkdir(exist_ok=True)

            fn = f.name

            im=PILImage.create(f)
            (h,w)=im.shape
            # (h,w)=(640,512)
            mask=PILMask.create((np.array(p.argmax(0))*255).astype(np.uint8))
            mask=Resize((h,w), ResizeMethod.Squish) (mask)

            lbl = label(np.array(mask))
            props = regionprops(lbl)
            x1,y1,x2,y2=props[0].bbox[0],props[0].bbox[2],props[0].bbox[1],props[0].bbox[3]

            im_c = PILImage.create(np.array(im)[x1:y1,x2:y2])
            im_c.save(dst_clip/fn)

        print("-End Time =", datetime.now().strftime("%H:%M:%S"))



In [None]:
#|export
if __name__=='__main__':
    main()

In [21]:
nb_export('02_segmentation-image_cleansing.ipynb', '.')

In [None]:
len(get_image_files(src))

In [None]:
len(get_image_files(dst))

In [None]:
assert len(get_image_files(src))==len(get_image_files(dst)), 'Not all images are cropped'