### ConvCRF
[Convolutional CRFs for Semantic Segmentation](https://doi.org/10.48550/arXiv.1805.04777)  
Marvin T. T. Teichmann and Roberto Cipolla

In [None]:
# pip install git+https://github.com/lucasb-eyer/pydensecrf.git

### CryoParticleSegment

In [None]:
#%%bash
#git clone git+https://cyanazuki@github.com/cyanazuki/CryoParticleSegment.git
#cd CryoParticleSegment
#python setup.py install

## ⭐ Setup
You must run all codes under this category.

### ✅ Directory Settings

In [None]:
# @title  { display-mode: "form" }

IMAGE_DIR = "/content/drive/MyDrive/research_xs/processed_micrographs_np" # @param {type:"string"}
LABEL_DIR = "/content/drive/MyDrive/research_xs/ground_truth_mask" # @param {type:"string"}
RESULT_DIR = "/content/drive/MyDrive/research_xs/final/deeplabv3_resnet50/convcrf_with_F" # @param {type:"string"}

In [None]:
# @title  { display-mode: "form" }
# @markdown Detect whether using folder in Google Drive as **`RESULT DIR`**📁.

if "content" in IMAGE_DIR.split("/")[:3] or "content" in LABEL_DIR.split("/")[:3]:
  try:
    from google.colab import drive
    drive.mount('/content/drive')
    !rm -r /content/sample_data
    if "content" in IMAGE_DIR.split("/")[:3]:
      !cp -r {IMAGE_DIR} /content/image_dir
      IMAGE_DIR = "/content/image_dir"
    if "content" in LABEL_DIR.split("/")[:3]:
      !cp -r {LABEL_DIR} /content/label_dir
      LABEL_DIR = "/content/label_dir"
  except:
    pass

Mounted at /content/drive


In [None]:
# @title  { display-mode: "form" }
# @markdown Source code directory.
SRC_DIR = "/content/drive/MyDrive/research/src" # @param {type:"string"}

if True:
  !cp -r {SRC_DIR}/EM_project/*.py /content/
else:
  !cp {SRC_DIR}/EM_project/convcrf.py /content/convcrf.py
  !cp {SRC_DIR}/EM_project/dataset.py /content/dataset.py
  !cp {SRC_DIR}/EM_project/lr_scheduler.py /content/lr_scheduler.py
  !cp {SRC_DIR}/EM_project/metrics.py /content/metrics.py
  !cp {SRC_DIR}/EM_project/model.py /content/model.py
  !cp {SRC_DIR}/EM_project/trainer.py /content/trainer.py
  !cp {SRC_DIR}/EM_project/utils.py /content/utils.py
# !rm /content/convcrf.py

### ✅ Packages Handling

In [None]:
# @title  { display-mode: "form" }
# @markdown Useful packages.

import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
# @title  { display-mode: "form" }
# @markdown User-defined packages.

from dataset import MicrographDataset, MicrographDatasetEvery
from dataset import reconstruct_patched
from model import create_model
from trainer import CryoEMEvaluator
from trainer import CryoEMTrainerWithScheduler, tqdm_plugin_for_Trainer

## ⭐ Main

### ✅ Setting

In [None]:
# @markdown Parameters.

NUM_CLASSES = 2
EPOCHS = 300
BATCH = 2
CROP_SIZE = (1024, 1024)
LR = 1e-5
RLR_PATIENCE = 3
ES_PATIENCE = 20
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# @markdown Set seed.

random_state = 42
torch.manual_seed(random_state)
torch.cuda.manual_seed_all(random_state)

### ✅ Dataset

In [None]:
train_dir = os.path.join(IMAGE_DIR, 'train')
train_filenames = np.loadtxt(f"{IMAGE_DIR}/train_filenames.txt", dtype=str)
train_dataset = MicrographDataset(image_dir=train_dir, label_dir=LABEL_DIR, filenames=train_filenames, crop_size=CROP_SIZE)

In [None]:
val_dir = os.path.join(IMAGE_DIR, 'val')
val_filenames = np.loadtxt(f"{IMAGE_DIR}/val_filenames.txt", dtype=str)
val_dataset = MicrographDatasetEvery(image_dir=val_dir, label_dir=LABEL_DIR, filenames=val_filenames, crop_size=CROP_SIZE)
val_loader = DataLoader(val_dataset, batch_size=None, shuffle=False, pin_memory=True)

In [None]:
test_dir = os.path.join(IMAGE_DIR, 'test')
test_filenames = np.loadtxt(f"{IMAGE_DIR}/test_filenames.txt", dtype=str)
test_dataset = MicrographDatasetEvery(image_dir=test_dir, label_dir=LABEL_DIR, filenames=test_filenames, crop_size=CROP_SIZE)
test_loader = DataLoader(test_dataset, batch_size=None, shuffle=False, pin_memory=True)

## ⭐ Convcrf wtih FCN finetuned on cryoem

### ✅ Model

In [None]:
backbone = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True)

backbone.backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
backbone.classifier[4] = torch.nn.Conv2d(256, NUM_CLASSES, kernel_size=1, stride=1)
backbone.aux_classifier[4] = torch.nn.Conv2d(256, NUM_CLASSES, kernel_size=1, stride=1)

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth" to /root/.cache/torch/hub/checkpoints/deeplabv3_resnet50_coco-cd0a2569.pth
100%|██████████| 161M/161M [00:00<00:00, 186MB/s]


In [None]:
CHECKPOINT_PATH = "/content/drive/MyDrive/research_xs/final/deeplabv3_resnet50/checkpoint26_B2_L-3_C1024.pt" # @param {type:"string"}
state_dict_path = CHECKPOINT_PATH
state_dict = torch.load(state_dict_path, map_location=torch.device(DEVICE))
backbone.load_state_dict(state_dict, strict=False)
backbone.to(DEVICE)
backbone.eval()
print("Load model at: ", state_dict_path)

Load model at:  /content/drive/MyDrive/research_xs/final/deeplabv3_resnet50/checkpoint26_B2_L-3_C1024.pt


In [None]:
for param in backbone.backbone.parameters():
    param.requires_grad = False

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

### ✅ Add convcrf

In [None]:
# @markdown config
config = {
    'filter_size': 7,
    'blur': 4,
    'merge': True,
    'norm': 'none', # sym
    'trainable': False,
    'convcomp': False,

    'weight': 'vector', # scalar
    'unary_weight': 1,
    'weight_init': 0.2,
    'logsoftmax': True,
    'softmax': True,
    'final_softmax': False,

    'pos_feats': {
        'sdims': 3, # 3
        'compat': 3, # 3,
    },

    'col_feats': {
        'sdims': 3, # 80,
        'schan': 0.5, # 13
        'compat': 1, # 10
        'use_bias': False, # True
    },
    "trainable_bias": False,
    "pyinn": False
}

In [None]:
from model import create_crf_model

model = create_crf_model(
    backbone, config=config, shape=CROP_SIZE,
    num_classes=NUM_CLASSES, use_gpu=torch.cuda.is_available()) #crf_args

In [None]:
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = ReduceLROnPlateau(optimizer, patience=RLR_PATIENCE, verbose=True)

### ✅ Training

In [None]:
Trainer = tqdm_plugin_for_Trainer(CryoEMTrainerWithScheduler)
trainer = Trainer(model, train_dataset, criterion, optimizer, DEVICE,
              num_classes = NUM_CLASSES,
              lr_scheduler=scheduler, patience=ES_PATIENCE)

trainer.train(EPOCHS, val_loader=val_loader, batch_size = BATCH,
              ckpt_dir = RESULT_DIR, random_state = random_state)

Epoch   1/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1493


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3228
global correct: 90.34
average row correct: ['97.35', '68.88']
Recall: ['90.53', '89.49']
IoU: ['88.36', '63.72']
F1 Score: ['93.82', '77.84']
mean IoU: 76.04
Saving model at /content/drive/MyDrive/research_xs/final/deeplabv3_resnet50/convcrf_with_F/checkpoint1.pt
Loss improve to 1.3228181733025446.
Epoch   2/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1581


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3797
global correct: 89.67
average row correct: ['97.50', '65.72']
Recall: ['89.69', '89.58']
IoU: ['87.67', '61.06']
F1 Score: ['93.43', '75.82']
mean IoU: 74.36
No improvement for 1 epoch.
Epoch   3/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1100


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3583
global correct: 89.58
average row correct: ['97.36', '65.77']
Recall: ['89.69', '89.09']
IoU: ['87.56', '60.86']
F1 Score: ['93.37', '75.67']
mean IoU: 74.21
No improvement for 2 epoch.
Epoch   4/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1397


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.4074
global correct: 89.32
average row correct: ['97.40', '64.59']
Recall: ['89.37', '89.06']
IoU: ['87.29', '59.85']
F1 Score: ['93.22', '74.88']
mean IoU: 73.57
No improvement for 3 epoch.
Epoch   5/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0636


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3840
global correct: 89.41
average row correct: ['97.41', '64.95']
Recall: ['89.47', '89.15']
IoU: ['87.40', '60.19']
F1 Score: ['93.27', '75.15']
mean IoU: 73.79
Epoch 00005: reducing learning rate of group 0 to 1.0000e-06.
No improvement for 4 epoch.
Epoch   6/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0697


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3663
global correct: 89.34
average row correct: ['97.26', '65.10']
Recall: ['89.50', '88.62']
IoU: ['87.30', '60.08']
F1 Score: ['93.22', '75.06']
mean IoU: 73.69
No improvement for 5 epoch.
Epoch   7/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0682


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3187
global correct: 90.10
average row correct: ['97.13', '68.58']
Recall: ['90.43', '88.67']
IoU: ['88.08', '63.05']
F1 Score: ['93.66', '77.34']
mean IoU: 75.57
Saving model at /content/drive/MyDrive/research_xs/final/deeplabv3_resnet50/convcrf_with_F/checkpoint7.pt
Loss improve to 1.3186716967158847.
Epoch   8/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0864


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3531
global correct: 89.82
average row correct: ['97.24', '67.12']
Recall: ['90.04', '88.85']
IoU: ['87.80', '61.90']
F1 Score: ['93.50', '76.47']
mean IoU: 74.85
No improvement for 1 epoch.
Epoch   9/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1105


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.4416
global correct: 88.80
average row correct: ['97.53', '62.11']
Recall: ['88.73', '89.16']
IoU: ['86.78', '57.75']
F1 Score: ['92.92', '73.22']
mean IoU: 72.26
No improvement for 2 epoch.
Epoch  10/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1074


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3738
global correct: 89.42
average row correct: ['97.31', '65.29']
Recall: ['89.55', '88.83']
IoU: ['87.39', '60.34']
F1 Score: ['93.27', '75.27']
mean IoU: 73.87
No improvement for 3 epoch.
Epoch  11/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0552


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.4302
global correct: 88.59
average row correct: ['97.49', '61.37']
Recall: ['88.52', '88.88']
IoU: ['86.55', '56.99']
F1 Score: ['92.79', '72.60']
mean IoU: 71.77
Epoch 00011: reducing learning rate of group 0 to 1.0000e-07.
No improvement for 4 epoch.
Epoch  12/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0604


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3475
global correct: 89.86
average row correct: ['97.16', '67.53']
Recall: ['90.15', '88.62']
IoU: ['87.83', '62.14']
F1 Score: ['93.52', '76.65']
mean IoU: 74.99
No improvement for 5 epoch.
Epoch  13/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0292


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.4071
global correct: 89.01
average row correct: ['97.42', '63.33']
Recall: ['89.04', '88.91']
IoU: ['86.98', '58.69']
F1 Score: ['93.04', '73.97']
mean IoU: 72.84
No improvement for 6 epoch.
Epoch  14/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0857


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.5903
global correct: 88.16
average row correct: ['97.83', '58.60']
Recall: ['87.84', '89.83']
IoU: ['86.16', '54.95']
F1 Score: ['92.57', '70.93']
mean IoU: 70.56
No improvement for 7 epoch.
Epoch  15/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1249


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.6040
global correct: 88.07
average row correct: ['97.69', '58.67']
Recall: ['87.84', '89.27']
IoU: ['86.06', '54.80']
F1 Score: ['92.51', '70.80']
mean IoU: 70.43
Epoch 00015: reducing learning rate of group 0 to 1.0000e-08.
No improvement for 8 epoch.
Epoch  16/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0736


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.4652
global correct: 88.95
average row correct: ['97.48', '62.87']
Recall: ['88.92', '89.10']
IoU: ['86.93', '58.38']
F1 Score: ['93.01', '73.72']
mean IoU: 72.65
No improvement for 9 epoch.
Epoch  17/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1081


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3619
global correct: 89.82
average row correct: ['97.25', '67.12']
Recall: ['90.04', '88.86']
IoU: ['87.80', '61.91']
F1 Score: ['93.51', '76.48']
mean IoU: 74.86
No improvement for 10 epoch.
Epoch  18/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0949


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.4492
global correct: 89.03
average row correct: ['97.25', '63.93']
Recall: ['89.18', '88.37']
IoU: ['86.98', '58.97']
F1 Score: ['93.04', '74.19']
mean IoU: 72.98
No improvement for 11 epoch.
Epoch  19/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0991


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.6262
global correct: 87.93
average row correct: ['97.67', '58.17']
Recall: ['87.71', '89.08']
IoU: ['85.91', '54.30']
F1 Score: ['92.42', '70.39']
mean IoU: 70.11
No improvement for 12 epoch.
Epoch  20/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0673


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3568
global correct: 89.56
average row correct: ['97.20', '66.21']
Recall: ['89.79', '88.56']
IoU: ['87.53', '60.99']
F1 Score: ['93.35', '75.77']
mean IoU: 74.26
No improvement for 13 epoch.
Epoch  21/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1358


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3340
global correct: 90.09
average row correct: ['97.13', '68.58']
Recall: ['90.43', '88.65']
IoU: ['88.08', '63.04']
F1 Score: ['93.66', '77.33']
mean IoU: 75.56
No improvement for 14 epoch.
Epoch  22/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1133


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3677
global correct: 89.83
average row correct: ['97.16', '67.44']
Recall: ['90.12', '88.58']
IoU: ['87.80', '62.04']
F1 Score: ['93.51', '76.58']
mean IoU: 74.92
No improvement for 15 epoch.
Epoch  23/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1339


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.4377
global correct: 89.17
average row correct: ['97.45', '63.85']
Recall: ['89.18', '89.12']
IoU: ['87.14', '59.23']
F1 Score: ['93.13', '74.40']
mean IoU: 73.19
No improvement for 16 epoch.
Epoch  24/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0955


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3452
global correct: 90.01
average row correct: ['97.16', '68.14']
Recall: ['90.31', '88.69']
IoU: ['87.99', '62.70']
F1 Score: ['93.61', '77.07']
mean IoU: 75.34
No improvement for 17 epoch.
Epoch  25/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0768


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3673
global correct: 89.43
average row correct: ['97.27', '65.47']
Recall: ['89.60', '88.70']
IoU: ['87.40', '60.43']
F1 Score: ['93.28', '75.34']
mean IoU: 73.91
No improvement for 18 epoch.
Epoch  26/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0848


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.2617
global correct: 90.59
average row correct: ['96.77', '71.68']
Recall: ['91.26', '87.90']
IoU: ['88.57', '65.24']
F1 Score: ['93.94', '78.96']
mean IoU: 76.90
Saving model at /content/drive/MyDrive/research_xs/final/deeplabv3_resnet50/convcrf_with_F/checkpoint26.pt
Loss improve to 1.2617267370224.
Epoch  27/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1337


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3981
global correct: 89.65
average row correct: ['97.30', '66.26']
Recall: ['89.81', '88.93']
IoU: ['87.63', '61.21']
F1 Score: ['93.41', '75.94']
mean IoU: 74.42
No improvement for 1 epoch.
Epoch  28/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0878


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.4516
global correct: 89.24
average row correct: ['97.59', '63.73']
Recall: ['89.16', '89.64']
IoU: ['87.24', '59.36']
F1 Score: ['93.19', '74.50']
mean IoU: 73.30
No improvement for 2 epoch.
Epoch  29/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1159


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3711
global correct: 89.78
average row correct: ['97.32', '66.74']
Recall: ['89.95', '89.07']
IoU: ['87.77', '61.69']
F1 Score: ['93.49', '76.31']
mean IoU: 74.73
No improvement for 3 epoch.
Epoch  30/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1190


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.4619
global correct: 88.87
average row correct: ['97.56', '62.30']
Recall: ['88.78', '89.31']
IoU: ['86.85', '57.98']
F1 Score: ['92.96', '73.40']
mean IoU: 72.41
No improvement for 4 epoch.
Epoch  31/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0826


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3947
global correct: 89.46
average row correct: ['97.31', '65.47']
Recall: ['89.60', '88.86']
IoU: ['87.44', '60.50']
F1 Score: ['93.30', '75.39']
mean IoU: 73.97
No improvement for 5 epoch.
Epoch  32/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0850


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3485
global correct: 89.83
average row correct: ['97.26', '67.11']
Recall: ['90.04', '88.91']
IoU: ['87.81', '61.92']
F1 Score: ['93.51', '76.48']
mean IoU: 74.87
No improvement for 6 epoch.
Epoch  33/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1072


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3491
global correct: 89.88
average row correct: ['97.21', '67.49']
Recall: ['90.14', '88.78']
IoU: ['87.87', '62.19']
F1 Score: ['93.54', '76.69']
mean IoU: 75.03
No improvement for 7 epoch.
Epoch  34/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1069


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3750
global correct: 89.80
average row correct: ['97.20', '67.17']
Recall: ['90.05', '88.71']
IoU: ['87.78', '61.88']
F1 Score: ['93.49', '76.45']
mean IoU: 74.83
No improvement for 8 epoch.
Epoch  35/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0669


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.4890
global correct: 88.75
average row correct: ['97.55', '61.85']
Recall: ['88.66', '89.21']
IoU: ['86.73', '57.55']
F1 Score: ['92.89', '73.05']
mean IoU: 72.14
No improvement for 9 epoch.
Epoch  36/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0954


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3749
global correct: 89.69
average row correct: ['97.19', '66.73']
Recall: ['89.93', '88.61']
IoU: ['87.66', '61.46']
F1 Score: ['93.42', '76.13']
mean IoU: 74.56
No improvement for 10 epoch.
Epoch  37/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0882


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3940
global correct: 89.57
average row correct: ['97.42', '65.57']
Recall: ['89.64', '89.27']
IoU: ['87.56', '60.78']
F1 Score: ['93.37', '75.61']
mean IoU: 74.17
No improvement for 11 epoch.
Epoch  38/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1073


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3669
global correct: 89.74
average row correct: ['97.23', '66.85']
Recall: ['89.97', '88.76']
IoU: ['87.72', '61.64']
F1 Score: ['93.46', '76.27']
mean IoU: 74.68
No improvement for 12 epoch.
Epoch  39/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1197


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3466
global correct: 89.98
average row correct: ['97.19', '67.92']
Recall: ['90.26', '88.77']
IoU: ['87.96', '62.55']
F1 Score: ['93.59', '76.96']
mean IoU: 75.25
No improvement for 13 epoch.
Epoch  40/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1052


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3949
global correct: 89.36
average row correct: ['97.30', '65.10']
Recall: ['89.50', '88.75']
IoU: ['87.33', '60.13']
F1 Score: ['93.24', '75.10']
mean IoU: 73.73
No improvement for 14 epoch.
Epoch  41/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0609


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3032
global correct: 90.30
average row correct: ['96.99', '69.86']
Recall: ['90.77', '88.37']
IoU: ['88.29', '63.97']
F1 Score: ['93.78', '78.03']
mean IoU: 76.13
No improvement for 15 epoch.
Epoch  42/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0952


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.2988
global correct: 90.29
average row correct: ['96.91', '70.04']
Recall: ['90.82', '88.12']
IoU: ['88.26', '64.00']
F1 Score: ['93.77', '78.05']
mean IoU: 76.13
No improvement for 16 epoch.
Epoch  43/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0834


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3962
global correct: 89.75
average row correct: ['97.45', '66.19']
Recall: ['89.81', '89.48']
IoU: ['87.75', '61.41']
F1 Score: ['93.48', '76.09']
mean IoU: 74.58
No improvement for 17 epoch.
Epoch  44/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1437


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.5107
global correct: 88.34
average row correct: ['97.62', '59.98']
Recall: ['88.18', '89.18']
IoU: ['86.32', '55.92']
F1 Score: ['92.66', '71.73']
mean IoU: 71.12
No improvement for 18 epoch.
Epoch  45/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.1090


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3407
global correct: 90.09
average row correct: ['97.22', '68.31']
Recall: ['90.36', '88.94']
IoU: ['88.09', '62.96']
F1 Score: ['93.67', '77.27']
mean IoU: 75.52
No improvement for 19 epoch.
Epoch  46/300:


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Training score:
  loss	: 1.0982


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

Validation score:
  loss : 1.3069
global correct: 90.16
average row correct: ['97.08', '69.00']
Recall: ['90.54', '88.55']
IoU: ['88.14', '63.35']
F1 Score: ['93.70', '77.56']
mean IoU: 75.75
Early stopping


In [None]:
import pickle

trainer_result = {
    '_results': trainer._results,
    'best_epoch': trainer.best_epoch,
    'best_loss': trainer.best_loss,
    'loss': trainer.loss,
}

with open(f'{RESULT_DIR}/trainer_result.pickle', 'wb') as f:
    pickle.dump(trainer_result, f)

## ⭐ Evaluate

### ⏭ Test score for all saved checkpoint.

In [None]:
checkpoint_paths = [path for path in os.listdir(RESULT_DIR) if '.pt' in path]
for checkpoint_path in checkpoint_paths:
  state_dict_path = f"{RESULT_DIR}/{checkpoint_path}"
  state_dict = torch.load(state_dict_path, map_location=torch.device(DEVICE))
  model.load_state_dict(state_dict, strict=False)
  model.eval()
  print("\nLoad model at: ", state_dict_path)
  trainer = CryoEMTrainerWithScheduler(model, train_dataset, criterion, optimizer, DEVICE,
                num_classes=NUM_CLASSES,
                lr_scheduler=scheduler, patience=ES_PATIENCE)
  result = trainer.evaluate(test_loader)

### ⏭ Load previous

In [None]:
# @markdown Set show_checkpoint to True to check every checkpoint.
# @markdown If checkpoint_filename is "" (empty), the last checkpoint will be used.
show_checkpoint = False # @param {type:"boolean"}
checkpoint_filename = "" # @param {type:"string"}
if True:
  checkpoint_paths = [path for path in os.listdir(RESULT_DIR) if '.pt' in path]
  if show_checkpoint:
    print(checkpoint_paths)
  if str(checkpoint_filename) == "":
    checkpoint_filename = checkpoint_paths[-1]
  state_dict_path = f"{RESULT_DIR}/{checkpoint_filename}"
  state_dict = torch.load(state_dict_path, map_location=torch.device(DEVICE))
  model.load_state_dict(state_dict, strict=False)
  model.eval()
  print("Load model at: ", state_dict_path)
  try:
    # Run if changing trainer.
    import pickle
    # !rm __pycache__/*
    # !cp /content/drive/MyDrive/research/src/EM_project/trainer.py /content/trainer.py
    # from trainer import CryoEMTrainerWithScheduler
    Trainer = tqdm_plugin_for_Trainer(CryoEMTrainerWithScheduler)
    trainer = Trainer(model, train_dataset, val_dataset, criterion, optimizer, DEVICE,
                  num_classes = NUM_CLASSES,
                  lr_scheduler=scheduler, patience=ES_PATIENCE)
    with open(f'{RESULT_DIR}/trainer_result.pickle', 'rb') as f:
      trainer_result = pickle.load(f)
    for key in trainer_result:
      setattr(trainer, key, trainer_result[key])
    # for var in dir(trainer):
    #   if '__' not in var:
    #     print(var)
    #     print(f'var:\n  ', getattr(trainer, var))
  except:
    pass

Load model at:  /content/drive/MyDrive/research_xs/final/deeplabv3_resnet50/convcrf_with_F/checkpoint26.pt


### ✅ Testing

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
evaluator = CryoEMEvaluator(
    model=model, device=DEVICE, metrics=['AP','iou'],
    num_classes=NUM_CLASSES)
print("ConvCRF validation result:")
result = evaluator.evaluate(loader=val_loader)
print("ConvCRF test result:")
result = evaluator.evaluate(loader=test_loader)

ConvCRF validation result:
global correct: 90.59
average row correct: ['96.77', '71.68']
Recall: ['91.26', '87.90']
IoU: ['88.57', '65.24']
F1 Score: ['93.94', '78.96']
mean IoU: 76.90
ConvCRF test result:
global correct: 89.91
average row correct: ['95.55', '70.31']
Recall: ['91.78', '82.01']
IoU: ['88.02', '60.92']
F1 Score: ['93.63', '75.71']
mean IoU: 74.47


In [None]:
from torchvision.utils import save_image
from dataset import reconstruct_patched

!mkdir {RESULT_DIR}/test_image
model.eval()
with torch.no_grad():
  for idx, (test_image, _, grid, _) in enumerate(test_dataset):
    inputs = test_image.to(DEVICE)
    outputs = model(inputs)['out']
    preds = outputs.argmax(dim=1).cpu().detach()
    filename = f"{os.path.splitext(test_dataset.filenames[idx])[0]}.png"
    pred_path = os.path.join(RESULT_DIR, "test_image", filename)
    save_image(reconstruct_patched(preds, grid).float(), pred_path)