In [11]:
from flask import Flask, request, jsonify
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms

app = Flask(__name__)

In [12]:
@app.route('/')
def form():
    return """
        <html>
            <body>
                <h1>Image Classification</h1>
                </br>
                </br>
                <p> Insert your image file and then see the Result
                <form action="/predict" method="post" enctype="multipart/form-data">
                    <input type="file" name="image" class="btn btn-block"/>
                    </br>
                    </br>
                    <button type="submit" class="btn btn-primary btn-block btn-large">Predict</button>
                </form>
            </body>
        </html>
    """

In [13]:
# Load the PyTorch model
class CNN_Network(nn.Module):
    def __init__(self):
        super(CNN_Network, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = CNN_Network()
model.load_state_dict(torch.load('CNNModel.pth'))
model.eval()

# Define a list of classes for classification
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [14]:
def preprocess_image(img):
    # Define the image preprocessing transformation
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    img_tensor = transform(img)
    return img_tensor

In [15]:
@app.route('/predict', methods=['POST'])
def predict():
    # Get the image from the request
    file = request.files['image']
    img = Image.open(file.stream)

    # Preprocess the image
    img = preprocess_image(img)

    # Predict the class probabilities
    with torch.no_grad():
        outputs = model(img.unsqueeze(0))
        probs = torch.softmax(outputs, dim=1).squeeze()

    # Get the predicted class label
    _, idx = probs.max(dim=0)
    label = classes[idx]

    # Return the predicted label and probabilities
    return jsonify({
        'label': label,
        'probs': probs.tolist()
    })

In [16]:
if (__name__ == "__main__"):
     app.run(port = 5000, use_reloader=False)

 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [23/Mar/2023 20:08:37] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [23/Mar/2023 20:08:43] "POST /predict HTTP/1.1" 200 -
