## Demo

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json

import pickle
import numpy as np
import pandas as pd
import torch 
import torchmetrics
import torchvision
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from models.from_config import build_from_config
from models.double_branch import DoubleBranchCNN
from data_handlers.csv_dataset import CustomDatasetFromDataFrame
from utils import utils
from utils import transfer_learning as tl
from train import train, dual_train
from test import test

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
CSV_PATH=os.path.join('data','dataset.csv')
TEST_CSV=os.path.join('data','madagascar_test_dataset.csv')
TRAIN_CSV=os.path.join('data','madagascar_train_dataset.csv')
DATA_DIR=os.path.join('data','landsat_7','')
FOLD_PATH=os.path.join('data','dhs_incountry_folds.pkl')
CONFIG_FILE_MS = os.path.join('configs','resnet18_ms_e2e_l7_yeh.json')
CONFIG_FILE_MSNL = os.path.join('configs','resnet18_msnl_e2e_l7_yeh.json')
TILE_MIN = [-0.0994, -0.0574, -0.0318, -0.0209, -0.0102, -0.0152, 0.0, -0.07087274]
TILE_MAX = [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 316.7, 3104.1401]

In [4]:
with open( CONFIG_FILE_MS ) as f:
    config_ms = json.load(f)
with open( CONFIG_FILE_MSNL ) as f:
    config_msnl = json.load(f)
csv = pd.read_csv(CSV_PATH)
# csv.drop("bounding_box", axis=1, inplace=True)
# csv = csv.loc[:, ~csv.columns.str.contains('^Unnamed')]
csv.reset_index(inplace=True)
csv.head()

Unnamed: 0,index,country,year,cluster,lat,lon,households,wealthpooled
0,0,angola,2011,1,-12.350257,13.534922,36,2.312757
1,1,angola,2011,2,-12.360865,13.551494,32,2.010293
2,2,angola,2011,3,-12.613421,13.413085,36,0.877744
3,3,angola,2011,4,-12.581454,13.397711,35,1.066994
4,4,angola,2011,5,-12.578135,13.418748,37,1.750153


In [5]:
# COMPUTE THE MEAN AND STD OF NORMED IMAGES OVER THE COMPLETE DATASET
# EXECUTE ONCE -> to script

# TEST_TRANSFORM  = torch.nn.Sequential(
#         torchvision.transforms.CenterCrop(size=224),
#     )
# dummy_dataset = CustomDatasetFromDataFrame(csv,
#                                            DATA_DIR,
#                                            transform=TEST_TRANSFORM,
#                                            tile_max=TILE_MAX,
#                                            tile_min=TILE_MIN)
# dummy_loader = torch.utils.data.DataLoader(
#         dummy_dataset, 
#         batch_size=64
#     )

# def compute_mean_and_std(dataloader, batch_size):
#     channels_sum, channels_squared_sum, num_batches = 0, 0, 0
#     for data, _ in dataloader:
#         if data is not None:
#             weight = data.size()[0] / batch_size
#             # Mean over batch, height and width, but not over the channels
#             channels_sum += weight*torch.mean(data, dim=[0,2,3])
#             channels_squared_sum += weight*torch.mean(data**2, dim=[0,2,3])
#             num_batches += weight
#     mean = channels_sum / num_batches
#     # std = sqrt(E[X^2] - (E[X])^2)
#     std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5
#     return mean, std

# means, stds = compute_mean_and_std(dummy_loader, 64)

In [6]:
# means, stds
means = torch.tensor([0.6952, 0.6890, 0.6851, 0.6834, 0.6818, 0.6826, 0.0043])
stds = torch.tensor([9.5266, 9.7209, 9.8435, 9.8968, 9.9495, 9.9249, 0.0632])

In [7]:
TRAIN_TRANSFORM = torch.nn.Sequential(
        torchvision.transforms.CenterCrop(size=224),
        torchvision.transforms.RandomHorizontalFlip(p=0.5),
        torchvision.transforms.Normalize(
            mean=means,
            std=stds
        )
    )
TEST_TRANSFORM  = torch.nn.Sequential(
        torchvision.transforms.CenterCrop(size=224),
        torchvision.transforms.Normalize(
            mean=means,
            std=stds
        )
    )

In [8]:
# Spatially Aware Cross-Validation
with open(FOLD_PATH, 'rb') as f:
    folds = pickle.load(f)
results = dict()
device = "cuda" if torch.cuda.is_available() else "cpu"
# for fold in folds:
writer = SummaryWriter()
r2 = torchmetrics.R2Score().to(device=device)
# Index split
csv_train = pd.read_csv('data/madagascar_train_dataset.csv')
train_split=np.arange(len(csv_train))
csv_test = pd.read_csv('data/madagascar_test_dataset.csv')
csv_test.reset_index(inplace=True)
val_split=(len(csv_test))
# train_split = np.concatenate((folds['A']['train'],folds['B']['train'],folds['C']['train']))
# val_split = folds['E']['train']
# CSV split
# train_df = csv.iloc[train_split]
train_df = csv_train
# val_df = csv.iloc[val_split]
val_df = csv_test
# Datasets
train_dataset = CustomDatasetFromDataFrame(train_df, DATA_DIR,transform=TRAIN_TRANSFORM,tile_max=TILE_MAX,
                                        tile_min=TILE_MIN )
val_dataset = CustomDatasetFromDataFrame(val_df, DATA_DIR, transform=TEST_TRANSFORM,tile_max=TILE_MAX,
                                        tile_min=TILE_MIN )

# DataLoaders
train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=config_ms['batch_size'], 
    shuffle=True,
    num_workers=8,
    pin_memory=True
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config_ms['batch_size'],
    shuffle=True,
    num_workers=8,
    pin_memory=True
)

base_model = torchvision.models.resnet18(weights='ResNet18_Weights.DEFAULT')
# base_model = torchgeo.models.resnet18(weights=torchgeo.models.ResNet18_Weights.SENTINEL2_ALL_MOCO)
ms_branch = build_from_config( base_model=base_model, config_file=CONFIG_FILE_MS )
# nl_branch = tl.update_single_layer(torchvision.models.resnet18())
# model = DoubleBranchCNN(b1=ms_branch, b2=nl_branch, output_features=1)
model = ms_branch.to(device=device)
# CONFIGURE LOSS, OPTIM
loss_fn = utils.configure_loss( config_ms )
optimizer = utils.configure_optimizer( config_ms, model )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer)
# print(f"Training on fold {fold}")
print(f"Training on fold (All)")
results = train(
    model=model,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    loss_fn=loss_fn,
    epochs=config_ms['n_epochs'],
    batch_size=config_ms['batch_size'],
    in_channels=config_ms['in_channels'],
    writer=writer,
    device=device,
    ckpt_path=config_ms['checkpoint_path']+'_fold_'+'all'+".pth",
    r2=r2
)

torch.save(model.state_dict(), config_ms['checkpoint_path']+'_fold_'+'all'+".pth")
# final_results = utils.compute_average_crossval_results(results=results)

2023-05-11 16:37:24.034781: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-11 16:37:24.129310: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


Training on fold (All)


100%|██████████| 498/498 [21:28<00:00,  2.59s/it]
100%|██████████| 10/10 [00:42<00:00,  4.22s/it]


Epoch: 1 | train_loss: 0.4935 | train_r2: 0.2557 | test_loss: 0.6798 | test_r2: -0.6735


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.85it/s]


Epoch: 2 | train_loss: 0.3265 | train_r2: 0.5108 | test_loss: 0.4703 | test_r2: -0.1262


100%|██████████| 498/498 [03:53<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.86it/s]


Epoch: 3 | train_loss: 0.3117 | train_r2: 0.5342 | test_loss: 3.5967 | test_r2: -7.9386


100%|██████████| 498/498 [03:55<00:00,  2.12it/s]
100%|██████████| 10/10 [00:05<00:00,  1.86it/s]


Epoch: 4 | train_loss: 0.3030 | train_r2: 0.5434 | test_loss: 0.5378 | test_r2: -0.3705


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.87it/s]


Epoch: 5 | train_loss: 0.2941 | train_r2: 0.5599 | test_loss: 0.5299 | test_r2: -0.3647


100%|██████████| 498/498 [03:50<00:00,  2.16it/s]
100%|██████████| 10/10 [00:05<00:00,  1.86it/s]


Epoch: 6 | train_loss: 0.3000 | train_r2: 0.5501 | test_loss: 2.2282 | test_r2: -4.3690


100%|██████████| 498/498 [03:49<00:00,  2.17it/s]
100%|██████████| 10/10 [00:05<00:00,  1.88it/s]


Epoch: 7 | train_loss: 0.2884 | train_r2: 0.5680 | test_loss: 2.1380 | test_r2: -4.5275


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.87it/s]


Epoch: 8 | train_loss: 0.2871 | train_r2: 0.5695 | test_loss: 0.9718 | test_r2: -1.5318


100%|██████████| 498/498 [03:54<00:00,  2.12it/s]
100%|██████████| 10/10 [00:05<00:00,  1.85it/s]


Epoch: 9 | train_loss: 0.2818 | train_r2: 0.5774 | test_loss: 5.7901 | test_r2: -13.1843


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.83it/s]


Epoch: 10 | train_loss: 0.2806 | train_r2: 0.5788 | test_loss: 0.8757 | test_r2: -1.1795


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.83it/s]


Epoch: 11 | train_loss: 0.2805 | train_r2: 0.5783 | test_loss: 1.0348 | test_r2: -1.5830


100%|██████████| 498/498 [03:54<00:00,  2.12it/s]
100%|██████████| 10/10 [00:05<00:00,  1.85it/s]


Epoch: 12 | train_loss: 0.2728 | train_r2: 0.5911 | test_loss: 0.5394 | test_r2: -0.3189


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.87it/s]


Epoch: 13 | train_loss: 0.2675 | train_r2: 0.5982 | test_loss: 3.4188 | test_r2: -7.2514


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.84it/s]


Epoch: 14 | train_loss: 0.2493 | train_r2: 0.6260 | test_loss: 2.1029 | test_r2: -4.2272


100%|██████████| 498/498 [03:47<00:00,  2.19it/s]
100%|██████████| 10/10 [00:05<00:00,  1.85it/s]


Epoch: 15 | train_loss: 0.2415 | train_r2: 0.6381 | test_loss: 0.7307 | test_r2: -0.8599


100%|██████████| 498/498 [03:53<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.85it/s]


Epoch: 16 | train_loss: 0.2392 | train_r2: 0.6416 | test_loss: 0.5545 | test_r2: -0.3746


100%|██████████| 498/498 [03:53<00:00,  2.13it/s]
100%|██████████| 10/10 [00:05<00:00,  1.84it/s]


Epoch: 17 | train_loss: 0.2364 | train_r2: 0.6446 | test_loss: 0.5374 | test_r2: -0.3254


100%|██████████| 498/498 [03:49<00:00,  2.17it/s]
100%|██████████| 10/10 [00:05<00:00,  1.83it/s]


Epoch: 18 | train_loss: 0.2347 | train_r2: 0.6464 | test_loss: 1.2775 | test_r2: -2.1343


100%|██████████| 498/498 [03:50<00:00,  2.16it/s]
100%|██████████| 10/10 [00:05<00:00,  1.82it/s]


Epoch: 19 | train_loss: 0.2301 | train_r2: 0.6561 | test_loss: 2.5019 | test_r2: -5.0301


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.85it/s]


Epoch: 20 | train_loss: 0.2302 | train_r2: 0.6548 | test_loss: 0.4998 | test_r2: -0.2188


100%|██████████| 498/498 [03:53<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.82it/s]


Epoch: 21 | train_loss: 0.2273 | train_r2: 0.6593 | test_loss: 1.1905 | test_r2: -1.9249


100%|██████████| 498/498 [03:54<00:00,  2.13it/s]
100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch: 22 | train_loss: 0.2258 | train_r2: 0.6615 | test_loss: 0.4509 | test_r2: -0.0731


100%|██████████| 498/498 [03:49<00:00,  2.17it/s]
100%|██████████| 10/10 [00:05<00:00,  1.82it/s]


Epoch: 23 | train_loss: 0.2263 | train_r2: 0.6596 | test_loss: 3.7118 | test_r2: -8.1867


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch: 24 | train_loss: 0.2233 | train_r2: 0.6653 | test_loss: 0.4977 | test_r2: -0.1830


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch: 25 | train_loss: 0.2213 | train_r2: 0.6671 | test_loss: 1.7847 | test_r2: -3.3483


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.82it/s]


Epoch: 26 | train_loss: 0.2198 | train_r2: 0.6713 | test_loss: 0.4340 | test_r2: -0.0283


100%|██████████| 498/498 [03:49<00:00,  2.17it/s]
100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Epoch: 27 | train_loss: 0.2196 | train_r2: 0.6701 | test_loss: 6.4400 | test_r2: -14.8303


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.78it/s]


Epoch: 28 | train_loss: 0.2174 | train_r2: 0.6750 | test_loss: 4.4163 | test_r2: -9.7498


100%|██████████| 498/498 [03:53<00:00,  2.13it/s]
100%|██████████| 10/10 [00:05<00:00,  1.78it/s]


Epoch: 29 | train_loss: 0.2181 | train_r2: 0.6726 | test_loss: 1.6705 | test_r2: -3.1558


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch: 30 | train_loss: 0.2167 | train_r2: 0.6742 | test_loss: 0.6095 | test_r2: -0.4723


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch: 31 | train_loss: 0.2147 | train_r2: 0.6776 | test_loss: 0.6649 | test_r2: -0.6781


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.79it/s]


Epoch: 32 | train_loss: 0.2141 | train_r2: 0.6791 | test_loss: 0.6821 | test_r2: -0.6454


100%|██████████| 498/498 [03:49<00:00,  2.17it/s]
100%|██████████| 10/10 [00:05<00:00,  1.77it/s]


Epoch: 33 | train_loss: 0.2144 | train_r2: 0.6773 | test_loss: 1.4501 | test_r2: -2.5836


100%|██████████| 498/498 [03:50<00:00,  2.16it/s]
100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch: 34 | train_loss: 0.2107 | train_r2: 0.6837 | test_loss: 0.9826 | test_r2: -1.3510


100%|██████████| 498/498 [03:52<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.78it/s]


Epoch: 35 | train_loss: 0.2107 | train_r2: 0.6835 | test_loss: 0.5696 | test_r2: -0.3727


100%|██████████| 498/498 [03:48<00:00,  2.17it/s]
100%|██████████| 10/10 [00:05<00:00,  1.76it/s]


Epoch: 36 | train_loss: 0.2096 | train_r2: 0.6846 | test_loss: 0.4789 | test_r2: -0.1405


100%|██████████| 498/498 [03:50<00:00,  2.17it/s]
100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Epoch: 37 | train_loss: 0.2079 | train_r2: 0.6884 | test_loss: 0.5514 | test_r2: -0.3264


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.78it/s]


Epoch: 38 | train_loss: 0.2005 | train_r2: 0.7002 | test_loss: 0.2190 | test_r2: 0.4630


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.77it/s]


Epoch: 39 | train_loss: 0.1980 | train_r2: 0.7026 | test_loss: 0.3803 | test_r2: 0.0900


100%|██████████| 498/498 [03:49<00:00,  2.17it/s]
100%|██████████| 10/10 [00:05<00:00,  1.76it/s]


Epoch: 40 | train_loss: 0.1979 | train_r2: 0.7024 | test_loss: 0.3972 | test_r2: 0.0515


100%|██████████| 498/498 [03:53<00:00,  2.13it/s]
100%|██████████| 10/10 [00:05<00:00,  1.77it/s]


Epoch: 41 | train_loss: 0.1957 | train_r2: 0.7054 | test_loss: 0.2303 | test_r2: 0.4389


100%|██████████| 498/498 [03:49<00:00,  2.17it/s]
100%|██████████| 10/10 [00:05<00:00,  1.76it/s]


Epoch: 42 | train_loss: 0.1955 | train_r2: 0.7065 | test_loss: 0.4382 | test_r2: -0.1124


100%|██████████| 498/498 [03:47<00:00,  2.19it/s]
100%|██████████| 10/10 [00:05<00:00,  1.75it/s]


Epoch: 43 | train_loss: 0.1954 | train_r2: 0.7064 | test_loss: 0.4589 | test_r2: -0.1256


100%|██████████| 498/498 [03:50<00:00,  2.16it/s]
100%|██████████| 10/10 [00:05<00:00,  1.76it/s]


Epoch: 44 | train_loss: 0.1940 | train_r2: 0.7086 | test_loss: 0.1710 | test_r2: 0.5810


100%|██████████| 498/498 [03:49<00:00,  2.17it/s]
100%|██████████| 10/10 [00:05<00:00,  1.77it/s]


Epoch: 45 | train_loss: 0.1943 | train_r2: 0.7080 | test_loss: 0.2420 | test_r2: 0.4058


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.76it/s]


Epoch: 46 | train_loss: 0.1925 | train_r2: 0.7106 | test_loss: 0.2585 | test_r2: 0.3506


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Epoch: 47 | train_loss: 0.1928 | train_r2: 0.7102 | test_loss: 0.3594 | test_r2: 0.1084


100%|██████████| 498/498 [03:50<00:00,  2.16it/s]
100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Epoch: 48 | train_loss: 0.1927 | train_r2: 0.7099 | test_loss: 0.2245 | test_r2: 0.4567


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.76it/s]


Epoch: 49 | train_loss: 0.1915 | train_r2: 0.7119 | test_loss: 0.2972 | test_r2: 0.2620


100%|██████████| 498/498 [03:50<00:00,  2.16it/s]
100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Epoch: 50 | train_loss: 0.1920 | train_r2: 0.7123 | test_loss: 0.3752 | test_r2: 0.0984


100%|██████████| 498/498 [03:54<00:00,  2.12it/s]
100%|██████████| 10/10 [00:05<00:00,  1.75it/s]


Epoch: 51 | train_loss: 0.1910 | train_r2: 0.7139 | test_loss: 0.2025 | test_r2: 0.5152


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch: 52 | train_loss: 0.1903 | train_r2: 0.7138 | test_loss: 0.2401 | test_r2: 0.4139


100%|██████████| 498/498 [03:47<00:00,  2.19it/s]
100%|██████████| 10/10 [00:05<00:00,  1.72it/s]


Epoch: 53 | train_loss: 0.1904 | train_r2: 0.7143 | test_loss: 0.5442 | test_r2: -0.3352


100%|██████████| 498/498 [03:55<00:00,  2.12it/s]
100%|██████████| 10/10 [00:05<00:00,  1.72it/s]


Epoch: 54 | train_loss: 0.1893 | train_r2: 0.7166 | test_loss: 0.3030 | test_r2: 0.2442


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.71it/s]


Epoch: 55 | train_loss: 0.1901 | train_r2: 0.7148 | test_loss: 0.2068 | test_r2: 0.4854


100%|██████████| 498/498 [03:50<00:00,  2.16it/s]
100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch: 56 | train_loss: 0.1869 | train_r2: 0.7191 | test_loss: 0.1959 | test_r2: 0.5191


100%|██████████| 498/498 [03:49<00:00,  2.17it/s]
100%|██████████| 10/10 [00:05<00:00,  1.75it/s]


Epoch: 57 | train_loss: 0.1872 | train_r2: 0.7182 | test_loss: 0.2053 | test_r2: 0.4871


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.71it/s]


Epoch: 58 | train_loss: 0.1878 | train_r2: 0.7175 | test_loss: 0.2096 | test_r2: 0.4809


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Epoch: 59 | train_loss: 0.1870 | train_r2: 0.7185 | test_loss: 0.1905 | test_r2: 0.5481


100%|██████████| 498/498 [03:53<00:00,  2.13it/s]
100%|██████████| 10/10 [00:05<00:00,  1.71it/s]


Epoch: 60 | train_loss: 0.1880 | train_r2: 0.7181 | test_loss: 0.1991 | test_r2: 0.5176


100%|██████████| 498/498 [03:49<00:00,  2.17it/s]
100%|██████████| 10/10 [00:05<00:00,  1.72it/s]


Epoch: 61 | train_loss: 0.1872 | train_r2: 0.7181 | test_loss: 0.2101 | test_r2: 0.4897


100%|██████████| 498/498 [03:51<00:00,  2.16it/s]
100%|██████████| 10/10 [00:05<00:00,  1.70it/s]


Epoch: 62 | train_loss: 0.1872 | train_r2: 0.7189 | test_loss: 0.2028 | test_r2: 0.4917


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.72it/s]


Epoch: 63 | train_loss: 0.1876 | train_r2: 0.7184 | test_loss: 0.2099 | test_r2: 0.4937


100%|██████████| 498/498 [03:49<00:00,  2.17it/s]
100%|██████████| 10/10 [00:05<00:00,  1.71it/s]


Epoch: 64 | train_loss: 0.1870 | train_r2: 0.7192 | test_loss: 0.2034 | test_r2: 0.5166


100%|██████████| 498/498 [03:53<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.72it/s]


Epoch: 65 | train_loss: 0.1879 | train_r2: 0.7165 | test_loss: 0.2113 | test_r2: 0.4852


100%|██████████| 498/498 [03:49<00:00,  2.17it/s]
100%|██████████| 10/10 [00:05<00:00,  1.69it/s]


Epoch: 66 | train_loss: 0.1876 | train_r2: 0.7185 | test_loss: 0.2230 | test_r2: 0.4578


100%|██████████| 498/498 [03:50<00:00,  2.16it/s]
100%|██████████| 10/10 [00:05<00:00,  1.67it/s]


Epoch: 67 | train_loss: 0.1869 | train_r2: 0.7184 | test_loss: 0.2088 | test_r2: 0.4917


100%|██████████| 498/498 [03:53<00:00,  2.13it/s]
100%|██████████| 10/10 [00:05<00:00,  1.69it/s]


Epoch: 68 | train_loss: 0.1868 | train_r2: 0.7190 | test_loss: 0.2167 | test_r2: 0.4763


100%|██████████| 498/498 [03:53<00:00,  2.13it/s]
100%|██████████| 10/10 [00:05<00:00,  1.70it/s]


Epoch: 69 | train_loss: 0.1879 | train_r2: 0.7178 | test_loss: 0.1977 | test_r2: 0.5244


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.68it/s]


Epoch: 70 | train_loss: 0.1880 | train_r2: 0.7161 | test_loss: 0.2062 | test_r2: 0.4940


100%|██████████| 498/498 [03:53<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.69it/s]


Epoch: 71 | train_loss: 0.1871 | train_r2: 0.7200 | test_loss: 0.2063 | test_r2: 0.5030


100%|██████████| 498/498 [03:50<00:00,  2.16it/s]
100%|██████████| 10/10 [00:05<00:00,  1.68it/s]


Epoch: 72 | train_loss: 0.1860 | train_r2: 0.7217 | test_loss: 0.2079 | test_r2: 0.4732


100%|██████████| 498/498 [03:52<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.69it/s]


Epoch: 73 | train_loss: 0.1870 | train_r2: 0.7198 | test_loss: 0.1982 | test_r2: 0.4797


100%|██████████| 498/498 [03:52<00:00,  2.15it/s]
100%|██████████| 10/10 [00:06<00:00,  1.66it/s]


Epoch: 74 | train_loss: 0.1874 | train_r2: 0.7173 | test_loss: 0.2005 | test_r2: 0.5129


100%|██████████| 498/498 [03:54<00:00,  2.12it/s]
100%|██████████| 10/10 [00:06<00:00,  1.65it/s]


Epoch: 75 | train_loss: 0.1864 | train_r2: 0.7187 | test_loss: 0.2122 | test_r2: 0.4690


100%|██████████| 498/498 [03:48<00:00,  2.18it/s]
100%|██████████| 10/10 [00:05<00:00,  1.68it/s]


Epoch: 76 | train_loss: 0.1872 | train_r2: 0.7190 | test_loss: 0.2056 | test_r2: 0.4783


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.69it/s]


Epoch: 77 | train_loss: 0.1870 | train_r2: 0.7189 | test_loss: 0.2067 | test_r2: 0.4963


100%|██████████| 498/498 [03:51<00:00,  2.16it/s]
100%|██████████| 10/10 [00:05<00:00,  1.67it/s]


Epoch: 78 | train_loss: 0.1876 | train_r2: 0.7183 | test_loss: 0.2035 | test_r2: 0.5116


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:05<00:00,  1.67it/s]


Epoch: 79 | train_loss: 0.1869 | train_r2: 0.7182 | test_loss: 0.2088 | test_r2: 0.4920


100%|██████████| 498/498 [03:54<00:00,  2.12it/s]
100%|██████████| 10/10 [00:06<00:00,  1.66it/s]


Epoch: 80 | train_loss: 0.1870 | train_r2: 0.7194 | test_loss: 0.2114 | test_r2: 0.4995


100%|██████████| 498/498 [03:53<00:00,  2.13it/s]
100%|██████████| 10/10 [00:05<00:00,  1.67it/s]


Epoch: 81 | train_loss: 0.1859 | train_r2: 0.7212 | test_loss: 0.2115 | test_r2: 0.4870


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.67it/s]


Epoch: 82 | train_loss: 0.1870 | train_r2: 0.7194 | test_loss: 0.2031 | test_r2: 0.5108


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:05<00:00,  1.68it/s]


Epoch: 83 | train_loss: 0.1868 | train_r2: 0.7194 | test_loss: 0.1996 | test_r2: 0.5103


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:06<00:00,  1.65it/s]


Epoch: 84 | train_loss: 0.1877 | train_r2: 0.7176 | test_loss: 0.1945 | test_r2: 0.5300


100%|██████████| 498/498 [03:54<00:00,  2.12it/s]
100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Epoch: 85 | train_loss: 0.1875 | train_r2: 0.7188 | test_loss: 0.2079 | test_r2: 0.4957


100%|██████████| 498/498 [03:54<00:00,  2.13it/s]
100%|██████████| 10/10 [00:06<00:00,  1.66it/s]


Epoch: 86 | train_loss: 0.1871 | train_r2: 0.7190 | test_loss: 0.2061 | test_r2: 0.5023


100%|██████████| 498/498 [03:50<00:00,  2.16it/s]
100%|██████████| 10/10 [00:05<00:00,  1.67it/s]


Epoch: 87 | train_loss: 0.1871 | train_r2: 0.7186 | test_loss: 0.1946 | test_r2: 0.5295


100%|██████████| 498/498 [03:48<00:00,  2.18it/s]
100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Epoch: 88 | train_loss: 0.1868 | train_r2: 0.7197 | test_loss: 0.2048 | test_r2: 0.5057


100%|██████████| 498/498 [03:50<00:00,  2.16it/s]
100%|██████████| 10/10 [00:06<00:00,  1.64it/s]


Epoch: 89 | train_loss: 0.1868 | train_r2: 0.7189 | test_loss: 0.2062 | test_r2: 0.5011


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:06<00:00,  1.64it/s]


Epoch: 90 | train_loss: 0.1883 | train_r2: 0.7164 | test_loss: 0.2050 | test_r2: 0.4895


100%|██████████| 498/498 [03:54<00:00,  2.12it/s]
100%|██████████| 10/10 [00:06<00:00,  1.62it/s]


Epoch: 91 | train_loss: 0.1870 | train_r2: 0.7187 | test_loss: 0.1999 | test_r2: 0.5057


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:06<00:00,  1.64it/s]


Epoch: 92 | train_loss: 0.1866 | train_r2: 0.7204 | test_loss: 0.2079 | test_r2: 0.4788


100%|██████████| 498/498 [03:50<00:00,  2.16it/s]
100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Epoch: 93 | train_loss: 0.1866 | train_r2: 0.7200 | test_loss: 0.1973 | test_r2: 0.5123


100%|██████████| 498/498 [03:53<00:00,  2.13it/s]
100%|██████████| 10/10 [00:06<00:00,  1.61it/s]


Epoch: 94 | train_loss: 0.1864 | train_r2: 0.7196 | test_loss: 0.2139 | test_r2: 0.4857


100%|██████████| 498/498 [03:48<00:00,  2.18it/s]
100%|██████████| 10/10 [00:06<00:00,  1.62it/s]


Epoch: 95 | train_loss: 0.1867 | train_r2: 0.7196 | test_loss: 0.1978 | test_r2: 0.5277


100%|██████████| 498/498 [03:58<00:00,  2.09it/s]
100%|██████████| 10/10 [00:06<00:00,  1.62it/s]


Epoch: 96 | train_loss: 0.1866 | train_r2: 0.7199 | test_loss: 0.1980 | test_r2: 0.5280


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:06<00:00,  1.61it/s]


Epoch: 97 | train_loss: 0.1889 | train_r2: 0.7153 | test_loss: 0.2003 | test_r2: 0.5104


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:06<00:00,  1.61it/s]


Epoch: 98 | train_loss: 0.1870 | train_r2: 0.7186 | test_loss: 0.2034 | test_r2: 0.4901


100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 10/10 [00:06<00:00,  1.61it/s]


Epoch: 99 | train_loss: 0.1867 | train_r2: 0.7185 | test_loss: 0.2046 | test_r2: 0.5076


100%|██████████| 498/498 [03:51<00:00,  2.15it/s]
100%|██████████| 10/10 [00:06<00:00,  1.65it/s]


Epoch: 100 | train_loss: 0.1865 | train_r2: 0.7205 | test_loss: 0.1966 | test_r2: 0.5298


In [8]:
# Spatially Aware Cross-Validation

results = dict()
device = "cuda" if torch.cuda.is_available() else "cpu"
# for fold in folds:
writer = SummaryWriter()
r2 = torchmetrics.R2Score().to(device=device)
# Index split
csv_train = pd.read_csv('data/madagascar_train_dataset.csv')
train_split=np.arange(len(csv_train))
csv_test = pd.read_csv('data/madagascar_test_dataset.csv')
csv_test.reset_index(inplace=True)
val_split=(len(csv_test))
# train_split = np.concatenate((folds['A']['train'],folds['B']['train'],folds['C']['train']))
# val_split = folds['E']['train']
# CSV split
# train_df = csv.iloc[train_split]
train_df = csv_train
# val_df = csv.iloc[val_split]
val_df = csv_test
# Datasets
train_dataset = CustomDatasetFromDataFrame(train_df, DATA_DIR,transform=TRAIN_TRANSFORM,tile_max=TILE_MAX,
                                        tile_min=TILE_MIN, nl=True )
val_dataset = CustomDatasetFromDataFrame(val_df, DATA_DIR, transform=TEST_TRANSFORM,tile_max=TILE_MAX,
                                        tile_min=TILE_MIN, nl=True )

# DataLoaders
train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=config_msnl['batch_size'], 
    shuffle=True,
    num_workers=8,
    pin_memory=True
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config_msnl['batch_size'],
    shuffle=True,
    num_workers=8,
    pin_memory=True
)

base_model = torchvision.models.resnet18(weights='ResNet18_Weights.DEFAULT')
# base_model = torchgeo.models.resnet18(weights=torchgeo.models.ResNet18_Weights.SENTINEL2_ALL_MOCO)
ms_branch = build_from_config( base_model=base_model, config_file=CONFIG_FILE_MSNL )
nl_branch = tl.update_single_layer(torchvision.models.resnet18())
model = DoubleBranchCNN(b1=ms_branch, b2=nl_branch, output_features=1)
model = model.to(device=device)
# CONFIGURE LOSS, OPTIM
loss_fn = utils.configure_loss( config_msnl )
optimizer = utils.configure_optimizer( config_msnl, model )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer)
# print(f"Training on fold {fold}")
print(f"Training on fold (All)")
results = dual_train(
    model=model,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    loss_fn=loss_fn,
    epochs=config_msnl['n_epochs'],
    batch_size=config_msnl['batch_size'],
    in_channels=config_msnl['in_channels'],
    writer=writer,
    device=device,
    ckpt_path=config_msnl['checkpoint_path']+'_fold_'+'all'+".pth",
    r2=r2
)

torch.save(model.state_dict(), config_msnl['checkpoint_path']+'_fold_'+'all'+".pth")
# final_results = utils.compute_average_crossval_results(results=results)

2023-05-12 13:32:39.721695: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-12 13:32:39.799629: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


Training on fold (All)


100%|██████████| 498/498 [04:14<00:00,  1.96it/s]
100%|██████████| 10/10 [00:07<00:00,  1.43it/s]


Epoch: 1 | train_loss: 0.4165 | train_r2: 0.3527 | test_loss: 0.2250 | test_r2: 0.4639


100%|██████████| 498/498 [04:05<00:00,  2.03it/s]
100%|██████████| 10/10 [00:07<00:00,  1.36it/s]


Epoch: 2 | train_loss: 0.2686 | train_r2: 0.5979 | test_loss: 0.3033 | test_r2: 0.2715


100%|██████████| 498/498 [04:06<00:00,  2.02it/s]
100%|██████████| 10/10 [00:07<00:00,  1.27it/s]


Epoch: 3 | train_loss: 0.2642 | train_r2: 0.6049 | test_loss: 0.2728 | test_r2: 0.3286


100%|██████████| 498/498 [04:06<00:00,  2.02it/s]
100%|██████████| 10/10 [00:08<00:00,  1.22it/s]


Epoch: 4 | train_loss: 0.2598 | train_r2: 0.6104 | test_loss: 0.1880 | test_r2: 0.5413


100%|██████████| 498/498 [04:07<00:00,  2.01it/s]
100%|██████████| 10/10 [00:08<00:00,  1.19it/s]


Epoch: 5 | train_loss: 0.2548 | train_r2: 0.6167 | test_loss: 0.5678 | test_r2: -0.3547


100%|██████████| 498/498 [04:07<00:00,  2.01it/s]
100%|██████████| 10/10 [00:09<00:00,  1.10it/s]


Epoch: 6 | train_loss: 0.2495 | train_r2: 0.6258 | test_loss: 0.3438 | test_r2: 0.1748


100%|██████████| 498/498 [04:07<00:00,  2.01it/s]
100%|██████████| 10/10 [00:09<00:00,  1.06it/s]


Epoch: 7 | train_loss: 0.2515 | train_r2: 0.6232 | test_loss: 0.1635 | test_r2: 0.6077


100%|██████████| 498/498 [04:08<00:00,  2.01it/s]
100%|██████████| 10/10 [00:09<00:00,  1.01it/s]


Epoch: 8 | train_loss: 0.2458 | train_r2: 0.6329 | test_loss: 1.0675 | test_r2: -1.6023


100%|██████████| 498/498 [04:07<00:00,  2.01it/s]
100%|██████████| 10/10 [00:10<00:00,  1.02s/it]


Epoch: 9 | train_loss: 0.2421 | train_r2: 0.6377 | test_loss: 0.1588 | test_r2: 0.6199


100%|██████████| 498/498 [04:09<00:00,  2.00it/s]
100%|██████████| 10/10 [00:10<00:00,  1.06s/it]


Epoch: 10 | train_loss: 0.2412 | train_r2: 0.6390 | test_loss: 0.2532 | test_r2: 0.3992


100%|██████████| 498/498 [04:09<00:00,  2.00it/s]
100%|██████████| 10/10 [00:10<00:00,  1.09s/it]


Epoch: 11 | train_loss: 0.2412 | train_r2: 0.6373 | test_loss: 0.1771 | test_r2: 0.5678


100%|██████████| 498/498 [04:09<00:00,  2.00it/s]
100%|██████████| 10/10 [00:11<00:00,  1.15s/it]


Epoch: 12 | train_loss: 0.2371 | train_r2: 0.6454 | test_loss: 0.1704 | test_r2: 0.5548


100%|██████████| 498/498 [05:04<00:00,  1.64it/s]
100%|██████████| 10/10 [00:11<00:00,  1.19s/it]


Epoch: 13 | train_loss: 0.2377 | train_r2: 0.6432 | test_loss: 0.7272 | test_r2: -0.7850


100%|██████████| 498/498 [10:32<00:00,  1.27s/it]
100%|██████████| 10/10 [00:12<00:00,  1.24s/it]


Epoch: 14 | train_loss: 0.2354 | train_r2: 0.6468 | test_loss: 0.2218 | test_r2: 0.4601


100%|██████████| 498/498 [13:13<00:00,  1.59s/it]
100%|██████████| 10/10 [00:16<00:00,  1.61s/it]


Epoch: 15 | train_loss: 0.2357 | train_r2: 0.6457 | test_loss: 0.1542 | test_r2: 0.6254


100%|██████████| 498/498 [15:26<00:00,  1.86s/it]
100%|██████████| 10/10 [00:47<00:00,  4.76s/it]


Epoch: 16 | train_loss: 0.2321 | train_r2: 0.6517 | test_loss: 1.5231 | test_r2: -2.7903


100%|██████████| 498/498 [18:33<00:00,  2.24s/it]
100%|██████████| 10/10 [00:13<00:00,  1.39s/it]


Epoch: 17 | train_loss: 0.2327 | train_r2: 0.6505 | test_loss: 0.1375 | test_r2: 0.6710


100%|██████████| 498/498 [21:46<00:00,  2.62s/it]
100%|██████████| 10/10 [00:13<00:00,  1.39s/it]


Epoch: 18 | train_loss: 0.2339 | train_r2: 0.6489 | test_loss: 0.4592 | test_r2: -0.1292


100%|██████████| 498/498 [23:18<00:00,  2.81s/it]
100%|██████████| 10/10 [00:50<00:00,  5.02s/it]


Epoch: 19 | train_loss: 0.2313 | train_r2: 0.6526 | test_loss: 0.1765 | test_r2: 0.5514


 11%|█         | 56/498 [01:56<03:57,  1.86it/s] 

3. Test Results

In [None]:
# test_r2, Y_true, Y_pred = test(model=model, dataloader=val_loader, device=device)
# # Y_true = [ utils.denormalize_asset(asset) for asset in Y_true]
# # Y_pred = [ utils.denormalize_asset(asset) for asset in Y_pred]
# results = pd.DataFrame({
#     'true index':np.array(Y_true),
#     'predicted index':np.array(Y_pred)
# })
# from scipy.stats import pearsonr
# import seaborn as sns
# sns.set_palette("rocket")
# sns.regplot(x='true index', y='predicted index', data=results).set(title='R2 = '+str(test_r2))