-
Notifications
You must be signed in to change notification settings - Fork 3
/
app.py
133 lines (113 loc) · 3.87 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import warnings
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
import json
import os
import torch
from scipy.ndimage import gaussian_filter
import cv2
from method import AdaCLIP_Trainer
import numpy as np
############ Init Model
ckt_path1 = 'weights/pretrained_mvtec_colondb.pth'
ckt_path2 = "weights/pretrained_visa_clinicdb.pth"
ckt_path3 = 'weights/pretrained_all.pth'
# Configurations
image_size = 518
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
model = "ViT-L-14-336"
prompting_depth = 4
prompting_length = 5
prompting_type = 'SD'
prompting_branch = 'VL'
use_hsf = True
k_clusters = 20
config_path = os.path.join('./model_configs', f'{model}.json')
# Prepare model
with open(config_path, 'r') as f:
model_configs = json.load(f)
# Set up the feature hierarchy
n_layers = model_configs['vision_cfg']['layers']
substage = n_layers // 4
features_list = [substage, substage * 2, substage * 3, substage * 4]
model = AdaCLIP_Trainer(
backbone=model,
feat_list=features_list,
input_dim=model_configs['vision_cfg']['width'],
output_dim=model_configs['embed_dim'],
learning_rate=0.,
device=device,
image_size=image_size,
prompting_depth=prompting_depth,
prompting_length=prompting_length,
prompting_branch=prompting_branch,
prompting_type=prompting_type,
use_hsf=use_hsf,
k_clusters=k_clusters
).to(device)
def process_image(image, text, options):
# Load the model based on selected options
if 'MVTec AD+Colondb' in options:
model.load(ckt_path1)
elif 'VisA+Clinicdb' in options:
model.load(ckt_path2)
elif 'All' in options:
model.load(ckt_path3)
else:
# Default to 'All' if no valid option is provided
model.load(ckt_path3)
print('Invalid option. Defaulting to All.')
# Ensure image is in RGB mode
image = image.convert('RGB')
# Convert PIL image to NumPy array
np_image = np.array(image)
# Convert RGB to BGR for OpenCV
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
np_image = cv2.resize(np_image, (image_size, image_size))
# Preprocess the image and run the model
img_input = model.preprocess(image).unsqueeze(0)
img_input = img_input.to(model.device)
with torch.no_grad():
anomaly_map, anomaly_score = model.clip_model(img_input, [text], aggregation=True)
# Process anomaly map
anomaly_map = anomaly_map[0, :, :].cpu().numpy()
anomaly_score = anomaly_score[0].cpu().numpy()
anomaly_map = gaussian_filter(anomaly_map, sigma=4)
anomaly_map = (anomaly_map * 255).astype(np.uint8)
# Apply color map and blend with original image
heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET)
vis_map = cv2.addWeighted(heat_map, 0.5, np_image, 0.5, 0)
# Convert OpenCV image back to PIL image for Gradio
vis_map_pil = Image.fromarray(cv2.cvtColor(vis_map, cv2.COLOR_BGR2RGB))
return vis_map_pil, f'{anomaly_score:.3f}'
# Define examples
examples = [
["asset/img.png", "candle", "MVTec AD+Colondb"],
["asset/img2.png", "bottle", "VisA+Clinicdb"],
["asset/img3.png", "button", "All"],
]
# Gradio interface layout
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(label="Class Name"),
gr.Radio(["MVTec AD+Colondb",
"VisA+Clinicdb",
"All"],
label="Pre-trained Datasets")
],
outputs=[
gr.Image(type="pil", label="Output Image"),
gr.Textbox(label="Anomaly Score"),
],
examples=examples,
title="AdaCLIP -- Zero-shot Anomaly Detection",
description="Upload an image, enter class name, and select pre-trained datasets to do zero-shot anomaly detection"
)
# Launch the demo
demo.launch()
# demo.launch(server_name="0.0.0.0", server_port=10002)