# Machine Learning

In [1]:
import io
import torch
import torch.nn as nn
import torchvision.transforms as transforms

from PIL import Image


# device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load model
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.relu    = nn.ReLU()
        self.linear2 = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        out = self.linear1(x)
        out = self.relu(out)
        out = self.linear2(out)
        return out
    
model =  torch.load('models/mnist_ffn.pth')
model.eval()
model.to(device)

# image -> tensor
def transform_image(img_bytes):
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((28,28)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
    img = Image.open(io.BytesIO(img_bytes))
    return transform(img).unsqueeze(0)

# predict
def predict_image(img_tensor):
    images  = img_tensor.reshape(-1, 28*28).to(device)
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    return predicted
    

# Flask App

In [2]:
from flask import Flask, request, jsonify
from re import match

app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
    file = request.files.get('file')

    if not file or not file.filename:
        return jsonify({'error': 'no file'})  
    
    try:
        img_bytes  = file.read()
        img_tensor = transform_image(img_bytes)
        prediction = predict_image(img_tensor)
        return jsonify({'prediction': prediction.item()})
    except Exception as ex:
        return jsonify({'error', ex})
    
app.run()


 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: off


 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
127.0.0.1 - - [11/Feb/2021 00:18:55] "[37mPOST /predict HTTP/1.1[0m" 200 -
127.0.0.1 - - [11/Feb/2021 00:22:24] "[37mPOST /predict HTTP/1.1[0m" 200 -
127.0.0.1 - - [11/Feb/2021 00:22:45] "[37mPOST /predict HTTP/1.1[0m" 200 -
