In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import copy
from flask import Flask, request, jsonify
from PIL import Image
import io
import requests
import zipfile

# Baixando e preparando o dataset
data_dir = 'data/hymenoptera_data'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
    url = 'https://download.pytorch.org/tutorial/hymenoptera_data.zip'
    r = requests.get(url, allow_redirects=True)
    zip_path = os.path.join(data_dir, 'hymenoptera_data.zip')
    open(zip_path, 'wb').write(r.content)
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(data_dir)

# Configurações do modelo e dataset
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Carregamento do modelo pré-treinado
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    model.load_state_dict(best_model_wts)
    return model

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)

# Salvando o modelo treinado
torch.save(model_ft.state_dict(), 'model.pth')

# Criação da API Flask
app = Flask(__name__)

def transform_image(image_bytes):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image = Image.open(io.BytesIO(image_bytes))
    return transform(image).unsqueeze(0)

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    tensor = tensor.to(device)
    model_ft.eval()
    outputs = model_ft(tensor)
    _, y_hat = outputs.max(1)
    return class_names[y_hat]

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_name': class_name})

if __name__ == '__main__':
    app.run(debug=True)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 93.5MB/s]


Epoch 0/24
----------
train Loss: 0.5578 Acc: 0.7418
val Loss: 0.1996 Acc: 0.9150

Epoch 1/24
----------
train Loss: 0.3824 Acc: 0.8279
val Loss: 0.2633 Acc: 0.8889

Epoch 2/24
----------
train Loss: 0.4951 Acc: 0.8402
val Loss: 0.4846 Acc: 0.8235

Epoch 3/24
----------
train Loss: 0.6224 Acc: 0.7623
val Loss: 0.6186 Acc: 0.7843

Epoch 4/24
----------
train Loss: 0.6065 Acc: 0.8156
val Loss: 0.4387 Acc: 0.8366

Epoch 5/24
----------
train Loss: 0.7058 Acc: 0.7582
val Loss: 0.3520 Acc: 0.8758

Epoch 6/24
----------
train Loss: 0.6321 Acc: 0.7418
val Loss: 0.5584 Acc: 0.8039

Epoch 7/24
----------
train Loss: 0.4993 Acc: 0.8033
val Loss: 0.2959 Acc: 0.8889

Epoch 8/24
----------
train Loss: 0.3509 Acc: 0.8648
val Loss: 0.2361 Acc: 0.9020

Epoch 9/24
----------
train Loss: 0.3179 Acc: 0.8607
val Loss: 0.2384 Acc: 0.8954

Epoch 10/24
----------
train Loss: 0.2527 Acc: 0.8934
val Loss: 0.2541 Acc: 0.9020

Epoch 11/24
----------
train Loss: 0.3158 Acc: 0.8934
val Loss: 0.2156 Acc: 0.9150

Ep

 * Running on http://127.0.0.1:5000
INFO:werkzeug:[33mPress CTRL+C to quit[0m
INFO:werkzeug: * Restarting with stat
