-
Notifications
You must be signed in to change notification settings - Fork 0
/
torch_csrnet.py
80 lines (67 loc) Β· 2.71 KB
/
torch_csrnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import PIL.Image as Image
from torchvision import transforms
import matplotlib.pyplot as plt
import random
import os
import torch
from torch_Model import CSRNet
def get_model_a():
model = CSRNet()
checkpoint = torch.load('.\\model\\PartAmodel_best.pth.tar', map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
return model
def get_model_b():
model = CSRNet()
checkpoint = torch.load('.\\model\\PartBmodel_best.pth.tar', map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
return model
# Access commons
model = get_model_a()
# Standard RGB transform
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])
def get_prediction(file):
img = transform(Image.open(file).convert('RGB'))
img = img.cpu()
output = model(img.unsqueeze(0))
prediction = int(output.detach().cpu().sum().numpy())
x = random.randint(1, 100000)
density = '.\\csr_rtn\\density_map' + str(x) + '.jpg'
plt.imsave(density, output.detach().cpu().numpy()[0][0])
return prediction, density
def torch_to_onnx(type):
# Instantiate your model. This is just a regular PyTorch model that will be exported in the following steps.
type_lower = str(type).lower()
if type_lower == 'b':
model = get_model_b()
# Evaluate the model to switch some operations from training mode to inference.
model.eval()
# Create dummy input for the model. It will be used to run the model inside export function.
dummy_input = torch.randn(1, 3, 224, 224)
# Call the export function
torch.onnx.export(model, (dummy_input,), '.\\model_b.onnx')
elif type_lower == 'a':
model = get_model_a()
# Evaluate the model to switch some operations from training mode to inference.
model.eval()
# Create dummy input for the model. It will be used to run the model inside export function.
dummy_input = torch.randn(1, 3, 224, 224)
# Call the export function
torch.onnx.export(model, (dummy_input,), '.\\model_a.onnx')
if __name__ == '__main__':
torch_to_onnx('a')
torch_to_onnx('b')
'''
files_dir = '.\\csr_rtn\\images_a\\'
files_rtn_dir = '.\\csr_rtn\\'
files_in_dir = os.listdir(files_dir)
filtered_files = [file for file in files_in_dir if file.endswith(".jpg") or file.endswith(".jpeg")]
for file in filtered_files:
path = os.path.join(files_dir, file)
# file_obj = open(path)
file_path = os.path.abspath(path)
prediction, density = get_prediction(file_path)
print(path, prediction, density)
'''