Skip to content

Commit

Permalink
build: test for seaplane
Browse files Browse the repository at this point in the history
  • Loading branch information
beiyuouo committed May 1, 2022
1 parent 449442a commit 7b7729e
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/config.py
Expand Up @@ -34,7 +34,7 @@ class config:
label_file_path = os.path.join(labeled_data_path, 'label.txt')

lr = 0.005
epochs = 100
epochs = 20
batch_size = 16
log_interval = 100

Expand Down
14 changes: 13 additions & 1 deletion src/main.py
Expand Up @@ -78,5 +78,17 @@ def test():
model = ResNetMini(3, 2)


def transfer_model():
model = ResNetMini(3, 2)
model.load_state_dict(torch.load(config.model_path))
model.eval()
torch.onnx.export(model,
torch.randn(1, 3, 64, 64),
config.model_onnx_path,
verbose=False,
export_params=True)


if __name__ == '__main__':
train()
train()
transfer_model()
49 changes: 49 additions & 0 deletions src/test_cv2.py
@@ -0,0 +1,49 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import cv2
import numpy as np
from config import config, mkdir


def val_cv2(img_path):
img_list = os.listdir(img_path)
# print(img_list)

model = cv2.dnn.readNetFromONNX(config.model_onnx_path)
acc = 0

for img_name in img_list:
# print(img_name)
img = cv2.imread(os.path.join(img_path, img_name))
img = cv2.resize(img, (64, 64))

blob = cv2.dnn.blobFromImage(img, 1 / 255.0, (64, 64), (0, 0, 0), swapRB=True, crop=False)
model.setInput(blob)
out = model.forward()
# print(out.shape)
# print(out)
label = np.argmax(out, axis=1)[0]
# print(label)
if label == 1:
acc += 1
# break

print(f'{img_path}\n err: {len(img_list)-acc} acc: {acc / len(img_list)}')


if __name__ == '__main__':
val_img_path = [
os.path.join(config.data_path, 'val', 'airplane in the sky flying left', 'yes'),
os.path.join(config.data_path, 'val', 'airplane in the sky flying left', 'bad'),
os.path.join(config.data_path, 'val', 'airplanes in the sky that are flying to the right',
'yes'),
os.path.join(config.data_path, 'val', 'airplanes in the sky that are flying to the right',
'bad'),
]
for val_img_path_ in val_img_path:
val_cv2(val_img_path_)

test_img_path = os.path.join(config.data_path, 'test')
for class_ in config.classes:
mkdir(os.path.join(test_img_path, class_), remove=True)

0 comments on commit 7b7729e

Please sign in to comment.