In [1]:
import os
import torch
import numpy as np
import tensorflow as tf
from io import BytesIO
from PIL import Image
from tensorflow.keras.models import load_model
from flask import Flask, request, jsonify, send_file, Response
from database import save_data_to_database, get_image_data, save_wigSynthesis
from subprocess_code import prepro_image, synthesis_image
from inference import Predictor, CELEBRITY_LABELS, ANIMAL_LABELS
from zipfile import ZipFile
from beauty_supplies import search_supplies

app = Flask(__name__)

# 모델폴더
model_folder = "./PredictModels"

# 이미지 전처리 함수
def preprocess_image(img_path):
    img = Image.open(img_path)
    img = img.resize((192, 192))
    img = img.convert("RGB")
    img = np.asarray(img)
    img = img.astype('float32') / 255
    return img

# 모델을 사용한 이미지 분류 함수
def predict_image_with_models(image_path, model_folder):
    categories = np.array(["0", "1", "2", "3"])
    scalp_category = ["미세각질", "피지과다", "모낭사이홍반", "모낭홍반/농포", "비듬", "탈모"]

    model_files = [f for f in os.listdir(model_folder) if f.endswith(".h5")]
    print(model_files)
    alter_predict = []

    predict1 = []
    predict2 = []
    predict3 = []
    predict4 = []
    predict5 = []
    predict6 = []
    image = Image.open(image_path)

    with tf.device("/device:CPU:0"):
        for i, model_file in enumerate(model_files):
            print(os.getcwd())
            print(model_file)
            model_path = model_folder + '/' + model_file
            print(model_path)
            model = load_model(model_path)
            test = []
            test.append(preprocess_image(image_path))
            test = np.array(test)

            
            predict = model.predict(test)
            predicted_class_index = predict[0]
            
            # Print class probabilities
            print(f"Probabilities for model {model_file} ({scalp_category[i]}):")
            for class_index, probability in enumerate(predicted_class_index):
                print(f'Class {class_index}: {probability * 100}%')
                [predict1, predict2, predict3, predict4, predict5, predict6][i].append(round(probability * 100, 2))
                
            print([predict1, predict2, predict3, predict4, predict5, predict6][i])
            predicted_class_index = np.argmax(predicted_class_index)
            predicted_category = categories[predicted_class_index]

            print(f"모델 파일: {model_file}")
            print("예측된 카테고리(" +  scalp_category[i] +  "): " +  predicted_category + "\n")

            alter_predict.append(predicted_category)

    print("alter_predict: ", alter_predict, "\n")
    
    # 피부진단 결과 계산
    skin_diagnosis = calculate_skin_diagnosis(alter_predict)

    return skin_diagnosis, alter_predict, predict1, predict2, predict3, predict4, predict5, predict6

def calculate_skin_diagnosis(alter_predict):
    # 여기서부터 피부진단 시작
    # 양호: 배열 안에 값이 전부 0인 경우
    if all(category == '0' for category in alter_predict):
        return "양호"

    # 건성: 배열 첫 번째 값은 3이고, 배열 두 번째 값은 1~3, 배열 5번째 값은 1~3 나머지는 3이 아님
    if (
        alter_predict[0] == '3' and
        (1 <= int(alter_predict[1]) <= 3) and
        (1 <= int(alter_predict[4]) <= 3) and
        all(category != '3' for category in alter_predict[2:4]) and
        all(category != '3' for category in alter_predict[5:])
    ):
        return "건성"

    # 지성: 배열 두 번째 값이 3이고 나머지 값이 3이 아닌 경우
    if (
        (1 <= int(alter_predict[1]) <= 3) and
        all(category != '3' for category in alter_predict[2:])
    ):
        return "지성"

    # 민감성: 배열 세 번째 값이 3이고 나머지 값이 3이 아닌 경우
    if (
        (0 <= int(alter_predict[0]) <= 3) and
        alter_predict[2] == '3' and
        all(category != '3' for category in alter_predict[3:])
    ):
        return "민감성"

    # 지루성: 배열 첫 번째 값이 3이고 다섯 번째, 여섯 번째 값이 2~3이고 나머지가 3이 아닌 경우
    if (
        alter_predict[0] == '3' and
        (2 <= int(alter_predict[4]) <= 3) and
        (2 <= int(alter_predict[5]) <= 3) and
        all(category != '3' for category in alter_predict[1:4])
    ):
        return "지루성"

    # 염증성: 배열 네 번째 값이 3이고 세 번째 값이 2~3이면서 나머지가 3이 아닌 경우
    if (
        alter_predict[2] == '3' and
        alter_predict[3] == '3'
    ):
        return "염증성"

    # 비듬성: 배열 2, 5번째 값이 2~3이고 나머지 배열 값이 3이 아닌 경우
    if (
        alter_predict[0] == '3' and
        (2 <= int(alter_predict[1]) <= 3) and
        (2 <= int(alter_predict[2]) <= 3) and
        (2 <= int(alter_predict[3]) <= 3) and
        (2 <= int(alter_predict[4]) <= 3) and
        alter_predict[5] != '3'
    ):
        return "비듬성"

    # 탈모: 배열 6번째 값이 3이고 나머지 배열 값이 3이 아닌 경우
    if (
        alter_predict[5] == '3' and
        all(category != '3' for category in alter_predict[:5])
    ):
        return "탈모"

    # 기타: 위 조건에 해당하지 않는 경우
    return "복합성"

@app.route('/upload', methods=['POST'])
def upload_image():
    if 'image' not in request.files:
        return jsonify({'error': 'No image part'})

    file = request.files['image']
    userid = request.form['userid']

    if file.filename == '':
        return jsonify({'error': 'No selected file'})

    if file:
        # 이미지를 저장할 경로를 지정합니다.
        upload_dir = './skin_upload'
        if not os.path.exists(upload_dir):
            os.makedirs(upload_dir)
        file_path = os.path.join(upload_dir, file.filename)
        file.save(file_path)

        # 이미지 분석 및 진단 결과 얻기
        diagnosis_result, alter_predict, predict1, predict2, predict3, predict4, predict5, predict6 = predict_image_with_models(file_path, model_folder)
        print(diagnosis_result)
        print(*alter_predict)
        predict1_str = ','.join(map(str, [int(round(x)) for x in predict1]))
        predict2_str = ','.join(map(str, [int(round(x)) for x in predict2]))
        predict3_str = ','.join(map(str, [int(round(x)) for x in predict3]))
        predict4_str = ','.join(map(str, [int(round(x)) for x in predict4]))
        predict5_str = ','.join(map(str, [int(round(x)) for x in predict5]))
        predict6_str = ','.join(map(str, [int(round(x)) for x in predict6]))

        print(predict1_str)
        print(predict2_str)
        print(predict3_str)
        print(predict4_str)
        print(predict5_str)
        print(predict6_str)
        # 진단 결과와 alter_predict를 함께 JSON 응답으로 반환
        response_data = {'message': diagnosis_result,'predict1': predict1_str, 
                        'predict2': predict2_str, 'predict3': predict3_str,
                        'predict4': predict4_str, 'predict5': predict5_str,
                        'predict6': predict6_str }
        
        # 데이터베이스에 데이터 저장
        save_data_to_database(file_path, diagnosis_result, alter_predict, userid)

        print("Mysql 전송 완료")
        
        return jsonify(response_data)
    
selected_wig = 0 # default

@app.route('/wear_wig', methods=['POST'])
def wear_wig():
    print(torch.cuda.is_available())
    os.environ["CUDA_VISIBLE_DEVICES"]="0"
    # 이미지 파일 받기
    image = request.files['image']
    # 선택된 가발 값 받기
    selected_wig = int(request.form['selectedWig'])
    # 사용자 ID 받기
    userid = request.form['userid']
    
    print(userid)

    source_path = './unprocessed/source.jpg'
    image.save(source_path)  # 이미지를 지정된 경로에 저장
        
    # 이곳에서 이미지 처리 및 가발 적용 로직을 구현
    script_path = "align_face.py"
        
    result1 = prepro_image(script_path)
    print(result1)
        
    label = 'male'

    ref_path = './images/ref/target' + str(selected_wig+1) + '.png'
    src = Image.open('./images/src/source.png').convert("RGB")
    ref = Image.open(ref_path).convert("RGB")

    entity = "celebrity"

    predictor = Predictor(entity=entity)
    predictor.create_interpolation(label, src_image=src, ref_image=ref)
        
    synthesis_output = "./images/res.jpg"
    
    # 이미지 처리 완료 후, PNG 이미지를 열고 JPG로 변환
    with Image.open(synthesis_output) as img:
        img = img.convert('RGB')
        jpg_output = synthesis_output.replace(".png", ".jpg")
        img.save(jpg_output, "JPEG")

    jpg_path = "./images/res.jpg"
    
    save_wigSynthesis(userid, ref_path, source_path, synthesis_output)
    
    print("Mysql 전송 완료")
    
    # JPG 파일을 클라이언트로 전송
    return send_file(jpg_path, mimetype='image/jpeg')


# 이미지 파일들을 압축하여 클라이언트에게 보내는 엔드포인트
@app.route('/progress_status', methods=['GET'])
def progress_status():
    try:
        user_id = request.args.get('userId')
        zip_file_path = get_image_data(user_id)
        
        print(zip_file_path)
        download_name = os.path.basename(zip_file_path)
        print(download_name)
        # 클라이언트에게 압축 파일을 전송
        response = send_file(zip_file_path, as_attachment=False, download_name=download_name, mimetype='application/zip')
        
        return response
    except Exception as e:
        return str(e)
    
@app.route('/beauty_supplies', methods=['GET'])
def start_crawl():
    symptom = request.args.get('symptom')
    print("받은 증상: " + symptom)
    search_supplies(symptom)
    return send_file('./supplies_result.txt', as_attachment=True, mimetype='text/plain')
    
if __name__ == '__main__':
    app.run(host='192.168.35.4', port=5000, debug=False, use_reloader=False, threaded=True)

c:\jupyter\capston\bald\hairskinWig
 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://192.168.35.4:5000
Press CTRL+C to quit


True
mingky0603
전처리 완료


  small_blurred = gaussian(cv2.resize(img, (W, H)), H//100, multichannel=True)


Mysql 전송 완료


192.168.35.4 - - [02/Dec/2023 02:44:13] "POST /wear_wig HTTP/1.1" 200 -
