In [1]:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from torchvision import models, transforms
from torchvision import datasets
import cv2
from sklearn.preprocessing import LabelEncoder
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib import patches

Load trained model:

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = models.resnet18(pretrained=False)
num_classes = len(datasets.ImageFolder("food-101/food-101/images").classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load("best_model.pth"))
model = model.to(device)
model.eval()

# Preprocess the image to feed into the model
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

Load Semantic Segmentation Model (DeepLabV3):

In [None]:
seg_model = models.segmentation.deeplabv3_resnet101(pretrained=True)
seg_model = seg_model.to(device)
seg_model.eval()

Segment and Classify Each Segment:

def segment_and_classify(image_path, model, seg_model, transform, device):
    image = Image.open(image_path).convert("RGB")
    original_image = np.array(image)

    # Semantic segmentation (DeepLabV3)
    input_image = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = seg_model(input_image)['out'][0]
    
    output_predictions = torch.argmax(output, dim=0).cpu().numpy()
    
    # Convert segmentation to bounding boxes and labels for each segment
    labels = np.unique(output_predictions)  # All unique food categories from the segmentation mask
    segments = defaultdict(list)
    
    for label in labels:
        # Find pixels for this label
        mask = (output_predictions == label)
        segment_pixels = np.where(mask)
        segments[label].append(segment_pixels)

    # Classify each segment and estimate its size
    results = []
    for label, pixels in segments.items():
        # Extract the segment region
        y, x = pixels[0], pixels[1]
        min_x, max_x = np.min(x), np.max(x)
        min_y, max_y = np.min(y), np.max(y)
        
        # Crop the region of the image
        cropped_img = original_image[min_y:max_y+1, min_x:max_x+1]
        cropped_img_pil = Image.fromarray(cropped_img)
        
        # Preprocess cropped image for classification
        cropped_img_tensor = transform(cropped_img_pil).unsqueeze(0).to(device)
        
        # Classify the food segment
        with torch.no_grad():
            output = model(cropped_img_tensor)
            _, predicted_class = torch.max(output, 1)
        
        # Get the label for the class
        class_label = model.fc.classes[predicted_class.item()]
        
        # Estimate segment size in pixels
        segment_area = len(y)  # Pixel count for the segment
        results.append({
            'food_type': class_label,
            'area_pixels': segment_area,
            'segment_coords': (min_x, max_x, min_y, max_y)
        })
    
    return results, output_predictions

# Function to visualize the segmentation and classification results
def visualize_segmentation(image_path, results, output_predictions):
    image = Image.open(image_path)
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(image)

    for result in results:
        food_type = result['food_type']
        min_x, max_x, min_y, max_y = result['segment_coords']
        rect = patches.Rectangle((min_x, min_y), max_x - min_x, max_y - min_y,
                                 linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        ax.text(min_x, min_y, f"{food_type}: {result['area_pixels']}px", color="red", fontsize=12)

    plt.show()


Main Function:

In [None]:
def main(image_path):
    results, output_predictions = segment_and_classify(image_path, model, seg_model, transform, device)
    visualize_segmentation(image_path, results, output_predictions)
    return results

# Example usage
image_path = "path_to_your_image.jpg"
results = main(image_path)