In [1]:
import os
import io
import json
import torch
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request, render_template
from flask_cors import CORS
from model import CombinedModel  # 导入 ResNet-ViT 模型

app = Flask(__name__)
CORS(app)  # 解决跨域问题

weights_path = "D:\桌面\deep-learning-for-image-processing-master\deep-learning-for-image-processing-master\deploying_service\deploying_pytorch\pytorch_flask_service/trained_model_weights.pth"#权重文件路径
class_json_path = "D:\桌面\deep-learning-for-image-processing-master\deep-learning-for-image-processing-master\deploying_service\deploying_pytorch\pytorch_flask_service/class_indices.json"#对照表路径
assert os.path.exists(weights_path), "weights path does not exist..."
assert os.path.exists(class_json_path), "class json path does not exist..."
# select device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# create model
model = CombinedModel(num_classes=9).to(device)  # 使用 ResNet-ViT 模型
# load model weights
model.load_state_dict(torch.load(weights_path, map_location=device))

model.eval()

# load class info
json_file = open(class_json_path, 'rb')
class_indict = json.load(json_file)
def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    if image.mode != "RGB":
        raise ValueError("input file does not RGB image...")
    return my_transforms(image).unsqueeze(0).to(device)
def get_prediction(image_bytes):
    try:
        tensor = transform_image(image_bytes=image_bytes)
        outputs = torch.softmax(model.forward(tensor).squeeze(), dim=0)
        prediction = outputs.detach().cpu().numpy()
        template = "class:{:<15} probability:{:.3f}"
        index_pre = [(class_indict[str(index)], float(p)) for index, p in enumerate(prediction)]
        # sort probability
        index_pre.sort(key=lambda x: x[1], reverse=True)
        text = [template.format(k, v) for k, v in index_pre]
        return_info = {"result": text}
    except Exception as e:
        return_info = {"result": [str(e)]}
    return return_info

cuda:0




In [2]:
@app.route("/predict", methods=["POST"])
@torch.no_grad()
def predict():
    image = request.files["file"]
    img_bytes = image.read()
    info = get_prediction(image_bytes=img_bytes)
    return jsonify(info)
@app.route("/up1", methods=["GET"])
def up1():
    return render_template("up1.html")
@app.route("/index",methods=["GET"])
def index():
    return render_template("index.html")

In [None]:
@app.route("/", methods=["GET", "POST"])
def root():
    return render_template("index.html")


if __name__ == '__main__':
    app.run(host="0.0.0.0", port=5000)


 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on all addresses.
 * Running on http://192.168.0.114:5000/ (Press CTRL+C to quit)
192.168.0.114 - - [09/May/2023 16:21:52] "GET / HTTP/1.1" 200 -
192.168.0.114 - - [09/May/2023 16:21:52] "GET /style.css HTTP/1.1" 404 -
192.168.0.114 - - [09/May/2023 16:21:53] "GET /static/img/霍氏粉褶菌003.jpg HTTP/1.1" 200 -
192.168.0.114 - - [09/May/2023 16:21:53] "GET /static/img/双孢蘑菇001.jpg HTTP/1.1" 200 -
192.168.0.114 - - [09/May/2023 16:21:53] "GET /static/img/乳牛肝菌006.jpg HTTP/1.1" 200 -
192.168.0.114 - - [09/May/2023 16:21:53] "GET /static/img/松乳菌007.jpg HTTP/1.1" 200 -
192.168.0.114 - - [09/May/2023 16:21:53] "GET /static/img/褪色红菇008.jpg HTTP/1.1" 200 -
192.168.0.114 - - [09/May/2023 16:21:53] "GET /static/img/掷丝膜菌009.jpg HTTP/1.1" 200 -
192.168.0.114 - - [09/May/2023 16:21:53] "GET /static/img/浅黄褐湿伞005.jpg HTTP/1.1" 200 -
192.168.0.114 - - [09/May/2023 16:21:53] "GET /static/img/毒蝇伞002.jpg HTTP/1.1" 200 -
192.168.0.114 - - [09/May/2023 16:21:53] "GET /static/img/丽柄牛肝菌004.jpg HTTP/1.1" 2