-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
78 lines (63 loc) · 2.95 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import torch
import time
import sys
import argparse
from datetime import timedelta
from hydra import compose, initialize
from torch import nn
from omegaconf import OmegaConf
from src.logging import logger
from src.models import get_model
from src.ICM_dataset import ICMDataset
from torch.utils.data import DataLoader
from src.training import valid_epoch, eval_uncertainty_model
def main(cfg) -> None:
# Find which device is used
if torch.cuda.is_available() and cfg.base.device == "cuda":
logger.info(f'Evaluating the model in {torch.cuda.get_device_name(torch.cuda.current_device())}')
else:
logger.warn('CAREFUL!! Training the model with CPU')
# Create the model
model = get_model(cfg.model.encoder)
model = model.to(cfg.base.device)
# Load loss, optimizer and scheduler
criterion = getattr(nn, cfg.training.loss)()
# Load evaluation dataset
eval_dataset = ICMDataset(path=os.path.join(cfg.base.dataset, "valid"),
train=False,
species=cfg.base.classes)
if cfg.uncertainty.eval_dataloader.batch_size != 1:
logger.warn("The test batch size must be 1. Changing it to 1.")
cfg.uncertainty.eval_dataloader.batch_size = 1
eval_loader = torch.utils.data.DataLoader(eval_dataset, **cfg.uncertainty.eval_dataloader)
logger.info("===== Starting evaluation =====")
start = time.time()
acc, f1, loss, cm = valid_epoch(model,
eval_loader,
criterion,
log_step=cfg.training.log_step,
epoch=0,
wb_log=False,
cls_names=cfg.base.classes,
device=cfg.base.device)
cm.savefig(os.path.join(logger.output_path, "test_confusion_matrix.jpg"))
logger.info(f"Accuracy: {acc:.4f} | F1: {f1:.4f} | Loss: {loss:.4f}")
eval_uncertainty_model(model=model,
eval_loader=eval_loader,
mc_samples=cfg.uncertainty.mc_samples,
dropout_rate=cfg.uncertainty.dropout_rate,
num_classes=len(cfg.base.classes),
wb_log=False,
output_path=logger.output_path,
device=cfg.base.device)
logger.info(f"===== Evaluation finished in {timedelta(seconds=round(time.time() - start))} =====")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate a model following the instructions in the README file")
parser.add_argument('--config', default='config.yaml')
args = parser.parse_args(sys.argv[1:])
config_name = args.config
initialize(version_base=None, config_path="config", job_name="training")
config = compose(config_name=config_name)
config = OmegaConf.create(config)
main(config)