In [3]:
from fastapi import FastAPI, UploadFile, File
import torch
from pydantic import BaseModel
from torchvision import models
import torch.nn as nn
import torch.optim as optim

# Define MLP model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(4, 10),  # 4 features to 10 hidden nodes
            nn.ReLU(),
            nn.Linear(10, 1),
            nn.Sigmoid()  # Output a single probability for binary classification
        )
    
    def forward(self, x):
        return self.layers(x)

# Load the models
mlp_model = MLP()
mlp_model.load_state_dict(torch.load('mlp_model.pth'))
mlp_model.eval()

image_model = models.resnet18(weights='ResNet18_Weights.DEFAULT')
# Modify the final fully connected layer to match the number of classes in your dataset
num_ftrs = image_model.fc.in_features
image_model.fc = nn.Linear(num_ftrs, 2)
image_model.load_state_dict(torch.load('image_model.pth'))


image_model.eval()




ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
app = FastAPI()

class MLPInput(BaseModel):
    feature1: float
    feature2: float
    feature3: float
    feature4: float

@app.post("/predict/mlp")
def predict_mlp(input: MLPInput):
    data = torch.tensor([[input.feature1, input.feature2, input.feature3, input.feature4]], dtype=torch.float32)
    prediction = mlp_model(data)
    return {"prediction": prediction.item()}

@app.post("/predict/image")
async def predict_image(file: UploadFile = File(...)):
    image = await file.read()
    # Preprocess the image here
    # image_tensor = preprocess_image(image)
    image_tensor = torch.randn(1, 3, 224, 224)  # Dummy tensor for example
    prediction = image_model(image_tensor)
    return {"prediction": prediction.item()}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)