In [2]:
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image

#number convert to label
labels = ['drawings', 'hentai', 'neutral', 'porn', 'sexy']
#define CNN model
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.cnn_layers = resnet18(weights=ResNet18_Weights.DEFAULT)
        self.fc_layers = nn.Sequential(
            nn.Linear(1000, 512),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 5)
        )
    def forward(self, x):

        # Extract features by convolutional layers.
        x = self.cnn_layers(x)
        x = self.fc_layers(x)
        return x
# pre-process
preprocess = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
#load model
model = Classifier()
model.load_state_dict(torch.load('classify_nsfw_v1.pth')) #可根据自身设备自行选择cpu or GPU，若运行在cpu上可在load添加参数 map_location='cpu'
model.eval()
def predict(inp):
  inp = preprocess(inp).unsqueeze(0)
  with torch.no_grad():
    prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
  return {labels[i]: float(prediction[i]) for i in range(5)}

inputs = gr.inputs.Image(type='pil')
outputs = gr.outputs.Label(num_top_classes=2)
gr.Interface(fn=predict, inputs=inputs, outputs=outputs,
             examples=["./example/anime.jpg"]).launch()



Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


