<a href="https://colab.research.google.com/github/casual-lab/colab-notebooks/blob/main/medical_inst_segmentation_medicalnet_neu_rat_lung.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

# 准备数据



In [2]:
!pip install "labelbox[data]"

Collecting labelbox[data]
  Downloading labelbox-3.19.0-py3-none-any.whl (162 kB)
[?25l[K     |██                              | 10 kB 18.2 MB/s eta 0:00:01[K     |████                            | 20 kB 24.7 MB/s eta 0:00:01[K     |██████                          | 30 kB 29.8 MB/s eta 0:00:01[K     |████████                        | 40 kB 29.6 MB/s eta 0:00:01[K     |██████████                      | 51 kB 22.2 MB/s eta 0:00:01[K     |████████████                    | 61 kB 24.5 MB/s eta 0:00:01[K     |██████████████                  | 71 kB 25.4 MB/s eta 0:00:01[K     |████████████████▏               | 81 kB 26.8 MB/s eta 0:00:01[K     |██████████████████▏             | 92 kB 28.9 MB/s eta 0:00:01[K     |████████████████████▏           | 102 kB 27.7 MB/s eta 0:00:01[K     |██████████████████████▏         | 112 kB 27.7 MB/s eta 0:00:01[K     |████████████████████████▏       | 122 kB 27.7 MB/s eta 0:00:01[K     |██████████████████████████▏     | 133 kB 27.7 M

In [3]:
from labelbox import Client, OntologyBuilder
from labelbox.data.annotation_types import Geometry
from  labelbox.data.annotation_types.collection import LabelList
from PIL import Image
import numpy as np
import os
import torch

In [4]:
from enum import Enum

class SegClsName:
  VESSEL = "血管"
  BRONCHUS = "支气管"

  def get_all_names():
    return ['支气管', '血管']

In [5]:
API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VySWQiOiJjbDFwMTI0NncwMnZ0MHo3cGdieGthaGRoIiwib3JnYW5pemF0aW9uSWQiOiJjbDFwMTI0NmwwMnZzMHo3cDJhZXlicXBxIiwiYXBpS2V5SWQiOiJjbDF3N2NxY3o0M2tiMHpiaWh3ZjE0Y2t5Iiwic2VjcmV0IjoiNTExZGFlOTQ4NzQ0MjI0YjQ4MjI1MWZmZTk0NDJkMDkiLCJpYXQiOjE2NDk3NzE0OTcsImV4cCI6MjI4MDkyMzQ5N30.afEQowJg4cIlz2yZJMOQE8r5NuzAglwcifskm8GfZQY"
PROJECT_ID = "cl1vkawjv12se0zdr5o4vf9xu"
client = Client(api_key=API_KEY)
project = client.get_project(PROJECT_ID)
labels = project.label_generator().as_list()



In [6]:
def segmentation_cls_filter(lb_labels: LabelList, cls_name):
  '''取存在特定实体标签的样本'''
  result = LabelList()
  for lb in lb_labels:
    for a in lb.annotations:
      # print(a.name)
      if a.name == cls_name:
        result.append(lb)
        break

  return result

def classification_filter(lb_labels: LabelList):
  '''取存在全局分类标签的样本'''
  return LabelList([lb for lb in labels if len(lb.classification_annotations())==0])

In [7]:
class LabelBoxInstSegDataset(torch.utils.data.Dataset):

  def __init__(self, lb_labels, cls_name, transforms=None):
    self.transforms = transforms
    
    assert cls_name in SegClsName.get_all_names()
    self.cls_name = cls_name
    self.lb_labels = segmentation_cls_filter(lb_labels, cls_name)

  def __getitem__(self, idx):
    data_item = self.lb_labels[idx]
    # load images
    img = Image.fromarray(np.uint8(self.lb_labels[idx].data.value)).convert("RGB")
    mask = np.zeros((img.size[1], img.size[0]))
    for i, a in enumerate(data_item.annotations):
      if a.name == self.cls_name:
        mask += a.value.draw()[:, :, 0]

    # instances are encoded as different colors
    obj_ids = np.unique(mask)
    # first id is the background, so remove it
    obj_ids = obj_ids[1:]

    # split the color-encoded mask into a set
    # of binary masks
    masks = (mask == obj_ids[:, None, None])

    # get bounding box coordinates for each mask
    num_objs = len(obj_ids)
    boxes = []
    for i in range(num_objs):
      pos = np.where(masks[i])
      xmin = np.min(pos[1])
      xmax = np.max(pos[1])
      ymin = np.min(pos[0])
      ymax = np.max(pos[0])
      if xmin == xmax or ymin == ymax: continue
      boxes.append([xmin, ymin, xmax, ymax])

    boxes = torch.as_tensor(boxes, dtype=torch.float32)
    # there is only one class
    labels = torch.ones((num_objs,), dtype=torch.int64)
    masks = torch.as_tensor(masks, dtype=torch.uint8)

    image_id = torch.tensor([idx])
    area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
    # suppose all instances are not crowd
    iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

    target = {}
    target["boxes"] = boxes
    target["labels"] = labels
    target["masks"] = masks
    target["image_id"] = image_id
    target["area"] = area
    target["iscrowd"] = iscrowd

    if self.transforms is not None:
      img, target = self.transforms(img, target)

    return img, target

  def __len__(self):
    return len(self.lb_labels)

In [8]:
dataset = LabelBoxInstSegDataset(labels, SegClsName.BRONCHUS)
dataset[0]

(<PIL.Image.Image image mode=RGB size=1440x1024 at 0x7FE7AD0D9C50>,
 {'area': tensor([95760.]),
  'boxes': tensor([[691., 130., 995., 445.]]),
  'image_id': tensor([0]),
  'iscrowd': tensor([0]),
  'labels': tensor([1]),
  'masks': tensor([[[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]]], dtype=torch.uint8)})

# 定义模型

In [9]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

      
def get_instance_segmentation_model(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model

In [10]:
%%shell

# Download TorchVision repo to use some files from
# references/detection
git clone https://github.com/pytorch/vision.git
cd vision
git checkout v0.8.2

cp references/detection/utils.py ../
cp references/detection/transforms.py ../
cp references/detection/coco_eval.py ../
cp references/detection/engine.py ../
cp references/detection/coco_utils.py ../

Cloning into 'vision'...
remote: Enumerating objects: 119919, done.[K
remote: Counting objects: 100% (12625/12625), done.[K
remote: Compressing objects: 100% (1057/1057), done.[K
remote: Total 119919 (delta 11645), reused 12382 (delta 11506), pack-reused 107294[K
Receiving objects: 100% (119919/119919), 232.11 MiB | 22.61 MiB/s, done.
Resolving deltas: 100% (104299/104299), done.
Note: checking out 'v0.8.2'.

You are in 'detached HEAD' state. You can look around, make experimental
changes and commit them, and you can discard any commits you make in this
state without impacting any branches by performing another checkout.

If you want to create a new branch to retain commits you create, you may
do so (now or later) by using -b with the checkout command again. Example:

  git checkout -b <new-branch-name>

HEAD is now at 2f40a483d [v0.8.X] .circleci: Add Python 3.9 to CI (#3063)




In [11]:
from engine import train_one_epoch, evaluate
import utils
import transforms as T


def get_transform(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    transforms.append(T.ToTensor())
    if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [12]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
dataset = LabelBoxInstSegDataset(labels, SegClsName.BRONCHUS, transforms=get_transform(train=True))
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=4,
    collate_fn=utils.collate_fn
)
# For Training
images,targets = next(iter(data_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
output = model(images,targets)   # Returns losses and detections
# For inference
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)           # Returns predictions


Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth


  0%|          | 0.00/160M [00:00<?, ?B/s]

  cpuset_checked))
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


# 训练

In [13]:
# use our dataset and defined transformations
dataset = LabelBoxInstSegDataset(labels, SegClsName.BRONCHUS, transforms=get_transform(train=True))
dataset_test = LabelBoxInstSegDataset(labels, SegClsName.BRONCHUS, transforms=get_transform(train=False))

# split the dataset in train and test set
torch.manual_seed(1)
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-5])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-5:])

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=4,
    collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=utils.collate_fn)

  cpuset_checked))


In [14]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# our dataset has two classes only - background and person
num_classes = 2

# get the model using our helper function
model = get_instance_segmentation_model(num_classes)
# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)

# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth


  0%|          | 0.00/170M [00:00<?, ?B/s]

In [None]:
# let's train it for 10 epochs
from torch.optim.lr_scheduler import StepLR
num_epochs = 10

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

torch.save(model.state_dict(), "/content/drive/MyDrive/model.pt")

  cpuset_checked))


Epoch: [0]  [ 0/36]  eta: 0:27:36  lr: 0.000148  loss: 2.4396 (2.4396)  loss_classifier: 0.4310 (0.4310)  loss_box_reg: 0.0668 (0.0668)  loss_mask: 1.8789 (1.8789)  loss_objectness: 0.0561 (0.0561)  loss_rpn_box_reg: 0.0067 (0.0067)  time: 46.0069  data: 2.1246
Epoch: [0]  [10/36]  eta: 0:16:58  lr: 0.001575  loss: 1.7277 (1.6444)  loss_classifier: 0.1609 (0.2352)  loss_box_reg: 0.0793 (0.0803)  loss_mask: 1.2942 (1.2879)  loss_objectness: 0.0358 (0.0340)  loss_rpn_box_reg: 0.0067 (0.0071)  time: 39.1705  data: 0.2007


# 评估

In [None]:
# pick one image from the test set
img, _ = dataset_test[1]
# put the model in evaluation mode
model.eval()
with torch.no_grad():
    prediction = model([img.to(device)])

In [None]:
prediction

In [None]:
Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())

In [None]:
Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())

In [None]:
from PIL import Image, ImageDraw, ImageFont
origin = Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())
draw = ImageDraw.Draw(origin)
draw.rectangle(xy=tuple(prediction[0]['boxes'][0]), fill=None, outline="red", width=1)
origin

In [None]:
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt


for i, box in enumerate(prediction[0]['boxes']):
  origin = Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())
  draw = ImageDraw.Draw(origin)
  draw.rectangle(xy=tuple(box), fill=None, outline=i*200, width=1)
  display(origin)