In [None]:
import torch
import torchvision.transforms as transforms
from lime import lime_image
from skimage.segmentation import mark_boundaries

from PIL import Image
import io
import os
import ipywidgets as widgets
from IPython.display import display, HTML
import matplotlib.pyplot as plt
from torchvision import models
import torch.nn as nn
import pandas as pd


# Loading saved model
import torchvision.models as models
import torch.nn as nn
import torch
model_path="/content/food101_model_2 (1).pth"
def load_model(model_path):
    # Load ResNet50
    model = models.resnet50(weights=None)

    # Matching the training architecture
    model.fc = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(model.fc.in_features, 101)
    )

    # Loading saved weights
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()

    idx_to_class = checkpoint['idx_to_class']
    return model, idx_to_class




# Nutrition Function
def load_data():
    nutrition_df = pd.read_csv('/content/nutrition.csv')  # nutrition file path
    return None, nutrition_df
def calculate_nutrition(food_name, weight, nutrition_df):
    food_data = nutrition_df[nutrition_df['label'].str.contains(food_name, case=False)]
    if food_data.empty:
        return None
    exact_match = food_data[food_data['weight'] == float(weight)]
    if not exact_match.empty:
        return {
            'Calories': exact_match['calories'].iloc[0],
            'Protein (g)': exact_match['protein'].iloc[0],
            'Carbohydrates (g)': exact_match['carbohydrates'].iloc[0],
            'Fat (g)': exact_match['fats'].iloc[0],
            'Fiber (g)': exact_match['fiber'].iloc[0],
            'Sugars (g)': exact_match['sugars'].iloc[0],
            'Sodium (mg)': exact_match['sodium'].iloc[0]
        }
    return None
#Displaying nutrition table
def display_nutrition(nutrients):
    if nutrients is None:
        return HTML("<p>No nutrition data available for this food.</p>")

    html = """
    <style>
        .nutrition-table {
            width: 50%;
            border-collapse: collapse;
            background-color: #1c1c1c;
            color: white;
        }
        .nutrition-table th {
            background-color: #1c1c1c;
            color: white;
            padding: 8px;
            text-align: center;
            border: 1px solid #333;
        }
        .nutrition-table td {
            padding: 8px;
            border: 1px solid #333;
        }
        .nutrient-name {
            text-align: right;
            color: #3498db;
        }
        .nutrient-value {
            text-align: right;
        }
    </style>
    <h3 style='color: white;'>Nutritional Information</h3>
    <table class='nutrition-table'>
        <tr>
            <th>Nutrient</th>
            <th>Value</th>
        </tr>
    """
    important_nutrients = ['Calories', 'Protein (g)', 'Carbohydrates (g)', 'Fat (g)', 'Fiber (g)', 'Sugars (g)', 'Sodium (mg)']

    for nutrient in important_nutrients:
        if nutrient in nutrients:
            value = nutrients[nutrient]
            html += f"""
                <tr>
                    <td class='nutrient-name'>{nutrient}</td>
                    <td class='nutrient-value'>{value:.1f}</td>
                </tr>
            """
    html += "</table>"
    return HTML(html)

# Classifying Images

def classify_image(model, idx_to_class, image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    img_tensor = transform(image).unsqueeze(0)
    with torch.no_grad():
        outputs = model(img_tensor)
    predicted_idx = torch.argmax(outputs, dim=1).item()
    return idx_to_class[predicted_idx]

# Running final interactive tool

def run_final_demo():
  model_path="/content/food101_model_2 (1).pth"
  model, idx_to_class = load_model(model_path)
  _, nutrition_df = load_data()

    # Upload image
  image_upload = widgets.FileUpload(accept='image/*', multiple=False)
  weight_dropdown = widgets.Dropdown(description="Weight (g):", options=[], layout=widgets.Layout(width="40%"))
  analyze_button = widgets.Button(description='Analyze Image', button_style='primary')

  output_image = widgets.Output()
  output_prediction = widgets.Output()
  output_nutrition = widgets.Output()

  def on_image_upload(change):
        output_image.clear_output()
        output_prediction.clear_output()
        output_nutrition.clear_output()

        if image_upload.value:
            content = list(image_upload.value.values())[0]['content']
            image = Image.open(io.BytesIO(content)).convert("RGB")


            with output_image:
                plt.figure(figsize=(5,5))
                plt.imshow(image)
                plt.axis('off')
                plt.title("Uploaded Image")
                plt.show()

            # Classify image
            predicted_label = classify_image(model, idx_to_class, image)

            with output_prediction:
                display(HTML(f"<h3 style='color:white;'>Predicted Food: <span style='color:#2ecc71'>{predicted_label}</span></h3>"))

            # Update weight dropdown
            food_data = nutrition_df[nutrition_df['label'].str.contains(predicted_label, case=False)]
            available_weights = sorted({int(w) for w in food_data['weight'].values})
            weight_dropdown.options = [str(w) for w in available_weights]
            if available_weights:
                weight_dropdown.value = str(available_weights[0])
            weight_dropdown.predicted_label = predicted_label
            weight_dropdown.image = image

  image_upload.observe(on_image_upload, names='value')

  def on_analyze_button_clicked(b):
        predicted_label = weight_dropdown.predicted_label
        selected_weight = weight_dropdown.value
        nutrients = calculate_nutrition(predicted_label, selected_weight, nutrition_df)

        output_nutrition.clear_output()
        with output_nutrition:
            display(display_nutrition(nutrients))

  analyze_button.on_click(on_analyze_button_clicked)

  display(widgets.VBox([
        widgets.HTML("<h2 style='color:white;'>Upload a Food Image</h2>"),
        image_upload,
        output_image,
        output_prediction,
        widgets.HBox([weight_dropdown, analyze_button]),
        output_nutrition
    ]))

# Run it!
run_final_demo()


VBox(children=(HTML(value="<h2 style='color:white;'>Upload a Food Image</h2>"), FileUpload(value={}, accept='i…

UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x7e7d5fe137e0>

In [None]:
!pip install lime scikit-image


Collecting lime
  Downloading lime-0.2.0.1.tar.gz (275 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/275.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━[0m [32m174.1/275.7 kB[0m [31m5.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m275.7/275.7 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: lime
  Building wheel for lime (setup.py) ... [?25l[?25hdone
  Created wheel for lime: filename=lime-0.2.0.1-py3-none-any.whl size=283834 sha256=30e0e33d44c229b856ebbfa7b02dbf97790de0b1b754ae5283dfc11778b4808d
  Stored in directory: /root/.cache/pip/wheels/85/fa/a3/9c2d44c9f3cd77cf4e533b58900b2bf4487f2a17e8ec212a3d
Successfully built lime
Installing collected packages: lime
Successfully installed lime-0.2.0.1


In [6]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import os
import ipywidgets as widgets
from IPython.display import display, HTML
import matplotlib.pyplot as plt
from torchvision import models
import torch.nn as nn
import pandas as pd
import numpy as np

# LIME & Visualization
from lime import lime_image
from skimage.segmentation import mark_boundaries


# Load model
model_path = "/content/food101_model_2 (1).pth"

def load_model(model_path):
    model = models.resnet50(weights=None)
    model.fc = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(model.fc.in_features, 101)
    )
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    idx_to_class = checkpoint['idx_to_class']
    return model, idx_to_class

# Load nutrition data
def load_data():
    nutrition_df = pd.read_csv('/content/nutrition.csv')
    return None, nutrition_df

def calculate_nutrition(food_name, weight, nutrition_df):
    food_data = nutrition_df[nutrition_df['label'].str.contains(food_name, case=False)]
    if food_data.empty:
        return None
    exact_match = food_data[food_data['weight'] == float(weight)]
    if not exact_match.empty:
        return {
            'Calories': exact_match['calories'].iloc[0],
            'Protein (g)': exact_match['protein'].iloc[0],
            'Carbohydrates (g)': exact_match['carbohydrates'].iloc[0],
            'Fat (g)': exact_match['fats'].iloc[0],
            'Fiber (g)': exact_match['fiber'].iloc[0],
            'Sugars (g)': exact_match['sugars'].iloc[0],
            'Sodium (mg)': exact_match['sodium'].iloc[0]
        }
    return None

def display_nutrition(nutrients):
    if nutrients is None:
        return HTML("<p>No nutrition data available for this food.</p>")
    html = """
    <style>
        .nutrition-table {
            width: 50%;
            border-collapse: collapse;
            background-color: #1c1c1c;
            color: white;
        }
        .nutrition-table th {
            background-color: #1c1c1c;
            color: white;
            padding: 8px;
            text-align: center;
            border: 1px solid #333;
        }
        .nutrition-table td {
            padding: 8px;
            border: 1px solid #333;
        }
        .nutrient-name {
            text-align: right;
            color: #3498db;
        }
        .nutrient-value {
            text-align: right;
        }
    </style>
    <h3 style='color: white;'>Nutritional Information</h3>
    <table class='nutrition-table'>
        <tr>
            <th>Nutrient</th>
            <th>Value</th>
        </tr>
    """
    important_nutrients = ['Calories', 'Protein (g)', 'Carbohydrates (g)', 'Fat (g)', 'Fiber (g)', 'Sugars (g)', 'Sodium (mg)']
    for nutrient in important_nutrients:
        if nutrient in nutrients:
            value = nutrients[nutrient]
            html += f"""
                <tr>
                    <td class='nutrient-name'>{nutrient}</td>
                    <td class='nutrient-value'>{value:.1f}</td>
                </tr>
            """
    html += "</table>"
    return HTML(html)

# Classification function
def classify_image(model, idx_to_class, image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    img_tensor = transform(image).unsqueeze(0)
    with torch.no_grad():
        outputs = model(img_tensor)
    predicted_idx = torch.argmax(outputs, dim=1).item()
    return idx_to_class[predicted_idx]

# Final demo function
def run_final_demo():
    model, idx_to_class = load_model(model_path)
    _, nutrition_df = load_data()

    image_upload = widgets.FileUpload(accept='image/*', multiple=False)
    weight_dropdown = widgets.Dropdown(description="Weight (g):", options=[], layout=widgets.Layout(width="40%"))
    analyze_button = widgets.Button(description='Analyze Image', button_style='primary')

    output_image = widgets.Output()
    output_prediction = widgets.Output()
    output_nutrition = widgets.Output()
    lime_image_display = widgets.Output()

    def on_image_upload(change):
        output_image.clear_output()
        output_prediction.clear_output()
        output_nutrition.clear_output()
        lime_image_display.clear_output()

        if image_upload.value:
            content = list(image_upload.value.values())[0]['content']
            image = Image.open(io.BytesIO(content)).convert("RGB")

            with output_image:
                plt.figure(figsize=(5,5))
                plt.imshow(image)
                plt.axis('off')
                plt.title("Uploaded Image")
                plt.show()

            # Predict
            predicted_label = classify_image(model, idx_to_class, image)

            with output_prediction:
                display(HTML(f"<h3 style='color:white;'>Predicted Food: <span style='color:#2ecc71'>{predicted_label}</span></h3>"))

            # Update weight dropdown
            food_data = nutrition_df[nutrition_df['label'].str.contains(predicted_label, case=False)]
            available_weights = sorted({int(w) for w in food_data['weight'].values})
            weight_dropdown.options = [str(w) for w in available_weights]
            if available_weights:
                weight_dropdown.value = str(available_weights[0])
            weight_dropdown.predicted_label = predicted_label
            weight_dropdown.image = image

            # LIME explanation
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
            ])

            def batch_predict(images):
                model.eval()
                batch = torch.stack([transform(Image.fromarray(img)).unsqueeze(0).squeeze(0) for img in images], dim=0)
                with torch.no_grad():
                    logits = model(batch)
                return torch.nn.functional.softmax(logits, dim=1).numpy()

            explainer = lime_image.LimeImageExplainer()
            image_np = np.array(image)

            explanation = explainer.explain_instance(
                image_np,
                batch_predict,
                top_labels=1,
                hide_color=0,
                num_samples=1000
            )

            with lime_image_display:
                temp, mask = explanation.get_image_and_mask(
                    explanation.top_labels[0],
                    positive_only=True,
                    num_features=5,
                    hide_rest=False
                )
                plt.figure(figsize=(5,5))
                plt.imshow(mark_boundaries(temp / 255.0, mask))
                plt.title("LIME Explanation")
                plt.axis('off')
                plt.show()

    image_upload.observe(on_image_upload, names='value')

    def on_analyze_button_clicked(b):
        predicted_label = weight_dropdown.predicted_label
        selected_weight = weight_dropdown.value
        nutrients = calculate_nutrition(predicted_label, selected_weight, nutrition_df)

        output_nutrition.clear_output()
        with output_nutrition:
            display(display_nutrition(nutrients))

    analyze_button.on_click(on_analyze_button_clicked)

    display(widgets.VBox([
        widgets.HTML("<h2 style='color:white;'>Upload a Food Image</h2>"),
        image_upload,
        output_image,
        output_prediction,
        lime_image_display,
        widgets.HBox([weight_dropdown, analyze_button]),
        output_nutrition
    ]))

# Run
run_final_demo()


VBox(children=(HTML(value="<h2 style='color:white;'>Upload a Food Image</h2>"), FileUpload(value={}, accept='i…

  0%|          | 0/1000 [00:00<?, ?it/s]