### ConvCRF
[Convolutional CRFs for Semantic Segmentation](https://doi.org/10.48550/arXiv.1805.04777)  
Marvin T. T. Teichmann and Roberto Cipolla

In [1]:
# pip install git+https://github.com/lucasb-eyer/pydensecrf.git

### CryoParticleSegment

In [2]:
#%%bash
#git clone git+https://cyanazuki@github.com/cyanazuki/CryoParticleSegment.git
#cd CryoParticleSegment
#python setup.py install

## ⭐ Setup
You must run all codes under this category.

### ✅ Directory Settings

In [3]:
# @title  { display-mode: "form" }

IMAGE_DIR = "/content/drive/MyDrive/research_xs/processed_micrographs_np" # @param {type:"string"}
LABEL_DIR = "/content/drive/MyDrive/research_xs/ground_truth_mask" # @param {type:"string"}
RESULT_DIR = "/content/drive/MyDrive/research_xs/final/fcn_resnet101/convcrf_with_F" # @param {type:"string"}

In [4]:
# @title  { display-mode: "form" }
# @markdown Detect whether using folder in Google Drive as **`RESULT DIR`**📁.

if "content" in IMAGE_DIR.split("/")[:3] or "content" in LABEL_DIR.split("/")[:3]:
  try:
    from google.colab import drive
    drive.mount('/content/drive')
    !rm -r /content/sample_data
    if "content" in IMAGE_DIR.split("/")[:3]:
      !cp -r {IMAGE_DIR} /content/image_dir
      IMAGE_DIR = "/content/image_dir"
    if "content" in LABEL_DIR.split("/")[:3]:
      !cp -r {LABEL_DIR} /content/label_dir
      LABEL_DIR = "/content/label_dir"
  except:
    pass

Mounted at /content/drive


In [5]:
# @title  { display-mode: "form" }
# @markdown Source code directory.
SRC_DIR = "/content/drive/MyDrive/research/src" # @param {type:"string"}

if True:
  !cp -r {SRC_DIR}/EM_project/*.py /content/
else:
  !cp {SRC_DIR}/EM_project/convcrf.py /content/convcrf.py
  !cp {SRC_DIR}/EM_project/dataset.py /content/dataset.py
  !cp {SRC_DIR}/EM_project/lr_scheduler.py /content/lr_scheduler.py
  !cp {SRC_DIR}/EM_project/metrics.py /content/metrics.py
  !cp {SRC_DIR}/EM_project/model.py /content/model.py
  !cp {SRC_DIR}/EM_project/trainer.py /content/trainer.py
  !cp {SRC_DIR}/EM_project/utils.py /content/utils.py
# !rm /content/convcrf.py

### ✅ Packages Handling

In [6]:
# @title  { display-mode: "form" }
# @markdown Useful packages.

import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [7]:
# @title  { display-mode: "form" }
# @markdown User-defined packages.

from dataset import MicrographDataset, MicrographDatasetEvery
from dataset import reconstruct_patched
from model import create_model
from trainer import CryoEMEvaluator
from trainer import CryoEMTrainerWithScheduler, tqdm_plugin_for_Trainer

## ⭐ Main

### ✅ Setting

In [8]:
# @markdown Parameters.

NUM_CLASSES = 2
EPOCHS = 100
BATCH = 7
CROP_SIZE = (512, 512)
LR = 1e-5
RLR_PATIENCE = 3
ES_PATIENCE = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
# @markdown Set seed.

random_state = 42
torch.manual_seed(random_state)
torch.cuda.manual_seed_all(random_state)

### ✅ Dataset

In [10]:
train_dir = os.path.join(IMAGE_DIR, 'train')
train_filenames = np.loadtxt(f"{IMAGE_DIR}/train_filenames.txt", dtype=str)
train_dataset = MicrographDataset(image_dir=train_dir, label_dir=LABEL_DIR, filenames=train_filenames, crop_size=CROP_SIZE)

In [11]:
val_dir = os.path.join(IMAGE_DIR, 'val')
val_filenames = np.loadtxt(f"{IMAGE_DIR}/val_filenames.txt", dtype=str)
val_dataset = MicrographDatasetEvery(image_dir=val_dir, label_dir=LABEL_DIR, filenames=val_filenames, crop_size=CROP_SIZE)
val_loader = DataLoader(val_dataset, batch_size=None, shuffle=False, pin_memory=True)

In [12]:
test_dir = os.path.join(IMAGE_DIR, 'test')
test_filenames = np.loadtxt(f"{IMAGE_DIR}/test_filenames.txt", dtype=str)
test_dataset = MicrographDatasetEvery(image_dir=test_dir, label_dir=LABEL_DIR, filenames=test_filenames, crop_size=CROP_SIZE)
test_loader = DataLoader(test_dataset, batch_size=None, shuffle=False, pin_memory=True)

## ⭐ Convcrf wtih FCN finetuned on cryoem

### ✅ Model

In [13]:
backbone = torch.hub.load('pytorch/vision:v0.10.0', 'fcn_resnet101', num_classes=NUM_CLASSES)

backbone.backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
backbone.classifier[3].p = 0.5

# model = create_model(backbone)

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:00<00:00, 183MB/s]


In [14]:
CHECKPOINT_PATH = "/content/drive/MyDrive/research_xs/final/fcn_resnet101/checkpoint74_B7_L-3_C512.pt" # @param {type:"string"}
state_dict_path = CHECKPOINT_PATH
state_dict = torch.load(state_dict_path, map_location=torch.device(device))
backbone.load_state_dict(state_dict, strict=False)
backbone.to(device)
backbone.eval()
print("Load model at: ", state_dict_path)

Load model at:  /content/drive/MyDrive/research_xs/final/fcn_resnet101/checkpoint74_B7_L-3_C512.pt


In [15]:
for param in backbone.backbone.parameters():
    param.requires_grad = False

In [16]:
import gc
gc.collect()
torch.cuda.empty_cache()

### ✅ Add convcrf

In [17]:
# @markdown config
config = {
    'filter_size': 7,
    'blur': 4,
    'merge': True,
    'norm': 'none', # sym
    'trainable': False,
    'convcomp': False,

    'weight': 'vector', # scalar
    'unary_weight': 1,
    'weight_init': 0.2,
    'logsoftmax': True,
    'softmax': True,
    'final_softmax': False,

    'pos_feats': {
        'sdims': 3, # 3
        'compat': 3, # 3,
    },

    'col_feats': {
        'sdims': 3, # 80,
        'schan': 0.5, # 13
        'compat': 1, # 10
        'use_bias': False, # True
    },
    "trainable_bias": False,
    "pyinn": False
}

In [18]:
from model import create_crf_model

model = create_crf_model(
    backbone, config=config, shape=CROP_SIZE,
    num_classes=NUM_CLASSES, use_gpu=torch.cuda.is_available()) #crf_args

In [19]:
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = ReduceLROnPlateau(optimizer, patience=RLR_PATIENCE, verbose=True)

## ⭐ Training

### ✅ Training

In [20]:
Trainer = tqdm_plugin_for_Trainer(CryoEMTrainerWithScheduler)
trainer = Trainer(model, train_dataset, criterion, optimizer, device,
              num_classes = NUM_CLASSES,
              lr_scheduler=scheduler, patience=ES_PATIENCE)

trainer.train(EPOCHS, val_loader=val_loader, batch_size = BATCH,
              ckpt_dir = RESULT_DIR, random_state = random_state)

Epoch   1/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.3277


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.2254
global correct: 91.06
average row correct: ['96.02', '75.88']
Recall: ['92.41', '86.20']
IoU: ['89.00', '67.66']
F1 Score: ['94.18', '80.71']
mean IoU: 78.33
Saving model at /content/drive/MyDrive/research_xs/final/fcn_resnet101/convcrf_with_F/checkpoint1.pt
Loss improve to 1.225419521331787.
Epoch   2/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.2723


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.2366
global correct: 90.88
average row correct: ['96.19', '74.64']
Recall: ['92.06', '86.52']
IoU: ['88.83', '66.87']
F1 Score: ['94.08', '80.14']
mean IoU: 77.85
No improvement for 1 epoch.
Epoch   3/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.2901


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.2615
global correct: 90.71
average row correct: ['96.56', '72.80']
Recall: ['91.56', '87.39']
IoU: ['88.67', '65.88']
F1 Score: ['94.00', '79.43']
mean IoU: 77.28
No improvement for 2 epoch.
Epoch   4/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.2199


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.2185
global correct: 90.96
average row correct: ['96.02', '75.49']
Recall: ['92.29', '86.12']
IoU: ['88.89', '67.30']
F1 Score: ['94.12', '80.45']
mean IoU: 78.10
Saving model at /content/drive/MyDrive/research_xs/final/fcn_resnet101/convcrf_with_F/checkpoint4.pt
Loss improve to 1.2185107072194417.
Epoch   5/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1912


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.2274
global correct: 90.76
average row correct: ['96.26', '73.95']
Recall: ['91.87', '86.60']
IoU: ['88.70', '66.36']
F1 Score: ['94.01', '79.78']
mean IoU: 77.53
No improvement for 1 epoch.
Epoch   6/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1930


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.2197
global correct: 90.77
average row correct: ['96.23', '74.09']
Recall: ['91.91', '86.54']
IoU: ['88.71', '66.43']
F1 Score: ['94.02', '79.83']
mean IoU: 77.57
No improvement for 2 epoch.
Epoch   7/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1325


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.2115
global correct: 90.73
average row correct: ['96.04', '74.50']
Recall: ['92.01', '86.03']
IoU: ['88.65', '66.46']
F1 Score: ['93.98', '79.85']
mean IoU: 77.55
Saving model at /content/drive/MyDrive/research_xs/final/fcn_resnet101/convcrf_with_F/checkpoint7.pt
Loss improve to 1.2115115986929998.
Epoch   8/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1413


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1838
global correct: 91.02
average row correct: ['95.83', '76.30']
Recall: ['92.52', '85.69']
IoU: ['88.94', '67.68']
F1 Score: ['94.14', '80.72']
mean IoU: 78.31
Saving model at /content/drive/MyDrive/research_xs/final/fcn_resnet101/convcrf_with_F/checkpoint8.pt
Loss improve to 1.1837986840142145.
Epoch   9/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1505


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1901
global correct: 90.93
average row correct: ['96.09', '75.15']
Recall: ['92.20', '86.27']
IoU: ['88.86', '67.12']
F1 Score: ['94.10', '80.33']
mean IoU: 77.99
No improvement for 1 epoch.
Epoch  10/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1476


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1759
global correct: 90.99
average row correct: ['95.85', '76.15']
Recall: ['92.47', '85.72']
IoU: ['88.91', '67.57']
F1 Score: ['94.13', '80.65']
mean IoU: 78.24
Saving model at /content/drive/MyDrive/research_xs/final/fcn_resnet101/convcrf_with_F/checkpoint10.pt
Loss improve to 1.175904115041097.
Epoch  11/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1589


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1767
global correct: 90.83
average row correct: ['96.15', '74.56']
Recall: ['92.04', '86.37']
IoU: ['88.76', '66.71']
F1 Score: ['94.05', '80.03']
mean IoU: 77.74
No improvement for 1 epoch.
Epoch  12/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1861


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1492
global correct: 91.08
average row correct: ['95.69', '76.97']
Recall: ['92.70', '85.40']
IoU: ['88.99', '68.02']
F1 Score: ['94.17', '80.97']
mean IoU: 78.51
Saving model at /content/drive/MyDrive/research_xs/final/fcn_resnet101/convcrf_with_F/checkpoint12.pt
Loss improve to 1.1492021878560383.
Epoch  13/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1930


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1475
global correct: 91.00
average row correct: ['95.80', '76.33']
Recall: ['92.52', '85.60']
IoU: ['88.91', '67.64']
F1 Score: ['94.13', '80.70']
mean IoU: 78.28
Saving model at /content/drive/MyDrive/research_xs/final/fcn_resnet101/convcrf_with_F/checkpoint13.pt
Loss improve to 1.147511535220676.
Epoch  14/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1440


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1653
global correct: 90.92
average row correct: ['96.08', '75.13']
Recall: ['92.19', '86.25']
IoU: ['88.85', '67.09']
F1 Score: ['94.10', '80.30']
mean IoU: 77.97
No improvement for 1 epoch.
Epoch  15/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1659


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1842
global correct: 90.71
average row correct: ['96.21', '73.89']
Recall: ['91.85', '86.45']
IoU: ['88.64', '66.22']
F1 Score: ['93.98', '79.68']
mean IoU: 77.43
No improvement for 2 epoch.
Epoch  16/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1415


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1964
global correct: 90.48
average row correct: ['96.33', '72.58']
Recall: ['91.48', '86.61']
IoU: ['88.40', '65.26']
F1 Score: ['93.84', '78.98']
mean IoU: 76.83
No improvement for 3 epoch.
Epoch  17/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1686


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1644
global correct: 90.74
average row correct: ['96.05', '74.49']
Recall: ['92.01', '86.06']
IoU: ['88.65', '66.46']
F1 Score: ['93.99', '79.85']
mean IoU: 77.56
Epoch 00017: reducing learning rate of group 0 to 1.0000e-06.
No improvement for 4 epoch.
Epoch  18/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1325


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1422
global correct: 90.76
average row correct: ['96.04', '74.63']
Recall: ['92.05', '86.04']
IoU: ['88.68', '66.57']
F1 Score: ['94.00', '79.93']
mean IoU: 77.63
Saving model at /content/drive/MyDrive/research_xs/final/fcn_resnet101/convcrf_with_F/checkpoint18.pt
Loss improve to 1.142203198538886.
Epoch  19/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.0958


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1383
global correct: 90.80
average row correct: ['96.04', '74.79']
Recall: ['92.09', '86.08']
IoU: ['88.73', '66.72']
F1 Score: ['94.03', '80.04']
mean IoU: 77.72
Saving model at /content/drive/MyDrive/research_xs/final/fcn_resnet101/convcrf_with_F/checkpoint19.pt
Loss improve to 1.1383166313171387.
Epoch  20/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.0624


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1383
global correct: 90.79
average row correct: ['95.95', '74.99']
Recall: ['92.14', '85.84']
IoU: ['88.70', '66.74']
F1 Score: ['94.01', '80.05']
mean IoU: 77.72
No improvement for 1 epoch.
Epoch  21/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1412


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1547
global correct: 90.71
average row correct: ['96.24', '73.83']
Recall: ['91.83', '86.51']
IoU: ['88.65', '66.21']
F1 Score: ['93.98', '79.67']
mean IoU: 77.43
No improvement for 2 epoch.
Epoch  22/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.0373


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1263
global correct: 90.89
average row correct: ['96.03', '75.17']
Recall: ['92.20', '86.11']
IoU: ['88.82', '67.04']
F1 Score: ['94.08', '80.27']
mean IoU: 77.93
Saving model at /content/drive/MyDrive/research_xs/final/fcn_resnet101/convcrf_with_F/checkpoint22.pt
Loss improve to 1.1263291703330145.
Epoch  23/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1489


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1336
global correct: 90.84
average row correct: ['96.05', '74.89']
Recall: ['92.12', '86.13']
IoU: ['88.76', '66.83']
F1 Score: ['94.05', '80.12']
mean IoU: 77.80
No improvement for 1 epoch.
Epoch  24/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.2070


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1321
global correct: 90.88
average row correct: ['95.86', '75.65']
Recall: ['92.33', '85.67']
IoU: ['88.79', '67.15']
F1 Score: ['94.06', '80.35']
mean IoU: 77.97
No improvement for 2 epoch.
Epoch  25/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.0837


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1351
global correct: 90.83
average row correct: ['96.03', '74.92']
Recall: ['92.13', '86.07']
IoU: ['88.75', '66.82']
F1 Score: ['94.04', '80.11']
mean IoU: 77.79
No improvement for 3 epoch.
Epoch  26/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1408


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1412
global correct: 90.68
average row correct: ['96.20', '73.83']
Recall: ['91.83', '86.40']
IoU: ['88.61', '66.14']
F1 Score: ['93.96', '79.62']
mean IoU: 77.37
Epoch 00026: reducing learning rate of group 0 to 1.0000e-07.
No improvement for 4 epoch.
Epoch  27/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.0771


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1413
global correct: 90.52
average row correct: ['96.23', '73.05']
Recall: ['91.61', '86.37']
IoU: ['88.43', '65.50']
F1 Score: ['93.86', '79.16']
mean IoU: 76.97
No improvement for 5 epoch.
Epoch  28/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.0563


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1547
global correct: 90.46
average row correct: ['96.37', '72.40']
Recall: ['91.43', '86.72']
IoU: ['88.39', '65.17']
F1 Score: ['93.84', '78.91']
mean IoU: 76.78
No improvement for 6 epoch.
Epoch  29/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1166


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1202
global correct: 90.93
average row correct: ['96.05', '75.27']
Recall: ['92.23', '86.17']
IoU: ['88.86', '67.15']
F1 Score: ['94.10', '80.35']
mean IoU: 78.01
Saving model at /content/drive/MyDrive/research_xs/final/fcn_resnet101/convcrf_with_F/checkpoint29.pt
Loss improve to 1.120242416858673.
Epoch  30/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1384


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1334
global correct: 90.72
average row correct: ['96.23', '73.89']
Recall: ['91.85', '86.50']
IoU: ['88.66', '66.25']
F1 Score: ['93.99', '79.70']
mean IoU: 77.45
No improvement for 1 epoch.
Epoch  31/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.0987


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1420
global correct: 90.72
average row correct: ['96.19', '74.03']
Recall: ['91.88', '86.39']
IoU: ['88.65', '66.30']
F1 Score: ['93.99', '79.73']
mean IoU: 77.47
No improvement for 2 epoch.
Epoch  32/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1368


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1381
global correct: 90.82
average row correct: ['96.12', '74.64']
Recall: ['92.06', '86.27']
IoU: ['88.75', '66.72']
F1 Score: ['94.04', '80.04']
mean IoU: 77.74
No improvement for 3 epoch.
Epoch  33/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.0375


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1476
global correct: 90.79
average row correct: ['96.22', '74.20']
Recall: ['91.94', '86.52']
IoU: ['88.73', '66.51']
F1 Score: ['94.03', '79.89']
mean IoU: 77.62
Epoch 00033: reducing learning rate of group 0 to 1.0000e-08.
No improvement for 4 epoch.
Epoch  34/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1634


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1519
global correct: 90.83
average row correct: ['96.17', '74.52']
Recall: ['92.02', '86.41']
IoU: ['88.77', '66.70']
F1 Score: ['94.05', '80.02']
mean IoU: 77.73
No improvement for 5 epoch.
Epoch  35/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1960


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1851
global correct: 90.39
average row correct: ['96.63', '71.29']
Recall: ['91.14', '87.38']
IoU: ['88.34', '64.64']
F1 Score: ['93.81', '78.52']
mean IoU: 76.49
No improvement for 6 epoch.
Epoch  36/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1112


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1972
global correct: 90.04
average row correct: ['96.81', '69.35']
Recall: ['90.61', '87.67']
IoU: ['87.99', '63.19']
F1 Score: ['93.61', '77.44']
mean IoU: 75.59
No improvement for 7 epoch.
Epoch  37/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1767


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1710
global correct: 90.43
average row correct: ['96.54', '71.73']
Recall: ['91.26', '87.16']
IoU: ['88.37', '64.87']
F1 Score: ['93.83', '78.69']
mean IoU: 76.62
No improvement for 8 epoch.
Epoch  38/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.0985


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1356
global correct: 90.68
average row correct: ['96.05', '74.28']
Recall: ['91.95', '86.01']
IoU: ['88.60', '66.28']
F1 Score: ['93.95', '79.72']
mean IoU: 77.44
No improvement for 9 epoch.
Epoch  39/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1476


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1280
global correct: 90.89
average row correct: ['95.94', '75.47']
Recall: ['92.28', '85.87']
IoU: ['88.81', '67.13']
F1 Score: ['94.07', '80.33']
mean IoU: 77.97
No improvement for 10 epoch.
Epoch  40/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.0654


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1288
global correct: 90.78
average row correct: ['95.95', '75.00']
Recall: ['92.15', '85.82']
IoU: ['88.69', '66.73']
F1 Score: ['94.01', '80.05']
mean IoU: 77.71
No improvement for 11 epoch.
Epoch  41/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1315


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1258
global correct: 90.84
average row correct: ['95.91', '75.34']
Recall: ['92.24', '85.76']
IoU: ['88.75', '66.96']
F1 Score: ['94.04', '80.21']
mean IoU: 77.86
No improvement for 12 epoch.
Epoch  42/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1503


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1339
global correct: 90.82
average row correct: ['96.07', '74.78']
Recall: ['92.09', '86.15']
IoU: ['88.75', '66.76']
F1 Score: ['94.04', '80.07']
mean IoU: 77.75
No improvement for 13 epoch.
Epoch  43/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1293


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1190
global correct: 90.97
average row correct: ['95.85', '76.04']
Recall: ['92.44', '85.71']
IoU: ['88.89', '67.48']
F1 Score: ['94.12', '80.59']
mean IoU: 78.19
Saving model at /content/drive/MyDrive/research_xs/final/fcn_resnet101/convcrf_with_F/checkpoint43.pt
Loss improve to 1.1190180910958185.
Epoch  44/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1682


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1602
global correct: 90.70
average row correct: ['96.28', '73.63']
Recall: ['91.78', '86.61']
IoU: ['88.63', '66.11']
F1 Score: ['93.97', '79.60']
mean IoU: 77.37
No improvement for 1 epoch.
Epoch  45/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1146


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1852
global correct: 90.50
average row correct: ['96.49', '72.19']
Recall: ['91.39', '87.07']
IoU: ['88.45', '65.20']
F1 Score: ['93.87', '78.93']
mean IoU: 76.82
No improvement for 2 epoch.
Epoch  46/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1204


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1603
global correct: 90.62
average row correct: ['96.34', '73.12']
Recall: ['91.64', '86.74']
IoU: ['88.56', '65.76']
F1 Score: ['93.93', '79.35']
mean IoU: 77.16
No improvement for 3 epoch.
Epoch  47/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1444


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1442
global correct: 90.79
average row correct: ['96.11', '74.55']
Recall: ['92.03', '86.23']
IoU: ['88.72', '66.62']
F1 Score: ['94.02', '79.97']
mean IoU: 77.67
No improvement for 4 epoch.
Epoch  48/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.0978


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1350
global correct: 90.80
average row correct: ['96.03', '74.82']
Recall: ['92.10', '86.03']
IoU: ['88.72', '66.72']
F1 Score: ['94.02', '80.04']
mean IoU: 77.72
No improvement for 5 epoch.
Epoch  49/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1868


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1705
global correct: 90.54
average row correct: ['96.46', '72.47']
Recall: ['91.46', '87.00']
IoU: ['88.49', '65.38']
F1 Score: ['93.89', '79.07']
mean IoU: 76.94
No improvement for 6 epoch.
Epoch  50/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.2726


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1486
global correct: 90.71
average row correct: ['96.33', '73.52']
Recall: ['91.75', '86.75']
IoU: ['88.65', '66.10']
F1 Score: ['93.98', '79.59']
mean IoU: 77.38
No improvement for 7 epoch.
Epoch  51/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1295


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1321
global correct: 90.87
average row correct: ['96.20', '74.59']
Recall: ['92.05', '86.52']
IoU: ['88.82', '66.82']
F1 Score: ['94.08', '80.11']
mean IoU: 77.82
No improvement for 8 epoch.
Epoch  52/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.2101


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1399
global correct: 90.69
average row correct: ['96.26', '73.64']
Recall: ['91.78', '86.58']
IoU: ['88.62', '66.10']
F1 Score: ['93.97', '79.59']
mean IoU: 77.36
No improvement for 9 epoch.
Epoch  53/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1671


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1480
global correct: 90.77
average row correct: ['96.30', '73.85']
Recall: ['91.84', '86.72']
IoU: ['88.71', '66.34']
F1 Score: ['94.02', '79.77']
mean IoU: 77.53
No improvement for 10 epoch.
Epoch  54/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1010


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1233
global correct: 90.91
average row correct: ['95.93', '75.58']
Recall: ['92.31', '85.85']
IoU: ['88.83', '67.21']
F1 Score: ['94.08', '80.39']
mean IoU: 78.02
No improvement for 11 epoch.
Epoch  55/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1149


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1403
global correct: 90.82
average row correct: ['96.08', '74.74']
Recall: ['92.08', '86.18']
IoU: ['88.75', '66.74']
F1 Score: ['94.04', '80.05']
mean IoU: 77.74
No improvement for 12 epoch.
Epoch  56/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1120


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1408
global correct: 90.74
average row correct: ['95.99', '74.70']
Recall: ['92.06', '85.91']
IoU: ['88.65', '66.54']
F1 Score: ['93.99', '79.91']
mean IoU: 77.60
No improvement for 13 epoch.
Epoch  57/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1044


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1261
global correct: 90.86
average row correct: ['95.85', '75.62']
Recall: ['92.32', '85.63']
IoU: ['88.77', '67.10']
F1 Score: ['94.05', '80.31']
mean IoU: 77.94
No improvement for 14 epoch.
Epoch  58/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1260


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1384
global correct: 90.76
average row correct: ['96.10', '74.45']
Recall: ['92.00', '86.20']
IoU: ['88.69', '66.52']
F1 Score: ['94.01', '79.90']
mean IoU: 77.60
No improvement for 15 epoch.
Epoch  59/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1188


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1479
global correct: 90.71
average row correct: ['96.19', '73.97']
Recall: ['91.87', '86.40']
IoU: ['88.64', '66.26']
F1 Score: ['93.98', '79.70']
mean IoU: 77.45
No improvement for 16 epoch.
Epoch  60/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1950


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1304
global correct: 90.90
average row correct: ['96.02', '75.25']
Recall: ['92.22', '86.08']
IoU: ['88.83', '67.09']
F1 Score: ['94.08', '80.30']
mean IoU: 77.96
No improvement for 17 epoch.
Epoch  61/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1873


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1205
global correct: 90.83
average row correct: ['96.12', '74.68']
Recall: ['92.07', '86.29']
IoU: ['88.77', '66.76']
F1 Score: ['94.05', '80.07']
mean IoU: 77.76
No improvement for 18 epoch.
Epoch  62/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.2140


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1379
global correct: 90.81
average row correct: ['96.18', '74.40']
Recall: ['91.99', '86.44']
IoU: ['88.75', '66.62']
F1 Score: ['94.04', '79.97']
mean IoU: 77.69
No improvement for 19 epoch.
Epoch  63/100:


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Training score:
  loss	: 1.1052


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.1527
global correct: 90.78
average row correct: ['96.27', '74.03']
Recall: ['91.89', '86.64']
IoU: ['88.73', '66.44']
F1 Score: ['94.03', '79.84']
mean IoU: 77.58
Early stopping


In [21]:
import pickle

trainer_result = {
    '_results': trainer._results,
    'best_epoch': trainer.best_epoch,
    'best_loss': trainer.best_loss,
    'loss': trainer.loss,
}

with open(f'{RESULT_DIR}/trainer_result.pickle', 'wb') as f:
    pickle.dump(trainer_result, f)

## ⭐ Evaluation

### ⏭ Test score for all saved checkpoint.

In [None]:
checkpoint_paths = [path for path in os.listdir(RESULT_DIR) if '.pt' in path]
for checkpoint_path in checkpoint_paths:
  state_dict_path = f"{RESULT_DIR}/{checkpoint_path}"
  state_dict = torch.load(state_dict_path, map_location=torch.device(device))
  model.load_state_dict(state_dict, strict=False)
  model.eval()
  print("\nLoad model at: ", state_dict_path)
  trainer = CryoEMTrainerWithScheduler(model, train_dataset, criterion, optimizer, device,
                num_classes=NUM_CLASSES,
                lr_scheduler=scheduler, patience=ES_PATIENCE)
  result = trainer.evaluate(test_loader)


Load model at:  /content/drive/MyDrive/research_xs/final/fcn_resnet101/fcn_resnet101_B7_LR-5_CS512/Crf_model_convcrf_with_F/checkpoint1.pt
  loss : 31.6959
  Average Precision : 0.1185
global correct: 86.98
average row correct: ['95.31', '58.08']
Recall: ['88.75', '78.13']
IoU: ['85.04', '49.96']
F1 Score: ['91.91', '66.63']
mean IoU: 67.50

Load model at:  /content/drive/MyDrive/research_xs/final/fcn_resnet101/fcn_resnet101_B7_LR-5_CS512/Crf_model_convcrf_with_F/checkpoint2.pt
  loss : 31.4424
  Average Precision : 0.1204
global correct: 87.09
average row correct: ['95.18', '59.03']
Recall: ['88.96', '77.91']
IoU: ['85.12', '50.57']
F1 Score: ['91.96', '67.17']
mean IoU: 67.85

Load model at:  /content/drive/MyDrive/research_xs/final/fcn_resnet101/fcn_resnet101_B7_LR-5_CS512/Crf_model_convcrf_with_F/checkpoint4.pt
  loss : 31.3555
  Average Precision : 0.1209
global correct: 87.11
average row correct: ['95.14', '59.28']
Recall: ['89.02', '77.85']
IoU: ['85.14', '50.72']
F1 Score: ['9

### ⏭ Load previous

In [22]:
# @markdown Set show_checkpoint to True to check every checkpoint.
# @markdown If checkpoint_filename is "" (empty), the last checkpoint will be used.
show_checkpoint = False # @param {type:"boolean"}
checkpoint_filename = "" # @param {type:"string"}
if True:
  checkpoint_paths = [path for path in os.listdir(RESULT_DIR) if '.pt' in path]
  if show_checkpoint:
    print(checkpoint_paths)
  if str(checkpoint_filename) == "":
    checkpoint_filename = checkpoint_paths[-1]
  state_dict_path = f"{RESULT_DIR}/{checkpoint_filename}"
  state_dict = torch.load(state_dict_path, map_location=torch.device(device))
  model.load_state_dict(state_dict, strict=False)
  model.eval()
  print("Load model at: ", state_dict_path)
  try:
    # Run if changing trainer.
    import pickle
    # !rm __pycache__/*
    # !cp /content/drive/MyDrive/research/src/EM_project/trainer.py /content/trainer.py
    # from trainer import CryoEMTrainerWithScheduler
    Trainer = tqdm_plugin_for_Trainer(CryoEMTrainerWithScheduler)
    trainer = Trainer(model, train_dataset, val_dataset, criterion, optimizer, device,
                  num_classes = NUM_CLASSES,
                  lr_scheduler=scheduler, patience=ES_PATIENCE)
    with open(f'{RESULT_DIR}/trainer_result.pickle', 'rb') as f:
      trainer_result = pickle.load(f)
    for key in trainer_result:
      setattr(trainer, key, trainer_result[key])
    # for var in dir(trainer):
    #   if '__' not in var:
    #     print(var)
    #     print(f'var:\n  ', getattr(trainer, var))
  except:
    pass

Load model at:  /content/drive/MyDrive/research_xs/final/fcn_resnet101/convcrf_with_F/checkpoint43.pt


### ✅ Testing

In [23]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [24]:
evaluator = CryoEMEvaluator(model=model, device=device, metrics=['iou'], num_classes=NUM_CLASSES)
print("FCN validation result:")
result = evaluator.evaluate(loader=val_loader)
print("FCN test result:")
result = evaluator.evaluate(loader=test_loader)

FCN validation result:
global correct: 90.97
average row correct: ['95.85', '76.04']
Recall: ['92.44', '85.71']
IoU: ['88.89', '67.48']
F1 Score: ['94.12', '80.59']
mean IoU: 78.19
FCN test result:
global correct: 89.99
average row correct: ['95.35', '71.38']
Recall: ['92.04', '81.56']
IoU: ['88.08', '61.46']
F1 Score: ['93.66', '76.13']
mean IoU: 74.77


In [25]:
from torchvision.utils import save_image
from dataset import reconstruct_patched

!mkdir {RESULT_DIR}/test_image
model.eval()
with torch.no_grad():
  for idx, (test_image, _, grid, _) in enumerate(test_dataset):
    inputs = test_image.to(device)
    outputs = model(inputs)['out']
    preds = outputs.argmax(dim=1).cpu().detach()
    filename = f"{os.path.splitext(test_dataset.filenames[idx])[0]}.png"
    pred_path = os.path.join(RESULT_DIR, "test_image", filename)
    save_image(reconstruct_patched(preds, grid).float(), pred_path)