<img src="static/images/datachain-logo.png" alt="Dataset" style="width: 200px;"/>

# 🚀 Run Inference Jobs using Datachain (scalable batch scoring)

In this tutorial, you'll learn how to run inference jobs on a large dataset using **[Datachain](https://github.com/iterative/datachain)** in batch processing. Batch inference is useful when you have a pre-trained model and want to make predictions on a dataset without needing real-time processing. **[Datachain](https://github.com/iterative/datachain)** provides a convenient way to run inference jobs on datasets stored in DataChain.

## 📋  Agenda

1. Load Pre-trained Model
2. Define Inference Classes
3. Run Inference
4. Save Predictions

## 🛠 Prerequisites

Before you begin, ensure you have:

- **[Datachain](https://github.com/iterative/datachain)** installed in your environment. 
- The necessary dependencies installed, including PyTorch and the required libraries.
- Access to the pre-trained model file `model.pth`.
- Access to the test dataset `fashion-test` for inference.

In [None]:
%load_ext autoreload
%autoreload 2

import os
import pathlib
import torch

from datachain import C, DataChain, DataModel
from datachain.lib.image import convert_image
from torch import optim

from src.train import CNN, transform

# 📥 Load Model (pre-trained)

- Load the pre-trained model weights and optimizer state.
- Create an instance of the `CNN` model.
- Create an optimizer using `optim.SGD`.


In [None]:
#### Load Pre-trained Model  #####

checkpoint = torch.load("model.pth")
CLASSES = checkpoint["classes"]
NUM_CLASSES = len(CLASSES)

model = CNN(NUM_CLASSES)
optimizer = optim.SGD(model.parameters(), lr=0.001)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

print(CLASSES)
print(NUM_CLASSES)
print(model, optimizer)

By following these steps, you can easily load a pre-trained model and its associated optimizer state in PyTorch. This allows you to leverage the learned weights and continue training or perform inference without starting from scratch.

# ✅ Run Inference & Save Predictions

- Load the `fashion-test` dataset.
- Apply filters to select only 'Coffee' category images with non-empty `front_back` values.
- Use `map()` to apply the `InferenceMapper` to each image in the dataset.
- The `InferenceMapper` takes the image binary file as input and outputs the predicted class, probability, and label.
- The `save()` method is used to save the dataset with the predicted values as `fashion-predictions`.

## Define the inference classes

In [None]:
#### Inference #####

class Predictions(DataModel):
    pred_class: int
    pred_proba: float
    pred_label: str

class InferenceMapper():

    def __init__(self, model, classes):
        self.model = model
        self.classes = classes 
        
        self.model.eval()

    def predict(self, file) -> Predictions:

        img_raw = file.read()
        img = convert_image(img_raw, transform=transform).unsqueeze(0)
        
        with torch.no_grad():
            # emb = model(img)
            outputs = self.model(img)
        probs = torch.nn.functional.softmax(outputs, dim=1)
        confidence, predicted_classes = torch.max(probs, 1)

        return Predictions(
            pred_class=predicted_classes.item(),
            pred_proba=confidence.item(),
            pred_label=self.classes[predicted_classes.item()]
        )

inference_eng = InferenceMapper(model, CLASSES) 

- Define a `Predictions` class that inherits from `DataModel` to store the predicted class, probability, and label.
- Define an `InferenceMapper` class to perform the inference logic.

## Run Inference

In [None]:
(
    DataChain.from_dataset("fashion-test")
    .settings(parallel=-1)    
    .map(predictions=lambda file: inference_eng.predict(file), output=Predictions)
    .save("fashion-predictions")
    .show()
)

## Explore predictions

In [None]:
DataChain.from_dataset("fashion-predictions").select("predictions").show(10)

In [None]:
DataChain.from_dataset("fashion-predictions").to_pandas().shape

## Delta Updates 

- Assume we already calculated predictions for 'old' dataset (for this example assume it's all images added before 2017)
- The 'new' dataset contains some images added after 2017
- We want to calculate inference for this delta updates only

In [None]:
# Prepare datasets

old = (
    DataChain.from_dataset("fashion-product-images")
    .filter(C("masterCategory") == "Apparel" and C("subCategory") == "Topwear")
    .filter(C("year") < 2017)
    # .limit(3)
)
print(old.to_pandas().shape)

new = (
    DataChain.from_dataset("fashion-product-images")
    .filter(C("masterCategory") == "Apparel" and C("subCategory") == "Topwear")
)
print(new.to_pandas().shape)

update = new.subtract(old)
print(update.to_pandas().shape)


In [None]:
update_predictions = (
    update
    .map(predictions=lambda file: inference_eng.predict(file), output=Predictions)
)

df = update_predictions.to_pandas()
print(df.shape)
df.head()

In [None]:
# Size of Predictions before the union (old)

DataChain.from_dataset("fashion-predictions").to_pandas().shape

In [None]:
# Add (union) new predictions and save (as a new version)

(
    DataChain.from_dataset("fashion-predictions")
    .union(update_predictions)
    .save("fashion-predictions")
)

In [None]:
# Size of Predictions after union

DataChain.from_dataset("fashion-predictions").to_pandas().shape

# 🔍 Analyze Predictions

After running inference on your test dataset and generating predictions, it's crucial to analyze the results to gain insights into your model's performance. 

DataChain provides powerful tools to filter and examine the predictions, allowing you to identify correct predictions, incorrect predictions, and explore different confidence levels. Let's break down the code:

## Show correct predictions

- Load the "fashion-predictions" dataset, which contains the predicted labels and probabilities along with the ground truth labels.
- Filter the dataset to include only the correct predictions where the predicted label (`pred_label`) matches the ground truth label (`usage`).

In [None]:
# Show correct predictions

(
    DataChain.from_dataset("fashion-predictions")
    .filter(C("usage") == C("predictions").pred_label)
    .show(5)
)

## Show incorrect predictions

- Filter the dataset to include only the incorrect predictions where the predicted label (`pred_label`) does not match the ground truth label (`usage`).
- Assign the filtered dataset to the variable `mistakes`.

In [None]:
# Show incorrect predictions

mistakes = (
    DataChain.from_dataset("fashion-predictions")
    .filter(C("usage") != C("predictions").pred_label)
    .select("file", "usage", "predictions")
)
mistakes.show(5)

In [None]:
df = mistakes.to_pandas()
print(df.shape)

df["predictions.pred_proba"].hist()

## Analyze high-confidence and low-confidence mistakes

- Filter the `mistakes` dataset
    -  `high-confidence` are mistakes where the predicted probability (`pred_proba`) is more than 0.5.
    -  `low-confidence` are mistakes where the predicted probability (`pred_proba`) is less than 0.5.
- Order the filtered dataset by the predicted probability in ascending order using `order_by(C("pred_proba").asc())`.
- Convert the filtered and ordered dataset to a pandas DataFrame using `to_pandas()`.
- Print the shape of the resulting DataFrame to see the number of low-confidence mistakes.
- Display the first 3 low-confidence mistakes using `head(3)`.

In [None]:
# Find high-confidence mistakes

high_conf_mist =(
    mistakes
    .filter(C("predictions.pred_proba") > 0.85 )
    .order_by(C("predictions.pred_proba").desc())
)

print(high_conf_mist.to_pandas().shape)
high_conf_mist.show(5)

In [None]:
# Find low-confidence mistakes

low_conf_mist = (
    mistakes
    .filter(C("predictions.pred_proba") < 0.5 )
    .order_by(C("predictions.pred_proba").asc())
)

print(low_conf_mist.to_pandas().shape)
low_conf_mist.show(3)

# 📊 Visualize mistakes

We can use `DataChain.collect()` to extract the values from the sample. Here's an example of collecting a subset of column signals from the sample:

In [None]:
sample_results = list(high_conf_mist.collect())
sample_results[0]

The example has an output for each signal:
- `file` returns a special `ImageFile` object (see below).
- `usage` returns an original target class name
- `predictions` returns prediction data fields. 

DataChain knows to treat `file` as an `ImageFile` because we created the chain for the image files with `DataChain.from_storage(..., type="image")`. `ImageFile` is a "DataModel" in Datachain, and you can use `.read()` to get its value, which for `ImageFile` returns the image itself.

In [None]:
example = sample_results[0]
example[0].read()

In [None]:
# Prepare utility function to display sample results

import matplotlib.pyplot as plt
import numpy as np

def display_image_matrix(items, top):
    # Constants for the layout
    columns = 5
    rows = int(np.ceil(top / columns))

    # Create subplots
    fig, ax_arr = plt.subplots(rows, columns, figsize=(15, 3 * rows))
    fig.suptitle("Displaying images", fontsize=20)
    ax_arr = ax_arr.flatten() # Flatten the array of axes, in case of a single row

    # Plot images
    for i in range(top):
        if i < len(items):  # Check to avoid index error if less items than top
            img = items[i][0].read()  # Retrieve the image
            true_class = items[i][1]
            preds = items[i][2].model_dump()
            pred_class = preds['pred_label']

            # Set image and title
            ax_arr[i].imshow(img)
            ax_arr[i].set_title(f"True: {true_class}\nPred: {pred_class}", fontsize=14, backgroundcolor='white')
            ax_arr[i].axis('off')  # Hide axes
        else:
            ax_arr[i].axis('off')  # Hide unused subplots

    # Adjust layout and padding
    plt.tight_layout(pad=2.0)
    plt.show()

## Explore Low Confidence Mistakes

In [None]:
# Low Confidence Mistakes

TOP = 10
# for item in low_conf_mist.limit(TOP).collect():
#     display(item[0].read())
    
items = list(low_conf_mist.limit(TOP).collect())
display_image_matrix(items, TOP)

## Explore High Confidence Mistakes

In [None]:
## How Confidence Mistakes

TOP = 10
# for item in high_conf_mist.limit(TOP).collect():
#     display(item[0].read())

items = list(high_conf_mist.limit(TOP).collect())
display_image_matrix(items, TOP)

## Explore Correct Predictions

In [None]:
correct_preds = (
    DataChain.from_dataset("fashion-predictions")
    .filter(C("usage") == C("predictions.pred_label"))
    .order_by(C("predictions.pred_proba").desc())
    .select("file", "usage", "predictions")
)

print(correct_preds.to_pandas().shape)
correct_preds.show(5)

In [None]:
TOP = 10
# for item in high_conf_mist.limit(TOP).collect():
#     display(item[0].read())

items = list(correct_preds.limit(TOP).collect())
display_image_matrix(items, TOP)

# ☁️ Run in Studio (SaaS)

<a href="https://datachain.ai/">
    <img src="static/images/studio.png" alt="DataChain Studio SaaS" style="width: 600px;"/>
</a>

To run these examples in Studio, follow the guide.

1. Open Studio / YOUR_TEAM / `datasets` workspace
2. Create a new Python Script
3. Copy/paste scripts from this Jupyter Notebook
4. Specify Settings:
   - Computer vision dependencies
   - pre-trained model `model.pth`
   - import `train.py` module
5. Click the Run button


# 🎉 Summary 

**🌟 Congratulations! You've Successfully Completed the Inference Jobs with DataChain Tutorial! 🌟**

In this tutorial, you've gained a wealth of knowledge and skills that will elevate your computer vision projects to new heights. Let's recap the key topics covered:

1. 🔍 **Running Inference:** You unleashed the power of your model to make predictions on real-world fashion images.
2. 📊 **Saving Predictions:** You discovered how to save model predictions, probabilities, and true labels into a DataChain dataset for further analysis.
3. 🔍 **Analyzing Predictions:** You explored DataChain's querying and filtering capabilities to identify correct predictions, high-confidence mistakes, and low-confidence mistakes, gaining valuable insights into your model's performance.


Once again, congratulations on your incredible achievement! 

## What's Next?

Keep exploring, experimenting, and pushing the boundaries of what's possible in computer vision.  Check out the next parts of our tutorial series:
- 📚 Check other tutorials and User Guides
- 🔍 Try DataChain on your projects

By mastering these techniques, you'll be well on your way to building powerful and efficient computer vision pipelines with DataChain.

## Get Involved

We'd love to have you join our growing community of DataChain users and contributors! Here's how you can get involved:
- ⭐ Give us a star on [GitHub](https://github.com/iterative/datachain) to show your support
- 🌐 Visit the [datachain.ai website](https://datachain.ai/) to learn more about our products and services
- 📞 Contact us to discuss how DataChain can help streamline your company's ML workflows
- 🙌 Follow us on social media for the latest updates and insights

Thanks for choosing DataChain, and happy coding! 😄