# Classifying images

In this excercise, we're going to classify images. You'll need to:
1. Read images from S3
2. Preprocess a dataset.
3. Implement a custom `Predictor`.
4. Use a pre-trained model to generate predictions.

Make sure to reference [the latest version of the AIR documentation](https://docs.ray.io/en/master/ray-air/getting-started.html).

### Task 1: Read images from S3

First, let's load our image data. We're going to be working with a subset of ImageNet that contains one image of each class. 

Read the images at `s3://air-example-data-2/imagenet-sample-images/` into a [Ray Dataset](https://docs.ray.io/en/master/data/api/dataset.html). Your dataset should contains 1000 rows, and its representation should look like `Dataset(num_blocks=..., num_rows=1000, schema={image: ..., ...})`.

In [None]:
from ray.data import Dataset

dataset: Dataset = ...

assert dataset.count() == 1000
dataset

In [None]:
import ray
from ray.data import Dataset
from ray.data.datasource import ImageFolderDatasource

dataset: Dataset = ray.data.read_datasource(ImageFolderDatasource(), root="s3://air-example-data-2/imagenet-sample-images", size=(224, 224))

assert dataset.count() == 1000
dataset

### Task 2: Preprocess images

Our pretrained model expects inputs to be normalized. If we don't normalize the images, our model won't perform well.

Apply `transform` to every image in the dataset.

In [None]:
from torchvision.transforms import Compose, ToTensor, Normalize

transform = Compose([
    ToTensor(), 
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])    
])

transformed_dataset: Dataset = ...

assert all(record["image"].shape[0] == 3 for record in transformed_dataset.take_all())
transformed_dataset

In [None]:
from ray.data.preprocessors import BatchMapper
from torchvision.transforms import Compose, ToTensor, Normalize

transform = Compose([
    ToTensor(), 
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])    
])

def preprocess(df):
    df.loc[:, "image"] = [transform(image) for image in df["image"]]
    return df

preprocessor = BatchMapper(preprocess)
transformed_dataset: Dataset = preprocessor.fit_transform(dataset)

assert all(record["image"].shape[0] == 3 for record in transformed_dataset.take_all())
transformed_dataset

### Task 3: Extend `TorchPredictor`

`resnet101` returns confidence scores rather than labels. In the code snippet below, the model returns 1000 logits for each input image. These logits represent the model's confidence that an image is a particular class.

In [12]:
from ray.train.torch import TorchPredictor
from torchvision.models import resnet101

model = resnet101(pretrained=True)
predictor = TorchPredictor(model)

batch = next(transformed_dataset.iter_batches(batch_size=4))["image"].to_numpy()
outputs = predictor.predict(batch)
outputs.shape



(4, 1000)

Logits aren't relevant to this excercise. So, let's extend the built-in `TorchPredictor` class to return labels instead.

Implement `CustomTorchPredictor.call_model`. Your implementation should return a tensor containing the predicted label for each in image in the batch. 

In [None]:
import torch

class CustomTorchPredictor(TorchPredictor):

    def call_model(self, tensor: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError
        

predictor = CustomTorchPredictor(model)
predictions = predictor.predict(batch)

assert predictions.shape == (4,)
predictions

In [None]:
import torch

class CustomTorchPredictor(TorchPredictor):

    def call_model(self, tensor: torch.Tensor) -> torch.Tensor:
        outputs = super().call_model(tensor)
        return torch.argmax(outputs, axis=1)


predictor = CustomTorchPredictor(model)
predictions = predictor.predict(batch)

assert predictions.shape == (4,)
predictions

**HINT**: Use `torch.argmax` to get predicted labels from model outputs.

In [None]:
import torch

outputs = model(torch.zeros(4, 3, 256, 256))
assert outputs.shape == (4, 1000)
predictions = torch.argmax(outputs, dim=1)
assert predictions.shape == (4,)

### Task 4: Make predictions for the entire dataset

Now that we've preprocessed our dataset and implemented a custom predictor, we can finally classify the images.

Classify all of the images in the dataset, and assign `predictions` to a dataset that describes the predicted labels. The dataset representation should like  looks like `Dataset(num_blocks=..., num_rows=1000, schema={predictions: int64})`.

In [None]:
predictions: Dataset = ...
predictions

In [None]:
from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchCheckpoint

checkpoint = TorchCheckpoint.from_model(model)
batch_predictor = BatchPredictor(checkpoint, CustomTorchPredictor)
predictions: Dataset = batch_predictor.predict(transformed_dataset, feature_columns=["image"])
predictions

If you did everything correctly, your model should classify 87.6% of the images correctly.

In [None]:
def score(outputs: Dataset) -> float:
    assert outputs.count() == 1000
    predicted_labels = [record["predictions"] for record in predictions.take_all()]
    return sum(label == expected_label for expected_label, label in enumerate(predicted_labels)) / 1000


assert score(predictions) == 0.876