In [1]:
import numpy as np
import os
import numpy as np
from patchify import patchify
import torch
from torch.functional import F
from torch import nn
from torch.nn import DataParallel
from torch.utils.data import DataLoader
from torchvision import transforms
import time
import random

from pytorch_segmentation.validate import validate
from pytorch_segmentation.data.rwanda_dataset import RwandaDataset
from pytorch_segmentation.data.inmemory_dataset import InMemorySatDataset
from pytorch_segmentation.models import UNet

seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

In [7]:
model_name = "unet_11_07_2022_070457_new"#"unet_29_07_2022_120212" #"unet_18_07_2022_150108"#"unet_19_07_2022_115722" #"unet_18_07_2022_150108"#"unet_05_05_2022_113034" #"unet_11_07_2022_070457"#"unet_13_07_2022_171256" 
data_parallel = False
save_dir = "saved_models"

#SA data
data_path = "/home/jovyan/work/satellite_data/tmp/2018.vrt"
label_path_sa_train = "data/datasets/V6/train/SA"
label_path_sa_val = "data/datasets/V6/val/SA"

#Rwanda data
shape_path_rw = "data/datasets/V1/rwanda_tree_shapes/training_data_polygons_model_29_v2.shp"
train_data_path_rw = "data/datasets/V6/train/rwanda"
val_data_path_rw = "data/datasets/V6/val/rwanda"
test_data_path_rw = "data/datasets/V6/test/rwanda"

save_dir_sa = "data/out/validate/"+ model_name +"/SA"
save_dir_rw = "data/out/validate/"+ model_name +"/RW"

patch_size = [300,300]# [x,y,bands]
overlap = 260

val_patch_size = [256,256]# [x,y,bands]
val_overlap = 200

padding = False#True

val_transform = None

batch_size = 100
nworkers = 4
pin_memory = True

n_images = 300
archived = True

In [3]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')
if str(device) == "cpu":
    pin_memory = False

test_transform = None


# 2 Model

In [4]:
model_path = save_dir + "/" + model_name +  ".pth"
state_dict = torch.load(model_path,map_location="cpu")

In [5]:


if data_parallel:
    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v
    model_path = save_dir + "/" + model_name +  "_new.pth"
    torch.save(new_state_dict,model_path)
    

## 2.1 Advanced Unet

In [8]:
# Change here to adapt to your data
# n_channels=3 for RGB images 
# n_classes is the number of probabilities you want to get per pixel

net = UNet(n_channels=3, n_classes=2, bilinear=False)
#net= DataParallel(net)

if data_parallel:
    net.load_state_dict(new_state_dict)
else:
    net.load_state_dict(state_dict)
#net.load_state_dict(torch.load(model_path,map_location="cpu"))


#net= DataParallel(net,device_ids=[0,1])
net = net.to(device=device)



# 3 Validate Model

## Validation set

In [9]:
save_dir_sa_val = os.path.join(save_dir_sa,"val")
save_dir_rw_val = os.path.join(save_dir_rw,"val")

In [10]:
val_dataset_sa = InMemorySatDataset(data_file_path=data_path,mask_path=label_path_sa_val,
                             overlap=val_overlap,patch_size=val_patch_size,padding=padding,transform=val_transform)
dl_sa = DataLoader(val_dataset_sa,batch_size=batch_size,num_workers=nworkers,
                         shuffle=True,pin_memory=pin_memory,drop_last=False)

In [11]:
validate(net,dl_sa,save_dir_sa_val,n_images=n_images,device=device,mode="best")

Scores: acc     0.961577
iou     0.289808
dice    0.383985


acc     0.961577
iou     0.289808
dice    0.383985
Name: 0, dtype: float64

In [None]:
val_dataset_sa.export_patches("data/out/val_data/"+model_name+"/SA/val",archived=archived,max_n=300)

In [12]:
val_dataset_rw = RwandaDataset(dataset_path=None,data_file_path=val_data_path_rw,
                   shape_path=shape_path_rw,
                             overlap=val_overlap,patch_size=val_patch_size,padding=padding,transform=val_transform)

dl_rw = DataLoader(val_dataset_rw,batch_size=batch_size,num_workers=nworkers,
                         shuffle=True,pin_memory=pin_memory,drop_last=False)

100%|██████████| 18/18 [00:25<00:00,  1.40s/it]


In [13]:
validate(net,dl_rw,save_dir_rw_val,n_images=n_images,device=device,mode="best")

Scores: acc     0.912840
iou     0.573817
dice    0.703045


acc     0.912840
iou     0.573817
dice    0.703045
Name: 0, dtype: float64

In [None]:
val_dataset_rw.export_patches("data/out/val_data/RW/val",archived=archived,max_n=300)

## Training set

In [None]:
save_dir_sa_train = os.path.join(save_dir_sa,"train")
save_dir_rw_train = os.path.join(save_dir_rw,"train")

In [None]:
train_dataset_sa = InMemorySatDataset(data_file_path=data_path,mask_path=label_path_sa_train,
                             overlap=overlap,patch_size=patch_size,padding=padding,transform=None)
dl_sa = DataLoader(train_dataset_sa,batch_size=batch_size,num_workers=nworkers,
                         shuffle=True,pin_memory=pin_memory,drop_last=False)

In [None]:
validate(net,dl_sa,save_dir_sa_train,n_images=n_images,device=device,mode="worst",patch_size=[3,300,300])

In [None]:
train_dataset_sa.export_patches("data/out/val_data/SA/train",archived=archived,max_n=500)

In [None]:
train_dataset_rw = RwandaDataset(dataset_path=None,data_file_path=train_data_path_rw,
                   shape_path=shape_path_rw,
                             overlap=overlap,patch_size=patch_size,padding=padding,transform=None)

dl_rw = DataLoader(train_dataset_rw,batch_size=batch_size,num_workers=nworkers,
                         shuffle=True,pin_memory=pin_memory,drop_last=False)

In [None]:
validate(net,dl_rw,save_dir_rw_train,n_images=n_images,device=device,mode="worst",patch_size=[3,300,300])

In [None]:
train_dataset_rw.export_patches("data/out/val_data/RW/train",archived=archived,max_n=500)