# Chip Classification using EuroSAT - Predict

This notebook demonstrates prediction using a chip classifier trained in `train-eurosat` on a Sentinel 2 dataset called [EuroSAT](https://github.com/phelber/EuroSAT). Note that using the [wandb logger](https://wandb.ai/) only requires a free account

## Environment Setup 

Refer to README.md for environment setup. 

In [None]:
import os

# If using LightningAI, change the current working directory to the directory containing this notebook. 
REPO_DIR = "/teamspace/studios/this_studio/eda-bids-hackathon-prep/"  # Adjust as appropriate
if os.path.exists(REPO_DIR):
    os.chdir(os.path.join(REPO_DIR, "sentinel2-modelling"))

In [None]:
import os
import tempfile
from typing import Dict, Optional, Any
from typing import Callable, Optional, cast
from glob import glob

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader

from torchgeo.datasets import EuroSAT
from torchgeo.datamodules import EuroSATDataModule
from torchgeo.transforms import AugmentationSequential, indices
from torchgeo.trainers import ClassificationTask
from torchgeo.models import ResNet18_Weights

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

seed_everything(543)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

# Load EDS credentials from .env file
from dotenv import load_dotenv
load_dotenv()

In [None]:
BANDS = ('B04', 'B03', 'B02') # make sure these match the model trained

## Predictions

Having trained a model from the `train-eurosat` notebook, we will now predict with it

In [None]:
if device == "cuda":
    batch_size = 128
    num_workers = 8
elif device ==  "cpu":
    batch_size = 64
    num_workers = 0
else:
    print("unknown device!")

datamodule = EuroSATDataModule(
    batch_size=batch_size, 
    root="data", 
    num_workers=num_workers,
    bands=BANDS,
    download=True,
)

Download a model checkpoint from wandb or point to a local checkpoint - note an [issue](https://github.com/microsoft/torchgeo/issues/1639) with the SENTINEL2_ALL_MOCO & RGB weights 

In [None]:
ckpt_path = '/teamspace/studios/this_studio/wandb_logs/eurosat/6knkh8o7/checkpoints/epoch=4-step=130.ckpt'
ckpt_path

In [None]:
task = ClassificationTask.load_from_checkpoint(ckpt_path, map_location=torch.device(device))

In [None]:
datamodule.setup(stage="test")

In [None]:
trainer = Trainer(
    # limit_predict_batches=1 # for a single batch only
)

In [None]:
test_results = trainer.test(model=task, dataloaders=datamodule) # 

# Inference a single image
Note that since the data was not normalised prior to training, it is possible to pass through an image without normalisation, but we will do so to be safe

In [None]:
%%time
sample = datamodule.test_dataset[2500]
label = cast(int, sample["label"].item())
image = sample['image'].unsqueeze(0).to(device)
pred = task(image)
pred_index = int(torch.argmax(pred))

result_str = f"label: {datamodule.test_dataset.classes[label]}, prediction: {datamodule.test_dataset.classes[pred_index]}"
fig = datamodule.test_dataset.plot(sample, suptitle=result_str)

## Generate confusion matrix

In [None]:
y_true = []
y_pred = []

task.eval()

with torch.no_grad():
    for batch in datamodule.test_dataloader():
        images = batch['image'].to(device)
        labels = batch['label'].to(device)

        preds = task(images)
        preds_indices = torch.argmax(preds, dim=1).cpu().numpy()

        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds_indices)

In [None]:
# Generate confusion matrix
cm = confusion_matrix(y_true, y_pred) # , labels=datamodule.test_dataset.classes

cm

In [None]:
# Define function to plot confusion matrix
def plot_confusion_matrix(cm, labels):
    fig, ax = plt.subplots(figsize=(8, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.show()

# Generate the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plot the confusion matrix
plot_confusion_matrix(cm, datamodule.test_dataset.classes)