In [None]:
## Big goal of this notebook

- Load the trained UNet model.
- Preprocess the satellite imagery if necessary.
- Use the UNet model to predict snow cover on the imagery.
- Flatten the predicted labels and ground truth labels to compare them.
- Compute accuracy metrics such as accuracy score, confusion matrix, and classification report.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import rasterio
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

# Copying the model for Posterity
# Define the CNN architecture
class SnowCoverCNN(nn.Module):
    def __init__(self):
        super(SnowCoverCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=4, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 2)  # 2 output classes (snow, no snow)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool(x)
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = torch.relu(self.conv3(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
# Load the trained UNet model

# Replace 'model_checkpoint.pth' with the path to your trained model checkpoint
checkpoint_path = 'model_checkpoint.pth'
model = SnowCoverCNN()  # Assuming you have defined your UNet model class
model.load_state_dict(torch.load(checkpoint_path))
model.eval()

# Load satellite imagery and ground truth snow cover labels
satellite_image_path = 'satellite_image.tif'
ground_truth_label_path = 'ground_truth_label.tif'

with rasterio.open(satellite_image_path) as src:
    satellite_image = src.read().astype(np.float32) / 255.0  # Normalize imagery

# Preprocess the imagery if necessary (e.g., convert to tensor, add batch dimension)
satellite_image_tensor = torch.tensor(satellite_image).unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions

# Use the UNet model to predict snow cover on the imagery
with torch.no_grad():
    predicted_label = model(satellite_image_tensor)

# Convert predicted label tensor to numpy array and flatten it
predicted_label = F.softmax(predicted_label, dim=1)
predicted_label = torch.argmax(predicted_label, dim=1).squeeze().numpy()
predicted_label_flat = predicted_label.flatten()

# Load ground truth snow cover labels
with rasterio.open(ground_truth_label_path) as src:
    ground_truth_label = src.read(1)  # Assuming single-band label
    ground_truth_label_flat = ground_truth_label.flatten()

# Compute accuracy metrics
accuracy = accuracy_score(ground_truth_label_flat, predicted_label_flat)
conf_matrix = confusion_matrix(ground_truth_label_flat, predicted_label_flat)
class_report = classification_report(ground_truth_label_flat, predicted_label_flat, target_names=['snow-free', 'snow-covered'])

# Print accuracy metrics
print("Accuracy:", accuracy)
print("Confusion Matrix:\n", conf_matrix)
print("Classification Report:\n", class_report)
