### ConvCRF
[Convolutional CRFs for Semantic Segmentation](https://doi.org/10.48550/arXiv.1805.04777)  
Marvin T. T. Teichmann and Roberto Cipolla

In [1]:
# pip install git+https://github.com/lucasb-eyer/pydensecrf.git

### CryoParticleSegment

In [2]:
#%%bash
#git clone git+https://cyanazuki@github.com/cyanazuki/CryoParticleSegment.git
#cd CryoParticleSegment
#python setup.py install

## ⭐ Setup
You must run all codes under this category.

### ✅ Directory Settings

In [3]:
# @title  { display-mode: "form" }

IMAGE_DIR = "/content/drive/MyDrive/research_xs/processed_micrographs_np" # @param {type:"string"}
LABEL_DIR = "/content/drive/MyDrive/research_xs/ground_truth_mask" # @param {type:"string"}
RESULT_DIR = "/content/drive/MyDrive/research_xs/final/fcn_resnet101/convcrf" # @param {type:"string"}

In [4]:
# @title  { display-mode: "form" }
# @markdown Detect whether using folder in Google Drive as **`RESULT DIR`**📁.

if "content" in IMAGE_DIR.split("/")[:3] or "content" in LABEL_DIR.split("/")[:3]:
  try:
    from google.colab import drive
    drive.mount('/content/drive')
    !rm -r /content/sample_data
    if "content" in IMAGE_DIR.split("/")[:3]:
      !cp -r {IMAGE_DIR} /content/image_dir
      IMAGE_DIR = "/content/image_dir"
    if "content" in LABEL_DIR.split("/")[:3]:
      !cp -r {LABEL_DIR} /content/label_dir
      LABEL_DIR = "/content/label_dir"
  except:
    pass

Mounted at /content/drive


In [5]:
# @title  { display-mode: "form" }
# @markdown Source code directory.
SRC_DIR = "/content/drive/MyDrive/research/src" # @param {type:"string"}

if True:
  !cp -r {SRC_DIR}/EM_project/*.py /content/
else:
  !cp {SRC_DIR}/EM_project/convcrf.py /content/convcrf.py
  !cp {SRC_DIR}/EM_project/dataset.py /content/dataset.py
  !cp {SRC_DIR}/EM_project/lr_scheduler.py /content/lr_scheduler.py
  !cp {SRC_DIR}/EM_project/metrics.py /content/metrics.py
  !cp {SRC_DIR}/EM_project/model.py /content/model.py
  !cp {SRC_DIR}/EM_project/trainer.py /content/trainer.py
  !cp {SRC_DIR}/EM_project/utils.py /content/utils.py
#!rm /content/convcrf.py

### ✅ Packages Handling

In [6]:
# @title  { display-mode: "form" }
# @markdown Useful packages.

import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

In [7]:
# @title  { display-mode: "form" }
# @markdown User-defined packages.

from dataset import MicrographDataset, MicrographDatasetEvery
from dataset import reconstruct_patched
from model import create_model
from trainer import CryoEMEvaluator

## ⭐ Main

### ✅ Setting

In [8]:
# @markdown Parameters.

NUM_CLASSES = 2
# EPOCHS = 100
# BATCH = 7
CROP_SIZE = (512, 512)
# LR = 1e-3
# RLR_PATIENCE = 3
# ES_PATIENCE = 20
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
# @markdown Set seed.

random_state = 42
torch.manual_seed(random_state)
torch.cuda.manual_seed_all(random_state)

### ✅ Dataset

In [10]:
train_dir = os.path.join(IMAGE_DIR, 'train')
train_filenames = np.loadtxt(f"{IMAGE_DIR}/train_filenames.txt", dtype=str)
train_dataset = MicrographDataset(image_dir=train_dir, label_dir=LABEL_DIR, filenames=train_filenames, crop_size=CROP_SIZE)

In [11]:
val_dir = os.path.join(IMAGE_DIR, 'val')
val_filenames = np.loadtxt(f"{IMAGE_DIR}/val_filenames.txt", dtype=str)
val_dataset = MicrographDatasetEvery(image_dir=val_dir, label_dir=LABEL_DIR, filenames=val_filenames, crop_size=CROP_SIZE)
val_loader = DataLoader(val_dataset, batch_size=None, shuffle=False, pin_memory=True)

In [12]:
test_dir = os.path.join(IMAGE_DIR, 'test')
test_filenames = np.loadtxt(f"{IMAGE_DIR}/test_filenames.txt", dtype=str)
test_dataset = MicrographDatasetEvery(image_dir=test_dir, label_dir=LABEL_DIR, filenames=test_filenames, crop_size=CROP_SIZE)
test_loader = DataLoader(test_dataset, batch_size=None, shuffle=False, pin_memory=True)

### ✅ Model

In [13]:
backbone = torch.hub.load('pytorch/vision:v0.10.0', 'fcn_resnet101', num_classes=NUM_CLASSES)

backbone.backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
backbone.classifier[3].p = 0.5

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:02<00:00, 71.7MB/s]


## ⭐ Evaluate

### ✅ Load model

In [14]:
CHECKPOINT_PATH = "/content/drive/MyDrive/research_xs/final/fcn_resnet101/checkpoint74_B7_L-3_C512.pt" # @param {type:"string"}
state_dict_path = CHECKPOINT_PATH
state_dict = torch.load(state_dict_path, map_location=torch.device(DEVICE))
backbone.load_state_dict(state_dict, strict=False)
model = create_model(backbone)
model.to(DEVICE)
model.eval()
print("Load model at: ", state_dict_path)

Load model at:  /content/drive/MyDrive/research_xs/final/fcn_resnet101/checkpoint74_B7_L-3_C512.pt


In [15]:
import gc
gc.collect()
torch.cuda.empty_cache()

### ✅ Evalutor

In [None]:
evaluator = CryoEMEvaluator(model=model, device=DEVICE, metrics=['iou'], num_classes=NUM_CLASSES)
print("FCN validation result:")
result = evaluator.evaluate(loader=val_loader)
print("FCN test result:")
result = evaluator.evaluate(loader=test_loader)

FCN validation result:
global correct: 91.44
average row correct: ['94.86', '80.98']
Recall: ['93.85', '83.75']
IoU: ['89.30', '69.99']
F1 Score: ['94.35', '82.34']
mean IoU: 79.65
FCN test result:
global correct: 90.29
average row correct: ['94.27', '76.50']
Recall: ['93.30', '79.37']
IoU: ['88.29', '63.81']
F1 Score: ['93.78', '77.91']
mean IoU: 76.05


In [16]:
evaluator = CryoEMEvaluator(model=model, device=DEVICE, metrics=['iou'], num_classes=NUM_CLASSES)
print("FCN validation result:")
result = evaluator.evaluate(loader=val_loader)
print("FCN test result:")
result = evaluator.evaluate(loader=test_loader)

FCN validation result:
global correct: 91.44
average row correct: ['94.86', '80.98']
Recall: ['93.85', '83.75']
IoU: ['89.30', '69.99']
F1 Score: ['94.35', '82.34']
mean IoU: 79.65
FCN test result:
global correct: 90.29
average row correct: ['94.27', '76.50']
Recall: ['93.30', '79.37']
IoU: ['88.29', '63.81']
F1 Score: ['93.78', '77.91']
mean IoU: 76.05


### ✅ Evaluate convcrf

In [17]:
# @markdown config dondong
import yaml

!cp {SRC_DIR}/config.yml /content/config.yml
config_path = "/content/config.yml"

with open(config_path, 'r') as f:
    data = yaml.safe_load(f)
    crf_args = data['crf_config']

In [18]:
# @markdown config
config = {
    'filter_size': 7,
    'blur': 4,
    'merge': True,
    'norm': 'none', # sym
    'trainable': False,
    'convcomp': False,

    'weight': 'vector', # scalar
    'unary_weight': 1,
    'weight_init': 0.2,
    'logsoftmax': True,
    'softmax': True,
    'final_softmax': False,

    'pos_feats': {
        'sdims': 3, # 3
        'compat': 3, # 3,
    },

    'col_feats': {
        'sdims': 3, # 80,
        'schan': 0.5, # 13
        'compat': 1, # 10
        'use_bias': False, # True
    },
    "trainable_bias": False,
    "pyinn": False
}

In [19]:
from model import create_crf_model

model = create_crf_model(
    backbone, config=config, shape=CROP_SIZE,
    num_classes=NUM_CLASSES, use_gpu=torch.cuda.is_available()) #crf_args

In [20]:
# config['filter_size'] = 7
evaluator = CryoEMEvaluator(
    model=model, device=DEVICE, metrics=['iou'],
    num_classes=NUM_CLASSES)
print("ConvCRF validation result:")
result = evaluator.evaluate(loader=val_loader)
print("ConvCRF test result:")
result = evaluator.evaluate(loader=test_loader)

ConvCRF validation result:
global correct: 90.48
average row correct: ['95.91', '73.88']
Recall: ['91.82', '85.53']
IoU: ['88.36', '65.67']
F1 Score: ['93.82', '79.28']
mean IoU: 77.02
ConvCRF test result:
global correct: 89.57
average row correct: ['95.35', '69.53']
Recall: ['91.56', '81.17']
IoU: ['87.65', '59.87']
F1 Score: ['93.42', '74.90']
mean IoU: 73.76


In [None]:
from torchvision.utils import save_image
from dataset import reconstruct_patched

!mkdir {RESULT_DIR}/test_image
model.eval()
with torch.no_grad():
  for idx, (test_image, _, grid, _) in enumerate(test_dataset):
    inputs = test_image.to(DEVICE)
    outputs = model(inputs)['out']
    preds = outputs.argmax(dim=1).cpu().detach()
    filename = f"{os.path.splitext(test_dataset.filenames[idx])[0]}.png"
    pred_path = os.path.join(RESULT_DIR, "test_image", filename)
    save_image(reconstruct_patched(preds, grid).float(), pred_path)