forked from dsp-uga/team-linden-p2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TestTiramisu.py
77 lines (60 loc) · 2.96 KB
/
TestTiramisu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from src.tiramisu.datasets import cilia, joint_transforms
from src.tiramisu.utils import training
from src.tiramisu.models import tiramisu
import adabound
import torch
from torchvision import transforms
from torch.utils import data
from imageio import imwrite, imread
from pathlib import Path
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import time
import argparse
def main(args):
args.rootDir = os.path.normpath(args.rootDir)
# ensure the root directory has expected subdirectories
if not os.path.exists(args.rootDir):
raise Exception("ERROR: The dir '"+args.rootDir+"' doesn't exist")
if not os.path.exists(args.rootDir+"/test/data"):
raise Exception("ERROR: The dir '"+args.rootDir+"/test/data' " + \
"doesn't exist")
if not os.path.exists(args.rootDir+'/weights/'+args.targModel):
raise Exception("ERROR: The target model file '"+args.rootDir+'/weights/'+args.targModel+"' "+ \
"doesn't exist")
if not os.path.exists(args.rootDir+"/results"):
os.mkdir(args.rootDir+"/results")
if not os.path.exists(args.rootDir+"/weights"):
os.mkdir(args.rootDir+"/weights")
training.RESULTS_PATH = Path(args.rootDir+"/results/")
training.WEIGHTS_PATH = Path(args.rootDir+"/weights/")
test_cilia = cilia.Cilia(args.rootDir, 'test')
test_loader = data.DataLoader(test_cilia, batch_size=1, \
shuffle=False)
## Load the target model
model = tiramisu.FCDenseNet103(n_classes=3, in_channels=1).cuda()
model.load_state_dict(torch.load(args.rootDir+'/weights/'+args.targModel)['state_dict'])
test_dir = sorted(os.listdir(args.rootDir + '/test/data/'))
for i, img in enumerate(test_loader):
pred = training.get_test_pred(model, img)
pred_img = pred[0, :, :]
imwrite(os.path.join(args.rootDir+"/results", test_dir[i] + '.png'), \
pred_img.numpy().astype(np.uint8))
print('----Testing done successfully----')
print('Masks have been created in ', args.rootDir+"/results")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='This ' + \
'is part of the UGA CSCI 8360 Project 2 - . Please visit our ' + \
'GitHub project at https://github.com/dsp-uga/team-linden-p2 ' + \
'for more information regarding data organization ' + \
'expectations and examples on how to execute our scripts.')
parser.add_argument('-r','--rootDir', required=True,
help='The base directory storing files and ' + \
'directories conforming with organization ' + \
'expectations, please visit out GitHub website')
parser.add_argument('-tm', '--targModel', required=True,
help='A model file to define the CNN weights')
args = parser.parse_args()
main(args)