-
Notifications
You must be signed in to change notification settings - Fork 2
/
app.py
146 lines (114 loc) · 5.7 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
from flask import Flask, render_template, request
import cv2
import torch
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from gfpgan.utils import GFPGANer
from realesrgan.utils import RealESRGANer
app = Flask(__name__)
# Set the upload folder and allowed extensions
app.config['UPLOAD_FOLDER'] = 'static/uploads'
app.config['ALLOWED_EXTENSIONS'] = {'jpg', 'jpeg', 'png'}
# Create the uploads and ouputs folder if it does not exist
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs('static/outputs', exist_ok=True)
# Download weights if not available
model_weights = {
'realesr-general-x4v3.pth': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
'GFPGANv1.2.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth',
'GFPGANv1.3.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
'GFPGANv1.4.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth',
'RestoreFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth',
'CodeFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth'
}
for weight_file, weight_url in model_weights.items():
if not os.path.exists(weight_file):
os.system(f"wget {weight_url} -P .")
# Create the background enhancer with RealESRGAN
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'realesr-general-x4v3.pth'
half = True if torch.cuda.is_available() else False
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
def allowed_file(filename):
"""Check if the file extension is allowed."""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
def enhance_image(image_path, version, scale):
filename = os.path.splitext(os.path.basename(image_path))[0]
try:
if scale > 4:
scale = 4 # avoid too large scale value
extension = os.path.splitext(os.path.basename(image_path))[1]
img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
elif len(img.shape) == 2: # for gray inputs
img_mode = None
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
else:
img_mode = None
h, w = img.shape[0:2]
if h > 3500 or w > 3500:
print('Too large size')
return None, None
if h < 300:
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
if version == 'v1.2':
face_enhancer = GFPGANer(
model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
elif version == 'v1.3':
face_enhancer = GFPGANer(
model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
elif version == 'v1.4':
face_enhancer = GFPGANer(
model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
elif version == 'RestoreFormer':
face_enhancer = GFPGANer(
model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
elif version == 'CodeFormer':
face_enhancer = GFPGANer(
model_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler)
try:
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
except RuntimeError as error:
print('Error:', error)
try:
if scale != 2:
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
h, w = img.shape[0:2]
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
except Exception as error:
print('Wrong scale input.', error)
if img_mode == 'RGBA': # RGBA images should be saved in png format
extension = 'png'
else:
extension = 'jpg'
output_path = f'static/outputs/{filename}.{extension}'
cv2.imwrite(output_path, output)
return output_path, image_path
except Exception as error:
print('Global exception:', error)
return None, None
@app.route('/', methods=['GET', 'POST'])
def index():
if request.method == 'POST':
# Check if the POST request has the file part
if 'file' not in request.files:
return render_template('index.html', error='No file part')
file = request.files['file']
version = request.form['version']
scale = int(request.form['scale'])
# Validate file
if file.filename == '':
return render_template('index.html', error='No file selected')
if not allowed_file(file.filename):
return render_template('index.html', error='Invalid file type')
# Save the file
file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
file.save(file_path)
# Perform image enhancement
output_path, input_path = enhance_image(file_path, version, scale)
# Render the result page with image paths
return render_template('result.html', before_image_path=input_path, after_image_path=output_path)
return render_template('index.html')
if __name__ == '__main__':
app.run(debug=True,port=52525)