In [None]:
import torch, torchvision

import numpy as np
import cv2
import random
import matplotlib.pyplot as plt
import matplotlib
import h5py
import collections
import os
import pycocotools
import sys 
import json
import pickle
from sklearn.cluster import KMeans
from scipy.ndimage.measurements import label
import mrcfile
import skimage.io
import skimage.morphology

import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from detectron2.structures import BoxMode
from detectron2.data import datasets
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.engine import DefaultTrainer
from detectron2.utils.visualizer import ColorMode
from detectron2.structures import Boxes, ImageList, Instances
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.modeling import build_model

import sys
sys.path.append('../dataset_and_eval')
import eval

In [None]:
# First stage feature adaptation

In [None]:
cryo_dataset_dict = pickle.load(open('/pasteur/data/darcnn_dataset/datadicts/bbbc_unlabelled_dict.p', 'rb'))
coco_dataset_dict = pickle.load(open('/pasteur/u/joycj/darcnn/coco/coco_dict.p', 'rb'))
print(len(cryo_dataset_dict))
print(len(coco_dataset_dict))

joint_dataset_dict = []
for i in range(9749):
    idx = random.randint(0, 9749-1)
    joint_dataset_dict.append(coco_dataset_dict[idx])
    joint_dataset_dict.append(cryo_dataset_dict[idx])
print(len(joint_dataset_dict))

d = ['temp']   
DatasetCatalog.register('train', lambda d=d: joint_dataset_dict)
metadata = MetadataCatalog.get('train')



In [None]:
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.WEIGHTS = '/pasteur/data/darcnn_dataset/checkpoints/coco/class_agnostic_maskrcnn/model_final.pth'
#cfg.MODEL.WEIGHTS = '/pasteur/data/darcnn_dataset/checkpoints/bbbc/base_darcnn_fin/model_0000639.pth' # 639



cfg.DATASETS.TRAIN = ('train',)
cfg.DATASETS.TEST = ('train',)
cfg.DATALOADER.NUM_WORKERS = 1

cfg.SOLVER.IMS_PER_BATCH = 4
cfg.SOLVER.BASE_LR = 0.0001
cfg.SOLVER.MAX_ITER = 3000
cfg.SOLVER.CHECKPOINT_PERIOD = 20

cfg.MODEL.META_ARCHITECTURE = 'DomainSeparationDARCNN'
cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS = False
cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE = 0
cfg.DATALOADER.ASPECT_RATIO_GROUPING = False
    
cfg.MODEL.FG = 10
cfg.MODEL.BG = 10
cfg.MODEL.SHARED_SIM = 1
cfg.MODEL.CRYO_DIFF = 1
cfg.MODEL.COCO_DIFF = 1
    
    
cfg.TEST.DETECTIONS_PER_IMAGE = 50
    
cfg.OUTPUT_DIR = '/pasteur/data/darcnn_dataset/checkpoints/final/' + 'bbbc_darcnn' 

'''os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()'''



In [None]:
img = '/pasteur/data/darcnn_dataset/bbbc/10k_256/image800.png'
img = cv2.imread(img)
gt = '/pasteur/data/darcnn_dataset/bbbc/10k_256/mask800.png'
gt = cv2.imread(gt)

img = img[:256, :256]
gt = gt[:256, :256]

cfg.MODEL.WEIGHTS = '/pasteur/data/darcnn_dataset/checkpoints/final/bbbc_darcnn/model_0000399.pth'

cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05
cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.05
cfg.TEST.DETECTIONS_PER_IMAGE = 200
cfg.MODEL.RPN.POST_NMS_TOPK_TEST = 1000

predictor = DefaultPredictor(cfg)

outputs = predictor(img)
v = Visualizer(img[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.0)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20 , 20))
ax1.imshow(gt)
ax2.imshow(out.get_image())

from_gt = out.get_image()
from_gt_pred_masks = outputs['instances'].pred_masks.cpu().numpy()

In [None]:
# Second stage pseudolabelling

In [None]:
def bbox2(img):
    rows = np.any(img, axis=1)
    cols = np.any(img, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    return cmin, rmin, cmax, rmax

In [None]:
def get_tnbc_self_train_dataset_dicts(predictor):
    tnbc_path = '/pasteur/data/darcnn_dataset/bbbc/10k_256/'

    idx = 0
    dataset_dicts = []
    
    for path in os.listdir(tnbc_path)[:1000]:
        if 'image' not in path:
            continue
            
        print(idx)
        
        record = {}
            
        curr_path = tnbc_path + path
        record['file_name'] = curr_path
            
        record['image_id'] = idx
        idx += 1
            
        record['height'] = 256
        record['width'] = 256
        
        img = cv2.imread(curr_path)

        img = cv2.GaussianBlur(img, (3,3), 0)
        img = cv2.convertScaleAbs(img, alpha=2.5, beta=-250)
        
        outputs = predictor(img)
        pred_masks = outputs['instances'].pred_masks.cpu().numpy()
        scores = outputs['instances'].scores.cpu().numpy()
        
        record['confidence'] = scores
        
        objs = []
        for i in range(pred_masks.shape[0]):
            curr_obj = pred_masks[i, :, :]

            try:
                bbox = bbox2(curr_obj)
            except:
                plt.imshow(curr_obj)
                continue
                
            curr_obj = curr_obj.astype(np.uint8)

            obj = {
                'bbox': bbox,
                'bbox_mode': BoxMode.XYXY_ABS,
                'segmentation': pycocotools.mask.encode(np.asarray(curr_obj, order="F")),
                'category_id': 1,
            }
            objs.append(obj)
            
        record['annotations'] = objs 
                
        dataset_dicts.append(record) 
    
    return dataset_dicts




In [None]:
#tnbc_dataset_dicts_predicted = pickle.load(open('/pasteur/data/darcnn_dataset/datadicts/bbbc_final.p', 'rb'))
tnbc_dataset_dicts_predicted = get_tnbc_self_train_dataset_dicts(predictor)

print(len(tnbc_dataset_dicts_predicted))
pickle.dump(tnbc_dataset_dicts_predicted, open('/pasteur/data/darcnn_dataset/datadicts/bbbc_final.p', 'wb'))

random.shuffle(tnbc_dataset_dicts_predicted)

d = ['temp']   
DatasetCatalog.register('tnbc', lambda d=d: tnbc_dataset_dicts_predicted)
metadata = MetadataCatalog.get('tnbc')



In [None]:
for d in random.sample(tnbc_dataset_dicts_predicted, 3):
    img = cv2.imread(d["file_name"])
    visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, scale=1.0)
    out = visualizer.draw_dataset_dict(d)
    plt.figure()
    plt.imshow(out.get_image()[:, :, ::-1])

In [None]:
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.WEIGHTS = '/pasteur/data/darcnn_dataset/checkpoints/final/bbbc_darcnn/model_0000299.pth'

cfg.DATASETS.TRAIN = ('tnbc',)
cfg.DATASETS.TEST = ('tnbc',)
cfg.DATALOADER.NUM_WORKERS = 1

cfg.MODEL.FG = 10
cfg.MODEL.BG = 10
cfg.MODEL.SHARED_SIM = 1
cfg.MODEL.CRYO_DIFF = 1
cfg.MODEL.COCO_DIFF = 1

cfg.SOLVER.BASE_LR = 0.0001 
cfg.SOLVER.IMS_PER_BATCH = 4
cfg.SOLVER.MAX_ITER = 3000

cfg.MODEL.META_ARCHITECTURE = 'PseudolabelTargetOnlyDARCNN'
cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS = False
cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE = 0
cfg.DATALOADER.ASPECT_RATIO_GROUPING = False
cfg.INPUT.MASK_FORMAT = 'bitmask'

cfg.SOLVER.CHECKPOINT_PERIOD = 20

cfg.OUTPUT_DIR = '/pasteur/data/darcnn_dataset/checkpoints/final/bbbc_pseudolabel_no_aug'

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=True)
trainer.train()




In [None]:
img = '/pasteur/data/darcnn_dataset/bbbc/10k_256/image800.png'
img = cv2.imread(img)
gt = '/pasteur/data/darcnn_dataset/bbbc/10k_256/mask800.png'
gt = cv2.imread(gt)

img = img[:256, :256]
gt = gt[:256, :256]

cfg.MODEL.WEIGHTS = '/pasteur/data/darcnn_dataset/checkpoints/final/bbbc_pseudolabel/model_0000839.pth'

cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.06 # 0.07
cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.05
cfg.TEST.DETECTIONS_PER_IMAGE = 200
cfg.MODEL.RPN.POST_NMS_TOPK_TEST = 1000

predictor = DefaultPredictor(cfg)

outputs = predictor(img)

v = Visualizer(img[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.0)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20 , 20))
ax1.imshow(gt)
ax2.imshow(out.get_image())
predicted = out.get_image()

from_gt = out.get_image()
from_gt_pred_masks = outputs['instances'].pred_masks.cpu().numpy()