In [1]:
import torch, torchvision
import torch.nn as nn

import numpy as np
import os, json, cv2, random

import detectron2
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor, DefaultTrainer
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.modeling import build_model,build_resnet_backbone,build_backbone
from detectron2.structures import ImageList, Instances
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
from detectron2.evaluation import COCOEvaluator
from detectron2.data.datasets import register_coco_instances
from detectron2.utils.logger import setup_logger
setup_logger()

import timm
from timm.models.vision_transformer import VisionTransformer
from timm.models.vision_transformer import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
register_coco_instances("SegPC_train", {}, "../TCIA_SegPC_dataset/coco/COCO.json", "../TCIA_SegPC_dataset/coco/x/")
register_coco_instances("SegPC_val", {}, "../TCIA_SegPC_dataset/coco_val/COCO.json", "../TCIA_SegPC_dataset/coco_val/x/")


train_meta = MetadataCatalog.get('SegPC_train')
val_meta = MetadataCatalog.get('SegPC_val')

train_dicts = DatasetCatalog.get("SegPC_train")
val_dicts = DatasetCatalog.get("SegPC_val")

[32m[04/27 23:09:33 d2.data.datasets.coco]: [0mLoaded 298 images in COCO format from ../TCIA_SegPC_dataset/coco/COCO.json
[32m[04/27 23:09:34 d2.data.datasets.coco]: [0mLoaded 200 images in COCO format from ../TCIA_SegPC_dataset/coco_val/COCO.json


In [3]:
class Transformer_Encoder(VisionTransformer):
    def __init__(self, pretrained = False, pretrained_model = None, img_size=224, patch_size=16, in_chans=3, num_classes=1, embed_dim=768, depth=12,
                  num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                  drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):

        super(Transformer_Encoder, self).__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=1000, embed_dim=embed_dim, depth=depth,
                  num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
                  drop_path_rate=drop_path_rate, hybrid_backbone=hybrid_backbone, norm_layer=norm_layer)
        
        self.num_classes = 1
        self.dispatcher = {
            'vit_small_patch16_224': vit_small_patch16_224,
            'vit_base_patch16_224': vit_base_patch16_224,
            'vit_large_patch16_224': vit_large_patch16_224,
            'vit_base_patch16_384': vit_base_patch16_384,
            'vit_base_patch32_384': vit_base_patch32_384,
            'vit_large_patch16_384': vit_large_patch16_384,
            'vit_large_patch32_384': vit_large_patch32_384,
            'vit_large_patch16_224' : vit_large_patch16_224,
            'vit_large_patch32_384': vit_large_patch32_384,
            'vit_small_resnet26d_224': vit_small_resnet26d_224,
            'vit_small_resnet50d_s3_224': vit_small_resnet50d_s3_224,
            'vit_base_resnet26d_224' : vit_base_resnet26d_224,
            'vit_base_resnet50d_224' : vit_base_resnet50d_224,
        }
        self.pretrained_model = pretrained_model
        self.pretrained = pretrained
        if pretrained:
            self.load_weights()
        self.head = nn.Identity()
        self.encoder_out = [1,2,3,4,5]

    def forward_features(self, x):

        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)  
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        features = []

        for i,blk in enumerate(self.blocks,1):
            x = blk(x)
            if i in self.encoder_out:
                features.append(x)

        for i in range(len(features)):
            features[i] = self.norm(features[i])

        return features

    def forward(self, x):

        features = self.forward_features(x)
        return features
    
    def load_weights(self):
        model = None
        try:
            model = self.dispatcher[self.pretrained_model](pretrained=True)
        except:
            print('could not not load model')
        if model == None:
            return
        # try:
        self.load_state_dict(model.state_dict())
        print("successfully loaded weights!!!")
        
        # except:
        #     print("Could not load weights. Parameters should match!!!")

In [4]:
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml"))
cfg.DATASETS.TRAIN = ("SegPC_train",)
cfg.DATASETS.TEST = ("SegPC_val",)

cfg.DATALOADER.NUM_WORKERS = 4
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml")  
cfg.SOLVER.IMS_PER_BATCH = 8
cfg.SOLVER.BASE_LR = 0.02/8
cfg.SOLVER.LR_SCHEDULER_NAME = 'WarmupCosineLR'

cfg.SOLVER.WARMUP_ITERS = 1500
cfg.SOLVER.MAX_ITER = 37
cfg.SOLVER.STEPS = (1000, 1500)
cfg.SOLVER.GAMMA = 0.05
cfg.SOLVER.CHECKPOINT_PERIOD = 1000

cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 64
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1

cfg.TEST.EVAL_PERIOD = 250

cfg.CUDNN_BENCHMARK = True
cfg.OUTPUT_DIR = "./output/"

In [5]:
class CocoTrainer(DefaultTrainer):
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            os.makedirs("coco_eval", exist_ok=True)
            output_folder = "coco_eval"
            
        return COCOEvaluator(dataset_name, cfg, False, output_folder)

In [6]:
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

In [7]:
trainer = CocoTrainer(cfg)
trainer.resume_or_load(resume=True)
trainer.train()

[32m[04/27 23:09:48 d2.engine.defaults]: [0mModel:
GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
      )
 

  max_size = (max_size + (stride - 1)) // stride * stride
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[32m[04/27 23:09:55 d2.evaluation.evaluator]: [0mInference done 11/200. Dataloading: 0.0008 s/iter. Inference: 0.1717 s/iter. Eval: 0.2259 s/iter. Total: 0.3984 s/iter. ETA=0:01:15
[32m[04/27 23:10:01 d2.evaluation.evaluator]: [0mInference done 24/200. Dataloading: 0.0012 s/iter. Inference: 0.1745 s/iter. Eval: 0.2296 s/iter. Total: 0.4054 s/iter. ETA=0:01:11
[32m[04/27 23:10:06 d2.evaluation.evaluator]: [0mInference done 36/200. Dataloading: 0.0013 s/iter. Inference: 0.1753 s/iter. Eval: 0.2385 s/iter. Total: 0.4152 s/iter. ETA=0:01:08
[32m[04/27 23:10:11 d2.evaluation.evaluator]: [0mInference done 48/200. Dataloading: 0.0013 s/iter. Inference: 0.1751 s/iter. Eval: 0.2468 s/iter. Total: 0.4233 s/iter. ETA=0:01:04
[32m[04/27 23:10:17 d2.evaluation.evaluator]: [0mInference done 60/200. Dataloading: 0.0012 s/iter. Inference: 0.1752 s/iter. Eval: 0.2510 s/iter. Total: 0.4276 s/iter. ETA=0:00:59
[32m[04/27 23:10:22 d2.evaluation.evaluator]: [0mInference done 72/200. Dataloading