In [None]:
import torch, os
from torch.utils.data import Dataset
import shutil
import torch.nn as nn
import fsspec
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import rasterio as rio
import pandas as pd, numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import r2_score
import sys
sys.path.append("/content")
from src.utils import tensor_to_rgb, plot_image

os.chdir("/content")

In [None]:
with open("/content/credentials") as f:
    env_vars = f.read().split("\n")

for var in env_vars:
    key, value = var.split(" = ")
    os.environ[key] = value

storage_options = {"account_name":os.environ["ACCOUNT_NAME"],
                   "account_key":os.environ["BLOB_KEY"]}

In [None]:
labels = pd.read_csv(
    f"az://modeling-data/partitioned_feature_data_buffer500m_daytol8_cloudthr80percent_lulcndvi_masking.csv",
    storage_options=storage_options
)

not_enough_water = labels["n_water_pixels"] <= 20
labels.drop(not_enough_water[not_enough_water].index, inplace=True)

labels["raw_img_az_path"] = [f'az://{x.split(".net/")[1]}' for x in labels["raw_img_chip_href"]]
labels["water_az_path"] = [f'az://{x.split(".net/")[1]}' for x in labels["water_chip_href"]]

aggregate_features = [
        "is_brazil", "sentinel-2-l2a_AOT", "sentinel-2-l2a_B02",
        "sentinel-2-l2a_B03", "sentinel-2-l2a_B04", "sentinel-2-l2a_B08",
        "sentinel-2-l2a_WVP", "sentinel-2-l2a_B05", "sentinel-2-l2a_B06",
        "sentinel-2-l2a_B07", "sentinel-2-l2a_B8A", "sentinel-2-l2a_B11",
        "sentinel-2-l2a_B12", "mean_viewing_azimuth", "mean_viewing_zenith", 
        "mean_solar_azimuth", "mean_solar_zenith"
    ]

train_labels = labels.loc[
    labels["partition"] == "train",
    ["region", "site_no", "sample_id", "SSC (mg/L)", "is_brazil", 
     "raw_img_az_path", "water_az_path", "mean_viewing_azimuth",
     "mean_viewing_zenith", "mean_solar_azimuth", "mean_solar_zenith",
     "sentinel-2-l2a_AOT", "sentinel-2-l2a_B02",
     "sentinel-2-l2a_B03", "sentinel-2-l2a_B04", "sentinel-2-l2a_B08",
     "sentinel-2-l2a_WVP", "sentinel-2-l2a_B05", "sentinel-2-l2a_B06",
     "sentinel-2-l2a_B07", "sentinel-2-l2a_B8A", "sentinel-2-l2a_B11",
     "sentinel-2-l2a_B12"]
]
validation_labels = labels.loc[
    labels["partition"] == "validate",
    ["region", "site_no", "sample_id", "SSC (mg/L)", "is_brazil", 
     "raw_img_az_path", "water_az_path", "mean_viewing_azimuth",
     "mean_viewing_zenith", "mean_solar_azimuth", "mean_solar_zenith",
     "sentinel-2-l2a_AOT", "sentinel-2-l2a_B02",
     "sentinel-2-l2a_B03", "sentinel-2-l2a_B04", "sentinel-2-l2a_B08",
     "sentinel-2-l2a_WVP", "sentinel-2-l2a_B05", "sentinel-2-l2a_B06",
     "sentinel-2-l2a_B07", "sentinel-2-l2a_B8A", "sentinel-2-l2a_B11",
     "sentinel-2-l2a_B12"]
]

scaler = MinMaxScaler().fit(np.array(train_labels.loc[:, aggregate_features]))


In [None]:
RIO_BANDS_ORDERED = {
    "aot":1, 
    "blue":2, 
    "green":3, 
    "red":4, 
    "nir":5, 
    "wvp":6,
    "rededge1":7, 
    "redege2":8, 
    "rededge3":9, 
    "rededge4":10, 
    "swir1":11, 
    "swir2":12
}

class FluviusDataset(Dataset):
    def __init__(
        self,
        labels,
        storage_options,
        scaler,
        bands=["aot", "blue", "green", "red", "nir"],
        aggregate_features=[],
        include_water=False,
        transform=None,
        download_data=False
    ):
        self.labels = labels
        if len(aggregate_features) != 0:
            self.aggregate_features = scaler.transform(np.array(labels.loc[:, aggregate_features]))
        else:
            self.aggregate_features = []
        
        self.transform = transform
        self.rio_band_idx = [RIO_BANDS_ORDERED.get(x) for x in bands]
        self.az_base_dir = os.path.dirname(
                os.path.dirname(train_labels["raw_img_az_path"][0])
            )[5:]
        self.local_base_dir = f"data/{self.az_base_dir}"
        self.include_water = include_water

        # check if the images exist locally
        if download_data:
            print("Downloading Chips. This might take a while...")

            # Get fresh data
            if not os.path.exists(self.local_base_dir):
                os.makedirs(self.local_base_dir)

            fs = fsspec.filesystem("az", **storage_options)

            paths = self.labels["raw_img_az_path"]
            regions = self.labels["region"]

            for path, region in zip(paths, regions):
            
                water_path = f"{path[:-4]}_water.tif"
                
                fs.get_file(
                    path,
                    f"{self.local_base_dir}/{region}_{os.path.basename(path)}"
                )
                fs.get_file(
                    water_path,
                    f"{self.local_base_dir}/{region}_{os.path.basename(water_path)}"
                )


    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        data_src = self.labels.iloc[index, :]["region"]
        img_basename = os.path.basename(
            self.labels.iloc[index, :]["raw_img_az_path"]
        )
        img_path = f"{self.local_base_dir}/{data_src}_{img_basename}"

        with rio.open(img_path) as ds:
            spect_img = (ds.read(
                tuple(
                    self.rio_band_idx
                )
            ).clip(0, 15000) / (15000))
        
        if self.include_water:
            water_basename = os.path.basename(
                self.labels.iloc[index, :]["water_az_path"]
                )
            water_path = f"{self.local_base_dir}/{data_src}_{water_basename}"
            with rio.open(water_path) as ds:
                water = ds.read()
            array = np.concatenate([spect_img, water], axis=0)
        else:
            array = spect_img

        img = np.moveaxis((array - 0.5) * 2, 0, -1).astype(np.float32)

        observation = self.labels.iloc[index, :]
        y_label = torch.tensor(np.log(observation["SSC (mg/L)"]).astype(np.float32))

        if self.transform is not None:
            img = self.transform(img)
        
        # get aggregated features
        if len(self.aggregate_features) != 0:
            features = torch.FloatTensor(self.aggregate_features[index, :])
        else:
            features = torch.FloatTensor([])

        return (img, features, y_label)

In [None]:
import torchvision.transforms.functional as TF
import random

class DiscreteRandomRotation:
    """Rotate by one of the given angles."""

    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x):
        angle = random.choice(self.angles)
        return TF.rotate(x, angle)


transform_train = transforms.Compose(
    [
        transforms.ToTensor(), # Converts ndarray to tensor
        transforms.CenterCrop(96),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        DiscreteRandomRotation([0, 90, 180, 270])
    ]
)

transform_validate = transforms.Compose(
    [
        transforms.ToTensor(), # Converts ndarray to tensor
        transforms.CenterCrop(96)
    ]
)


In [None]:
## HyperParameters
train_batch_size = 64
val_batch_size = 64
learning_rate = 0.01
epochs = 500

## Load DataSets
train = FluviusDataset(
    train_labels, 
    storage_options,
    scaler,
    aggregate_features=aggregate_features,
    download_data=False, 
    include_water=False, 
    transform=transform_train)
validation = FluviusDataset(
    validation_labels,
    storage_options,
    scaler,
    aggregate_features=aggregate_features,
    download_data=False,
    include_water=False,
    transform=transform_validate)

## Set up data loaders
train_loader = DataLoader(
    train,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=3
    
)
val_loader = DataLoader(
    validation,
    batch_size=val_batch_size,
    shuffle=False,
    num_workers=3
)

val_loader_all = DataLoader(
    validation,
    batch_size=1,
    shuffle=False,
    num_workers=3
)
train_loader_all = DataLoader(
    train,
    batch_size=1,
    shuffle=False,
    num_workers=3
)


In [None]:
## Specify the model
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class CNN(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=5,
            out_channels=16,
            kernel_size=3,
            padding=1
        )
        self.avgpool1 = nn.MaxPool2d(kernel_size=4, stride=4)
        self.conv2 = nn.Conv2d(
            in_channels=16,
            out_channels=6,
            kernel_size=3,
            padding=1
        )
        self.avgpool2 = nn.MaxPool2d(kernel_size=4, stride=4)
        self.fc1 = nn.Linear(6*6*6, 8)
        self.fc2 = nn.Linear(25, 48)
        self.fc3 = nn.Linear(48, 16)
        self.fc4 = nn.Linear(16, 8)
        self.fc5 = nn.Linear(8, 1)

    def forward(self, x, features):
        x = self.avgpool1(F.relu(self.conv1(x))) # Convolve, activate, pool #1
        x = self.avgpool2(F.relu(self.conv2(x))) # Convolve, activate, pool #2
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x)) # Fully connected layers
        if features.shape[1] != 0:
            x = torch.cat((x, features), dim=1)
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = torch.relu(self.fc4(x))
        x = self.fc5(x).squeeze()
        return x

In [None]:
# Demo the data loader
val_img, _, _ = validation.__getitem__(10)

my_img = tensor_to_rgb(val_img, clip_bounds=[0,0.7], gamma=0.6, rgb=[4,3,2])
plot_image(my_img)

In [None]:
## Train the model!
model = CNN(len(train.rio_band_idx))
model.to(device)

import torch.optim as optim
criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(
    optimizer,
    step_size=200,
    gamma=0.25
)

for epoch in range(epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        # get the inputs; data is a list of [inputs, labels]
        img, features, labels = data[0].to(device), data[1].to(device), data[2].to(device)
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(img, features)
        loss = criterion(outputs, labels.squeeze())
        loss.backward()
        optimizer.step()

        # print statistics288
        running_loss += loss.item()
    
    # Validation loss:
    with torch.no_grad():
        val_loss = 0
        model.eval()
        for i, data in enumerate(val_loader):
            img, features, labels = data[0].to(device), data[1].to(device), data[2].to(device)
            y_val_pred = model(img, features)
            this_val_loss = criterion(y_val_pred, labels.squeeze())
            val_loss += this_val_loss.item()
    
    scheduler.step()
    print(f"Epoch {epoch + 1}/{epochs} | Train Loss: {running_loss/len(train_loader):.3f} | Val Loss: {val_loss/len(val_loader):.3f}")


print('Finished Training')

In [None]:
val_pred_list = []
with torch.no_grad():
        model.eval()
        for img, features, _ in val_loader_all:
                y_pred = model(img.to(device), features.to(device))
                val_pred_list.append(y_pred.tolist())

train_pred_list = []
with torch.no_grad():
        model.eval()
        for img, features, _ in train_loader_all:
                y_pred = model(img.to(device), features.to(device))
                train_pred_list.append(y_pred.tolist())

In [None]:
train_observed = [train.__getitem__(x)[2].tolist() for x in range(train.__len__())]
val_observed = [validation.__getitem__(x)[2].tolist() for x in range(validation.__len__())]

val_imgs = [validation.__getitem__(x)[0] for x in range(validation.__len__())]

In [None]:
from matplotlib import pyplot as plt

def plot_obs_predict(obs_pred, title, savefig=False, outfn=""):
    plt.figure(figsize=(12,12))
    plt.plot(list(range(0,8)),list(range(0,8)), color="black", label="One-to-one 1 line")
    #plt.scatter(obs_pred.iloc[:,0], obs_pred.iloc[:,1])
    plt.axvline(x=np.mean(obs_pred.iloc[:,1]), color="black")
    plt.xlabel("ln(SSC) Predicted")
    plt.ylabel("ln(SSC) Observed")

    for i, row in obs_pred.iterrows():
        label = f"{i}"

        plt.annotate(label, # this is the text
                    (row[0],row[1]), # these are the coordinates to position the label
                    textcoords="offset points", # how to position the text
                    xytext=(0,0) # distance from text to points (x,y))
        )
    plt.title(title)
    plt.legend()
    if savefig:
        plt.savefig(
            outfn,
            bbox_inches="tight",
            facecolor="#FFFFFF",
            dpi=150
        )

val_obs_pred = pd.DataFrame({
    "predicted": val_pred_list,
    "observed": val_observed
})


In [None]:
plot_obs_predict(val_obs_pred, "Observed vs. Predicted for Validation")

In [None]:
val_miss = list(np.array(val_obs_pred["predicted"]) -  np.array(val_obs_pred["observed"]))

my_img = tensor_to_rgb(val_imgs[163], clip_bounds=[0,0.7], gamma=0.6, rgb=[4,3,2])
plot_image(my_img)

In [None]:

print(r2_score(val_pred_list, val_observed))
print(r2_score(val_pred_list[40:164] + val_pred_list[164 + 1:], val_observed[40:164] + val_observed[164 + 1:]))