## Class 10 - Image encoding and classification using CNNs and ViTs
In this class, we will use pretrained CNN and transformer models to encode and classify images.
We will use a library you are familiar with, namely Hugging Face's _transformers_, as well as (under the hood) _timm_ (another HuggingFace library, specific to computer vision) and _pytorch_.

### Instructions
**Overall goal**: training a simple fully-connected neural network that classifies snack types from this dataset (https://huggingface.co/datasets/Matthijs/snacks) based on image embeddings extracted from a pretrained ResNet model (https://arxiv.org/abs/1512.03385)

**Suggestion**: use `example.ipynb` notebook if you are in doubt! This might be especially helpful with respect to Pytorch. All we will use Pytorch for, here, are things we already did as part of NLP, but if you have not used Pytorch since it is perfectly understandable that you need a reminder.

Steps:
1. Use Hugging Face's `dataset` library to import the dataset: https://huggingface.co/datasets/Matthijs/snacks. Import the training and the test splits from this dataset, and visualize the frequency of each class.
2. Load a pretrained ResNet model using HuggingFace's `transformers`. Check the documentation (https://huggingface.co/docs/transformers/en/model_doc/resnet, https://huggingface.co/microsoft/resnet-50) for more details on how to load the model. Visualize the structure of the model, and reflect on how each of the blocks works, based on what we discussed in the lecture;
3. Extract ResNet embeddings for all the images in the training set and all the images in the test set. You will need to look for the `pooler_output` attribute of the model's output. This sounds complex, but it's really just a minor change over the examples presented in the model's documentation;
4. Optional: apply dimensionality reduction to the embeddings using TSNE (https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html): is any information about the labels encoded in the embeddings?
5. Train a simple neural network that, for each image in the training set, takes the embeddings as inputs and predicts the snack label
    - You will have to create a Pytorch `Dataset` (https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) where both your embeddings and your labels are included.
    - You will also have to create a `Dataloader` for the training set, as this will make it much easier to do things like batching and shuffling
    - You need to specify the loss and the optimizer
    - Then, you need to specify a training loop. Remember the neural net gymnastics:
        - Do the forward step, i.e., make a prediction
        - Compute the loss based on predictions and true labels 
        - Compute gradients
        - Propagate gradients (backpropagation)
6. Evaluate the resulting model:
    - Make predictions for all examples in the test set
    - Use `sklearn.metrics` utils (https://scikit-learn.org/stable/modules/model_evaluation.html#classification-metrics) to compute accuracy and F1-score
    - Also plot the confusion matrix, to see which categories are most often confused
7. Explore zero-shot classification for computer vision
    - Try classification on custom labels on one of the images from today's dataset. To do so, adapt the code provided here: https://huggingface.co/docs/transformers/en/tasks/zero_shot_image_classification, by simply replacing the desired labels and the input
    - Can you get reasonable predictions on more nuanced labels than the ones provided with the dataset?
    - Which kind of models support this behavior? Why?

In [None]:
# label dictionary: inferred from the dataset's online documentation
label_dict = {0: 'apple',
              1: 'banana',
              2: 'cake',
              3: 'candy',
              4: 'carrot',
              5: 'cookie',
              6: 'doughnut',
              7: 'grape',
              8: 'hotdog',
              9: 'icecream',
              10: 'juice',
              11: 'muffin',
              12: 'orange',
              13: 'pineapple',
              14: 'popcorn',
              15: 'pretzel',
              16: 'salad',
              17: 'strawberry',
              18: 'waffle',
              19: 'watermelon'}

### task 1

In [1]:
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset("Matthijs/snacks")

Downloading data: 100%|██████████| 76.5M/76.5M [00:01<00:00, 66.1MB/s]
Downloading data: 100%|██████████| 14.9M/14.9M [00:00<00:00, 32.7MB/s]
Downloading data: 100%|██████████| 15.2M/15.2M [00:00<00:00, 30.3MB/s]
Generating train split: 100%|██████████| 4838/4838 [00:00<00:00, 24284.46 examples/s]
Generating test split: 100%|██████████| 952/952 [00:00<00:00, 24698.01 examples/s]
Generating validation split: 100%|██████████| 955/955 [00:00<00:00, 25579.27 examples/s]


In [10]:
train = dataset['train']
test = dataset['test']

### task 2

In [19]:
from transformers import AutoImageProcessor, ResNetModel
import torch
from torchview import draw_graph

image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
model = ResNetModel.from_pretrained("microsoft/resnet-50")



In [None]:
#terminal: sudo apt-get install graphviz
#draw_graph(model, input_size=(1,3,224,224), expand_nested=True).visual_graph 

### task 3

In [20]:
def _extract(dataset, outs):
    for i, image in enumerate(dataset['image']):
        inputs = image_processor(image, return_tensors="pt")


        with torch.no_grad():
            outputs = model(**inputs)

        embedding = outputs.pooler_output
        outs.append(embedding.squeeze())
    return torch.stack(tuple(outs))

In [21]:
train_embeddings = []
test_embeddings = []

train_embeddings = _extract(train, train_embeddings)
test_embeddings = _extract(test, test_embeddings)

### task 5

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim

In [23]:
class Embeddings(Dataset):
    def __init__(self, labels, embeddings):
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.embedding = embeddings

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.embedding[idx]
        label = self.labels[idx]
        return image, label

In [24]:
train_dataloader = DataLoader(Embeddings(embeddings=train_embeddings, 
                                         labels=train['label']),
                              batch_size=32, shuffle=True)

### task 6

In [25]:
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, 128)
        self.fc3 = nn.Linear(128, 20)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
net = Classifer()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

#### Bonus CV tasks
- Try training a classifier on embeddings extracted from a vision transformer: how does performance compare to ResNet?
- Build a CNN from scratch for this classification task: how does performance compare? There are lots of tutorials online, you can start from this one: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html.
- Look up models and pipelines for more computer vision tasks, including semantic segmentation, object detection (which can also be performed zero-shot using ViT), etc.

#### Bonus NN-related tasks
- Train NN-based models (FFNNs or RNNs) on the bike data: how does performance compare to the best algorithms?

### Additional tasks
- A wide array of pretrained models for audio classification (including zero-shot audio classification) are also available on Hugging Face. Experiment with those :)

### Additional notes
- Here, we worked with pretrained models. If you want to train a CNN from scratch (e.g., because you are working with "special" image data, and no pretrained models are available), you can look up, for example, _Pytorch_'s intro tutorial (https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html). There are really a billion resources online if this is what you decide to focus on for your final project;
- If, instead, fine-tuning a full model is what you want to do, plenty of tutorials are available for tasks ranging from image classification, object detection, image segmentation, etc. Look up _timm_'s documentation (https://huggingface.co/docs/hub/en/timm) and _transformers_'s documentation (https://huggingface.co/docs/transformers/tasks/image_classification).