# Classifying Satellite Imagery with GNNs
## CIS700

### Matt Graber

### Obtaining the data

First, we will obtain the data. We will be attempting to classify the Corrected Reflectance (True Color) Suomi NPP / VIIRS product available from the Global Imagery Browse Services (GIBS) API. To label the data for training and testing, we will be using the Clear Sky Confidence (Day) Suomi NPP / VIRRS product and the Land Water Map (OSM) (from Open Street Map served through GIBS).

Corrected Reflectance (True Color) Layer example: https://gibs.earthdata.nasa.gov/wms/epsg4326/best/wms.cgi?version=1.3.0&service=WMS&request=GetMap&format=image/png&STYLE=default&bbox=-90,-180,90,180CRS=EPSG:4326&HEIGHT=600&WIDTH=600&TIME=2022-05-01&layers=VIIRS_SNPP_CorrectedReflectance_TrueColor

Clear Sky Confidence Layer Example: https://gibs.earthdata.nasa.gov/wms/epsg4326/best/wms.cgi?version=1.3.0&service=WMS&request=GetMap&format=image/png&STYLE=default&bbox=-90,-180,90,180CRS=EPSG:4326&HEIGHT=600&WIDTH=600&TIME=2022-05-01&layers=VIIRS_SNPP_Clear_Sky_Confidence_Day

Land Water Map Layer Example: https://gibs.earthdata.nasa.gov/wms/epsg4326/best/wms.cgi?version=1.3.0&service=WMS&request=GetMap&format=image/png&STYLE=default&bbox=-90,-180,90,180CRS=EPSG:4326&HEIGHT=600&WIDTH=600&TIME=2022-05-01&layers=OSM_Land_Water_Map

In [27]:
import cv2
import csv
import numpy as np
import datetime
import time
import os
from skimage import io
from PIL import Image as plimg

In [None]:
layers = ["VIIRS_SNPP_CorrectedReflectance_TrueColor", "VIIRS_SNPP_Clear_Sky_Confidence_Day"]
startdate = datetime.date(2022,5,1)
enddate = datetime.date(2022,5,4)
img_extent_step = 5
resolution = 128

for layer in layers:
    print("Downloading {} images...".format(layer))
    layer_outdir = os.path.join(os.getcwd(), "images", layer)
    currentdate = startdate

    while currentdate < enddate:
        outdir = os.path.join(layer_outdir, str(currentdate))
        
         # Create directory if it doesn't exist yet
        if not os.path.exists(outdir):
            os.makedirs(outdir, exist_ok=True)
            
        print("Downloading images for {}...".format(currentdate))
        for longitude in range(-180,180,img_extent_step):
            for latitude in range(-90,90,img_extent_step):
                extents = "{0},{1},{2},{3}".format(latitude, longitude,
                                                latitude + img_extent_step,
                                                longitude + img_extent_step)
                outfilepath = os.path.join(outdir,'{0}_{1}_{2}.png'.format(layer, currentdate, extents))
                # Skip any files that have already been downloaded
                # (this enables quick resumption if connection errors are encountered).
                # put this in a while-loop in case there's a connection error and
                # the download for something needs to be retried
                while not os.path.exists(outfilepath) or cv2.imread(outfilepath) is None:
                    # Construct image URL.
                    url = 'https://gibs.earthdata.nasa.gov/wms/epsg4326/best/wms.cgi?\
version=1.3.0&service=WMS&request=GetMap&\
format=image/png&STYLE=default&bbox={0}&CRS=EPSG:4326&\
HEIGHT={3}&WIDTH={3}&TIME={1}&layers={2}'.format(extents, currentdate, layer, resolution)
                    
                    # Occasionally we get an error from a momentary dropout of internet connection or something.
                    # This try-except should 
                    try:
                        # Request and save image
                        img = plimg.fromarray(io.imread(url))
                        img.save(outfilepath)
                    except:
                        print("Error encountered, retrying")
                        time.sleep(5)

        currentdate += datetime.timedelta(1)

# OSM_Land_Water_Map is a static layer, meaning that we don't need to re-download it for every day.
layer = "OSM_Land_Water_Map"
print("Downloading {} images...".format(layer))
outdir = os.path.join(os.getcwd(), "images", "{}".format(layer))

# Create directory if it doesn't exist yet
if not os.path.exists(outdir):
    os.mkdir(outdir)

for longitude in range(-180,180,img_extent_step):
    for latitude in range(-90,90,img_extent_step):
        extents = "{0},{1},{2},{3}".format(latitude, longitude,
                                        latitude + img_extent_step,
                                        longitude + img_extent_step)
        outfilepath = os.path.join(outdir,'{0}_{1}.png'.format(layer, extents))
        # Skip any files that have already been downloaded
        # (this enables quick resumption if connection errors are encountered)
        while not os.path.exists(outfilepath) or cv2.imread(outfilepath) is None:
            # Construct image URL.
            url = 'https://gibs.earthdata.nasa.gov/wms/epsg4326/best/wms.cgi?\
version=1.3.0&service=WMS&request=GetMap&\
format=image/png&STYLE=default&bbox={0}&CRS=EPSG:4326&\
HEIGHT={2}&WIDTH={2}&layers={1}'.format(extents, layer, resolution)
            # Occasionally we get an error from a momentary dropout of internet connection or something.
            # This try-except should 
            try:
                # Request and save image
                img = plimg.fromarray(io.imread(url))
                img.save(outfilepath)
            except:
                print("Error encountered, retrying")
                time.sleep(5)

Downloading VIIRS_SNPP_CorrectedReflectance_TrueColor images...
Downloading images for 2022-05-01...
Downloading images for 2022-05-02...
Downloading images for 2022-05-03...
Downloading VIIRS_SNPP_Clear_Sky_Confidence_Day images...
Downloading images for 2022-05-01...


libpng error: IDAT: CRC error
libpng error: bad adaptive filter value


Downloading images for 2022-05-02...
Downloading images for 2022-05-03...
Downloading OSM_Land_Water_Map images...


### Labeling the data
Next, we will label each Corrected Reflectance image to indicate whether it contains land, water, and/or clouds to use as the training data for the neural networks. We will do this based on the percentages of colors in the corresponding images from Clear Sky Confidence and Land Water Map.

In [57]:
labeled_data_filename = "labeled_data.csv"

layer_to_label_path = os.path.join("images","VIIRS_SNPP_CorrectedReflectance_TrueColor")
clear_sky_layer_path = os.path.join("images", "VIIRS_SNPP_Clear_Sky_Confidence_Day")
land_water_map_path = os.path.join("images", "OSM_Land_Water_Map")
lw_filelist = os.listdir(land_water_map_path)

resolution = 128
pixel_count = resolution ** 2
# Exclude any images where 40% or more of the image is "no data"
nodata_threshold = pixel_count * 0.6

# dict for memoization of land water map results, since this is a static layer
lw_results = {}

with open(labeled_data_filename, 'w') as f:
    writer = csv.writer(f)
    writer.writerow(["filepath", "weather", "terrain"])
    for date in os.listdir(layer_to_label_path):
        print("Labeling for date {}...".format(date))
        co_re_datepath = os.path.join(layer_to_label_path, date)
        cl_sk_datepath = os.path.join(clear_sky_layer_path, date)
        co_re_filelist = os.listdir(co_re_datepath)
        cl_sk_filelist = os.listdir(cl_sk_datepath)
        for i in range(len(co_re_filelist)):
            # these directories should be ordered the same
            co_re_imgpath = os.path.join(co_re_datepath, co_re_filelist[i])
            cl_sk_imgpath = os.path.join(cl_sk_datepath, cl_sk_filelist[i])
            
            csv_row = [co_re_imgpath]

            # First, check if the corrected reflectance image is in an area of "no data"
            # i.e. it's all or mostly pure black.
            # We want to skip these images.
            co_re_img = cv2.imread(co_re_imgpath, 0) # use 0 flag to read grayscale
            if cv2.countNonZero(co_re_img) < nodata_threshold:
                continue

            # Next, check if the image is mostly cloudy or not cloudy.
            # In this layer, the reddish color (the higher pixel value) corresponds to clear skies,
            # and the whiteish color (the lower pixel value) corresponds to cloudy skies.
            cl_sk_img = cv2.imread(cl_sk_imgpath, 0)
            
            # If there are more dark-colored pixels than light-colored pixels, then it's not cloudy.
            if cv2.countNonZero(cv2.inRange(cl_sk_img, 0, 127)) > cv2.countNonZero(cv2.inRange(cl_sk_img, 128, 255)):
                # clear skies
                csv_row.append("clear")
            else:
                # cloudy skies
                csv_row.append("cloudy")

            # Finally, check if the image is mostly land or water.
            lw_imgpath = os.path.join(land_water_map_path, lw_filelist[i])
            if lw_imgpath in lw_results.keys():
                csv_row.append(lw_results[lw_imgpath])
            else:
                lw_img = cv2.imread(lw_imgpath, 0)
                # If there are more light-colored pixels than dark-colored pixels, then its mostly water
                if cv2.countNonZero(cv2.inRange(lw_img, 128, 128)) > cv2.countNonZero(cv2.inRange(lw_img, 75, 75)):
                    lw_results[lw_imgpath] = "water"
                else:
                    lw_results[lw_imgpath] = "land"
                csv_row.append(lw_results[lw_imgpath])
            writer.writerow(csv_row)

print("Labeling complete!")


Labeling for date 2022-05-01...
Labeling for date 2022-05-02...
Labeling for date 2022-05-03...
Labeling complete!


### Convolutional Neural Network (non-graph)
Now that our data is labeled, we can start by training a baseline non-graph convolutional neural network classifier.

In [58]:
import torch as tr
from torch import nn
from torch.nn import functional as F
from torchvision import transforms, models
import random
import ast

# convert from strings to int representation of the labels
def encode(label):
    if label in ['clear', 'water']:
        return 0
    else: # cloudy, land
        return 1

# Used for obtaining the training/testing data
def load_data(filename):
    imgs = []
    weather = []
    terrain = []
    with open(filename) as datacsv:
        reader = csv.DictReader(datacsv)
        for row in reader:
            imgs.append(row["filepath"])
            weather.append(row["weather"])
            terrain.append(row["terrain"])
    shufflelist = list(zip(imgs, weather, terrain))
    random.shuffle(shufflelist)
    imgs, weather, terrain = zip(*shufflelist)
    imgs, weather, terrain = list(imgs), list(weather), list(terrain)
    # split into training and test data (use 3/4 for training, 1/4 for testing)
    split_size = int(0.75 * len(imgs))
    training_data = (imgs[:split_size], weather[:split_size], terrain[:split_size])
    testing_data = (imgs[split_size:], weather[split_size:], terrain[split_size:])
    return training_data, testing_data

class CorrectedReflectanceDataset(tr.utils.data.Dataset):
    def __init__(self, data):
        self.imgs, self.weather, self.terrain = data
    
    def __getitem__(self, idx):
        # take the data sample by its index
        img = Image.open(self.imgs[idx])

        # Normalize the image and convert to tensor
        transform1 = transforms.Compose([transforms.ToTensor()])
        img_tr = transform1(img)
        mean, std = img_tr.mean([1,2]), img_tr.std([1,2])
        transform2 = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
        img = transform2(img)
    
        # return the image and the associated labels
        dict_data = {
            'img': img,
            'labels': {
                'weather_labels': self.weather[idx],
                'terrain_labels': self.terrain[idx],
            }
        }
        return dict_data
    
class MultiOutputModel(nn.Module):
    def __init__(self, n_weather_classes, n_terrain_classes):
        super().__init__()
        self.base_model = models.mobilenet_v2().features  # take the model without classifier
        last_channel = models.mobilenet_v2().last_channel # size of the layer before the classifier
 
        # the input for the classifier should be two-dimensional, but we will have
        # [<batch_size>, <channels>, <width>, <height>]
        # so, let's do the spatial averaging: reduce <width> and <height> to 1
        self.pool = nn.AdaptiveAvgPool2D((1, 1))
 
        # create separate classifiers for our outputs
        self.weather = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=last_channel, out_features=n_weather_classes)
        )
        self.terrain = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=last_channel, out_features=n_terrain_classes)
        )
    
    def forward(self, x):
        x = self.base_model(x)
        x = self.pool(x)
    
        # reshape from [batch, channels, 1, 1] to [batch, channels] to put it into classifier
        x = tr.flatten(x, start_dim=1)
    
        return {
            'weather': self.weather(x),
            'terrain': self.terrain(x)
        }

    def get_loss(self, net_output, ground_truth):
        weather_loss = F.cross_entropy(net_output['weather'], ground_truth['weather_labels'])
        terrain_loss = F.cross_entropy(net_output['terrain'], ground_truth['terrain_labels'])
        loss = weather_loss + terrain_loss
        return loss, {'weather': weather_loss, 'terrain': terrain_loss}

filename = "labeled_data.csv"
training_data, testing_data = load_data(filename)

N_epochs = 50
batch_size = 16
device = 'cuda:0'
 
model = MultiOutputModel(n_weather_classes=2, n_terrain_classes=2).to(device)
 
optimizer = tr.optim.Adam(model.parameters())

for epoch in range(N_epochs):
    epoch_start = time.time()
    print("Epoch: {}/{}".format(epoch+1, N_epochs))
    # Set to training mode
    model.train()
    # Loss and Accuracy within the epoch
    train_loss = 0.0
    train_acc = 0.0
    valid_loss = 0.0
    valid_acc = 0.0
    for batch in train_dataloader:
        # Clean existing gradients
        optimizer.zero_grad()

        img = batch['img']
        target_labels = batch['labels']
        
        # Forward pass - compute outputs on input data using the model
        output = model(img.to(device))

        # Compute loss
        loss_train, losses_train = model.get_loss(output, target_labels)
        total_loss += loss_train.item()
        batch_accuracy_color, batch_accuracy_gender, batch_accuracy_article = calculate_metrics(output, target_labels)

        # Backpropagate the gradients
        loss_train.backward()
        # Update the parameters
        optimizer.step()
        
        # Compute the total loss for the batch and add it to train_loss
        train_loss += loss.item() * inputs.size(0)
        # Compute the accuracy
        ret, predictions = tr.max(outputs.data, 1)
        correct_counts = predictions.eq(labels.data.view_as(predictions))
        # Convert correct_counts to float and then compute the mean
        acc = tr.mean(correct_counts.type(tr.FloatTensor))
        # Compute total accuracy in the whole batch and add to train_acc
        train_acc += acc.item() * inputs.size(0)
        print("Batch number: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}".format(i, loss.item(), acc.item()))
