# Notebook to Create Example Predictions

In [None]:
import torch
from models.msnet import MSNet
import numpy as np
from utils.evaluation import visualise_batch_predictions
from skimage.io import imread
import kornia.augmentation as K

Settings

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bands = [0, 1, 2, 3]
batchsize= 10

### Load Model and Trained Weights

**NOTE**: remember to download the weights place them in the root directory

In [None]:
model = MSNet(num_classes=2).to(device)
model.load_state_dict(torch.load('model_weights.pth', map_location=device))
model.eval()

## Predict for One Single Image

In [None]:
# Load image and label in shape expected by the model [Batch size, Channels, Height, Width]

image = torch.tensor(imread("image_10016.tif").astype(np.float32)).permute([2,0,1])[None,...]
label = torch.tensor(imread("label_10016.tif").astype(np.float32))[None,None,...]

# Set mean and standard deviation of each band in the training set for standardization 
means = torch.tensor([265.7371, 445.2234, 393.7881, 2773.2734])
stds = torch.tensor([91.8786, 110.0122, 191.7516, 709.2327])

# Pre-process the image
preprocess = K.container.AugmentationSequential(
        K.Resize((224, 224)),
        K.Normalize(mean=means, std=stds),
        data_keys=["image", "mask"],
    )
image, label = preprocess(image, label)

In [None]:
# predict for single image
predictions = model(image.to(device))
predictions = predictions.argmax(dim=1, keepdim=True)

# visualise predictions
visualise_batch_predictions(image, label, predictions, rescale=True, bands=bands)

## Predict with Batch-wise on Full Testing Loader

**NOTE**: You must adjust the PATH_TO_DATA in the config/config.py file to refer to a folder containing the dataset

In [None]:
from utils.data import create_dataloaders
from config import config

DATA_PATH = config.PATH_TO_DATA

In [None]:
# create iterable test_loader
_, _, test_loader = create_dataloaders(DATA_PATH, batch_size=batchsize, bands=bands)
loader = iter(test_loader)

NOTE: You can rerun this following last cell to "flip" through the iterable data loader and inspect all test images

In [None]:
# get data
batch_sample, batch_mask = next(loader)
batch_sample, batch_mask = batch_sample.to(device), batch_mask.to(device)
# create prediction
batch_output = model(batch_sample)
batch_predictions = batch_output.argmax(dim=1)
# visualise prediction 
visualise_batch_predictions(batch_sample, batch_mask.unsqueeze(1), batch_predictions.unsqueeze(1), rescale=True, bands=bands)