In [1]:
import rasterio
from rasterio.plot import show
from rasterio.mask import mask
from rasterio.windows import from_bounds
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import os
import numpy as np
from patchify import patchify
import torch
from torch.functional import F
from torch import nn
from torch.nn import DataParallel
from torch.utils.data import DataLoader
from shapely.geometry import box
from torchvision import transforms
import time
import random
import torchvision.models as models
from pytorch_segmentation.data.train_dataset import TrainDataset
from pytorch_segmentation.data.test_dataset import TestSatDataset
from pytorch_segmentation.models import UNet
import pytorch_segmentation.augmentation.transforms as seg_transforms

from pytorch_segmentation.validate import validate

seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

In [2]:
cfg = {
    "save_dir" : "saved_models",
    "model_name" : "unet_01_12_2022_085945",
    "data_parallel": True,
    
    "mode": "worst",
    "nimgs": 2000,
    
    #SA high resoluted data
    "shape_path_sa_high": "data/datasets/V24/data_pool/SA_tree_shapes/labels.geojson",
    "train_data_path_sa_high": "data/datasets/V24/train/SA_high",
    "val_data_path_sa_high": "data/datasets/V24/val/SA_high",
    "test_data_path_sa_high": "data/datasets/V24/test/SA_high",

    #SA low data
    "shape_path_sa_low": "data/datasets/V24/data_pool/SA_tree_shapes/labels.geojson",
    "train_data_path_sa_low": "data/datasets/V24/train/SA_low",
    "val_data_path_sa_low": "data/datasets/V24/val/SA_low",
    "test_data_path_sa_low": "data/datasets/V24/test/SA_low",
    
    
    #Rwanda data 2008
    "shape_path_rw_2008": "data/datasets/V24/data_pool/rwanda_tree_shapes/Training_Data_manual_Trial29_V9_2008.shp",
    "train_data_path_rw_2008": "data/datasets/V24/train/rwanda_2008",
    "val_data_path_rw_2008": "data/datasets/V24/val/rwanda_2008",
    "test_data_path_rw_2008": "data/datasets/V24/test/rwanda_2008",

    #Rwanda data 2019
    "shape_path_rw_2019": "data/datasets/V24/data_pool/rwanda_tree_shapes/Training_Data_manual_Trial29_V9_2020.shp",
    "train_data_path_rw_2019": "data/datasets/V24/train/rwanda_2019",
    "val_data_path_rw_2019": "data/datasets/V24/val/rwanda_2019",
    "test_data_path_rw_2019": "data/datasets/V24/test/rwanda_2019",



    "val_patch_size": [256,256],# [x,y,bands]
    "val_overlap": 200,
    
    "padding": False,#True

   
    #batch_size = 200
    #batch_size = 50
    "batch_size": 50, #50 #150 #75

   
    "metric": "iou",

    "n_channels": 3,

    "nworkers": 4,
    "pin_memory": True,


   
}
image_dir ="data/out/validate/"+ cfg["model_name"] 

In [3]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

val_transform = seg_transforms.Compose([
    #seg_transforms.CLAHE_Norm(),
    # seg_transforms.RandomHorizontalFlip(0.5),
    # seg_transforms.RandomVerticalFlip(0.5),
    #seg_transforms.UnmaskEdges([225,225])
#         seg_transforms.Normalize(mean=[0.5492, 0.5190, 0.4393],
#                                          std=[0.1381, 0.1320, 0.1349])
])

test_transform = None


# 1.1 Data Collection

#### Create Training and Test Dataset - SA high

In [4]:
val_dataset_sa_high = TrainDataset(dataset_path=None,data_file_path=cfg["val_data_path_sa_high"],
                   shape_path=cfg["shape_path_sa_high"],
                             overlap=cfg["val_overlap"],patch_size=cfg["val_patch_size"],padding=cfg["padding"],transform=val_transform)

print("Len Val: ",len(val_dataset_sa_high))
#print("Len Test: ",len(test_dataset_sa))

100%|██████████| 44/44 [00:05<00:00,  8.02it/s]


Len Val:  11245


#### Create Training and Test Dataset - SA low

In [5]:
val_dataset_sa_low = TrainDataset(dataset_path=None,data_file_path=cfg["val_data_path_sa_low"],
                   shape_path=cfg["shape_path_sa_low"],
                             overlap=cfg["val_overlap"],patch_size=cfg["val_patch_size"],padding=cfg["padding"],transform=val_transform)


print("Len Val: ",len(val_dataset_sa_low))
#print("Len Test: ",len(test_dataset_sa))

100%|██████████| 13/13 [00:01<00:00,  6.60it/s]


Len Val:  708


#### Create Training and Test Dataset - Rwanda 2008

In [6]:
val_dataset_rw_2008 = TrainDataset(dataset_path=None,data_file_path=cfg["val_data_path_rw_2008"],
                   shape_path=cfg["shape_path_rw_2008"],
                             overlap=cfg["val_overlap"],patch_size=cfg["val_patch_size"],padding=cfg["padding"],transform=val_transform)

print("Len Val: ",len(val_dataset_rw_2008))
#print("Len Test: ",len(test_dataset_rw))

100%|██████████| 31/31 [00:16<00:00,  1.88it/s]


Len Val:  8686


#### Create Training and Test Dataset - Rwanda 2019

In [7]:
val_dataset_rw_2019 = TrainDataset(dataset_path=None,data_file_path=cfg["val_data_path_rw_2019"],
                   shape_path=cfg["shape_path_rw_2019"],
                             overlap=cfg["val_overlap"],patch_size=cfg["val_patch_size"],padding=cfg["padding"],transform=val_transform)

print("Len Val: ",len(val_dataset_rw_2019))
#print("Len Test: ",len(test_dataset_rw))

100%|██████████| 21/21 [00:09<00:00,  2.28it/s]


Len Val:  5004


# 2 Model

In [8]:
model_path = cfg["save_dir"] + "/" + cfg["model_name"] +  ".pth"
state_dict = torch.load(model_path,map_location="cpu")

In [9]:


if cfg["data_parallel"]:
    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v
    model_path = cfg["save_dir"] + "/" + cfg["model_name"] +  "_new.pth"
    torch.save(new_state_dict,model_path)
    

## 2.1 Advanced Unet

In [10]:
# Change here to adapt to your data
# n_channels=3 for RGB images 
# n_classes is the number of probabilities you want to get per pixel

net = UNet(n_channels=3, n_classes=2, bilinear=False)
#net= DataParallel(net)

if cfg["data_parallel"]:
    net.load_state_dict(new_state_dict)
else:
    net.load_state_dict(state_dict)
#net.load_state_dict(torch.load(model_path,map_location="cpu"))


#net= DataParallel(net,device_ids=[0,1])
net = net.to(device=device)



# 2.2 Model Validation


In [11]:

val_dataset = torch.utils.data.ConcatDataset([val_dataset_rw_2008,val_dataset_rw_2019, val_dataset_sa_high,val_dataset_sa_low])
#test_dataset = torch.utils.data.ConcatDataset([test_dataset_rw, test_dataset_sa])

In [12]:
val_dataset = torch.utils.data.ConcatDataset([val_dataset_rw_2008,val_dataset_rw_2019, val_dataset_sa_high,val_dataset_sa_low])

if (len(val_dataset) % cfg["batch_size"]) < cfg["nimgs"]:
    val_dl = DataLoader(val_dataset,batch_size=cfg["batch_size"],num_workers=cfg["nworkers"],
                     shuffle=True,pin_memory=cfg["pin_memory"],drop_last=True)
else:
    val_dl = DataLoader(val_dataset,batch_size=cfg["batch_size"],num_workers=cfg["nworkers"],
                         shuffle=True,pin_memory=cfg["pin_memory"],drop_last=False)
# test_dl = DataLoader(test_dataset,batch_size=batch_size,num_workers=nworkers,
#                      shuffle=False,pin_memory=pin_memory,drop_last=False)

In [13]:
validate(net,val_dl,image_dir,n_images=cfg["nimgs"],device=device,mode=cfg["mode"])

Scores: acc     0.877653
iou     0.369249
dice    0.489295


acc     0.877653
iou     0.369249
dice    0.489295
Name: 0, dtype: float64