The following is an example of how to utilize our Sen1Floods11 dataset for training a FCNN. In this example, we train and validate on hand-labeled chips of flood events. However, our dataset includes several other options that are detailed in the README. To replace the dataset, as outlined further below, simply replace the train, test, and validation split csv's, and download the corresponding dataset.

Authenticate Google Cloud Platform. Note that to run this code, you must connect your notebook runtime to a GPU. 

Install RasterIO

In [1]:
#!pip install rasterio

In [2]:
#%pip install scikit-learn
#%pip install pandas

Define a model checkpoint folder, for storing network checkpoints during training

In [3]:
"""%cd /home
!sudo mkdir checkpoints"""

'%cd /home\n!sudo mkdir checkpoints'

Download train, test, and validation splits for both flood water. To download different train, test, and validation splits, simply replace these paths with the path to a csv containing the desired splits. 

In [4]:
"""!gsutil cp gs://sen1floods11/v1.1/splits/flood_handlabeled/flood_train_data.csv .
!gsutil cp gs://sen1floods11/v1.1/splits/flood_handlabeled/flood_test_data.csv .
!gsutil cp gs://sen1floods11/v1.1/splits/flood_handlabeled/flood_valid_data.csv ."""

'!gsutil cp gs://sen1floods11/v1.1/splits/flood_handlabeled/flood_train_data.csv .\n!gsutil cp gs://sen1floods11/v1.1/splits/flood_handlabeled/flood_test_data.csv .\n!gsutil cp gs://sen1floods11/v1.1/splits/flood_handlabeled/flood_valid_data.csv .'

Download raw train, test, and validation data. In this example, we are downloading train, test, and validation data of flood images which are hand labeled. However, you can simply replace these paths with whichever dataset you would like to use - further documentation of the Sen1Floods11 dataset and organization is available in the README.

In [5]:


"""!gsutil -m rsync -r gs://sen1floods11/v1.1/data/flood_events/HandLabeled/S1Hand files/S1
!gsutil -m rsync -r gs://sen1floods11/v1.1/data/flood_events/HandLabeled/LabelHand files/Labels"""

'!gsutil -m rsync -r gs://sen1floods11/v1.1/data/flood_events/HandLabeled/S1Hand files/S1\n!gsutil -m rsync -r gs://sen1floods11/v1.1/data/flood_events/HandLabeled/LabelHand files/Labels'

In [6]:
import sys
sys.path.append('./src')

Define model training hyperparameters

In [7]:
LR = 5e-4
MAX_LR = 5e-4
EPOCHS = 100

EPOCHS_PER_UPDATE = 1
RUNNAME = "Sen1Floods11"

Define functions to process and augment training and testing images

In [8]:
import torch
from torchvision import transforms
import torch.nn.functional as fun
import torchvision.transforms.functional as F
import torchvision.transforms as T
import random
from PIL import Image
import numpy as np
class InMemoryDataset(torch.utils.data.Dataset):
  
  def __init__(self, data_list, preprocess_func, source='S1', select_bands=(0,1,2)):
    self.data_list = data_list
    self.preprocess_func = preprocess_func
    self.source = source
    self.select_bands = select_bands
  
  def __getitem__(self, i):
    return self.preprocess_func(self.data_list[i], self.source, self.select_bands)
  
  def __len__(self):
    return len(self.data_list)


def processAndAugment(data, source='S1', select_bands=(0,1,2)): 
  (x,y) = data
  im,label = x.copy(), y.copy()
  label = label.astype(np.float)
  
  if source == 'S1':
    bands = 2
  else:
    bands = len(select_bands)

  # convert to PIL for easier transforms
  ims = []
  for i in range(bands):
    ims.append(Image.fromarray(im[i]))

  label = Image.fromarray(label.squeeze())      

  # Get params for random transforms
  i, j, h, w = transforms.RandomCrop.get_params(ims[0], (256, 256))


  for i in range(bands):
    ims[i] = F.crop(ims[i], i, j, h, w)
  label = F.crop(label, i, j, h, w)
 
  if random.random() > 0.5:
    for i in range(bands):
      ims[i] = F.hflip(ims[i])
    label = F.hflip(label)

  if random.random() > 0.5:
    for i in range(bands):
      ims[i] = F.vflip(ims[i])
    label = F.vflip(label)

  if random.random() > 0.75:
    rotation = random.choice((90, 180, 270))
    for i in range(bands): 
      ims[i] = F.rotate(ims[i], rotation)
    label = F.rotate(label, rotation)
  
  """if random.random() > 0.2:
    for i in range(bands):
      ims[i] = F.gaussian_blur(ims[i], 7)"""

  # What does this do
  if source == 'S1':
    norm = transforms.Normalize([0.6851, 0.5235], [0.0820, 0.1102])
  else: #TODO band selector
    mean_list = np.array([0.16269160022432763, 0.13960347063125136, 0.13640611841716485, 
    0.1218228479188587, 0.14660729066303788, 0.23869029753700105, 0.284561256276994, 0.2622957968923778, 
    0.3077482214806557, 0.048687436781988974, 0.006377861007811543, 0.20306476302374007, 0.11791660722096743])
    std_list = np.array([0.07001713384623806, 0.07390945268205054, 0.07352482387959473, 0.08649366949997794, 
    0.07768803358037298, 0.09213683430927469, 0.10843734609719749, 0.10226341800670553, 0.1196442553176325, 
    0.03366110543131479, 0.014399923282248634, 0.09808706134697646, 0.07646083655721092])
    norm = transforms.Normalize(mean_list[np.array(select_bands)], std_list[np.array(select_bands)])
    

  blur = transforms.GaussianBlur(19, (.5,1.5))
  t1 = transforms.ToPILImage()

  ims_T = []
  for i in range(bands):
    ims_T.append(transforms.ToTensor()(ims[i]).squeeze())
  
  im = torch.stack(ims_T)
  if random.random() > .8:
    #BLURSAVE = random.randint(0,1000000000)
    #ts = t1(im)
    #ts.save('blurred/{}crisp.png'.format(BLURSAVE))
    im = blur(im)
    #te = t1(im)
    #te.save('blurred/{}blur.png'.format( BLURSAVE))

  im = norm(im)
  
  label = transforms.ToTensor()(label).squeeze()
  if torch.sum(label.gt(.003) * label.lt(.004)):
    label *= 255
  label = label.round()

  return im, label


def processTestIm(data, source='S1', select_bands=(0,1,2)): 
  if source == 'S1':
    bands = 2
  else:
    bands = len(select_bands)
  
  (x,y) = data
  im,label = x.copy(), y.copy()
  label = label.astype(np.float)
  if source == 'S1':
    norm = transforms.Normalize([0.6851, 0.5235], [0.0820, 0.1102])
  else: #TODO band selector
    mean_list = np.array([0.16269160022432763, 0.13960347063125136, 0.13640611841716485, 
    0.1218228479188587, 0.14660729066303788, 0.23869029753700105, 0.284561256276994, 0.2622957968923778, 
    0.3077482214806557, 0.048687436781988974, 0.006377861007811543, 0.20306476302374007, 0.11791660722096743])
    std_list = np.array([0.07001713384623806, 0.07390945268205054, 0.07352482387959473, 0.08649366949997794, 
    0.07768803358037298, 0.09213683430927469, 0.10843734609719749, 0.10226341800670553, 0.1196442553176325, 
    0.03366110543131479, 0.014399923282248634, 0.09808706134697646, 0.07646083655721092])
    norm = transforms.Normalize(mean_list[np.array(select_bands)], std_list[np.array(select_bands)])


  # convert to PIL for easier transforms
  im_c = []
  for i in range(bands):
    im_c.append(Image.fromarray(im[i]).resize((512,512)))

  label = Image.fromarray(label.squeeze()).resize((512,512))

  im_cs = []
  for i in range(bands):
    im_cs.append([F.crop(im_c[i], 0, 0, 256, 256), F.crop(im_c[i], 0, 256, 256, 256),
            F.crop(im_c[i], 256, 0, 256, 256), F.crop(im_c[i], 256, 256, 256, 256)])
  labels = [F.crop(label, 0, 0, 256, 256), F.crop(label, 0, 256, 256, 256),
            F.crop(label, 256, 0, 256, 256), F.crop(label, 256, 256, 256, 256)]

  ims = []
  for i in range(4):
    temp = []
    for j in range(bands):
      temp.append(transforms.ToTensor()(im_cs[j][i]).squeeze())
    ims.append(torch.stack(temp))
      
  
  ims = [norm(im) for im in ims]
  ims = torch.stack(ims)
  
  labels = [(transforms.ToTensor()(label).squeeze()) for label in labels]
  labels = torch.stack(labels)
  
  
  if torch.sum(labels.gt(.003) * labels.lt(.004)):
    labels *= 255
  labels = labels.round()
  
  return ims, labels

def save_images(image_tensor, label_tensor, count, i):
    img = torch.argmax(image_tensor, dim=0)
    img = img * 255
    img = img.cpu()
    img = img.numpy().astype('uint8')
    lbl = torch.clone(label_tensor)
    lbl = lbl.cpu()
    lbl = label_tensor.numpy().astype('uint8')
    lbl[lbl==255] = 0.
    lbl = lbl * 255
    
    img = Image.fromarray(img)
    lbl = Image.fromarray(lbl)
    img.save('predictions/pred_{}_{}.png'.format(count, i))
    lbl.save('predictions/label_{}_{}.png'.format(count, i))

Load *flood water* train, test, and validation data from splits. In this example, this is the data we will use to train our model.

In [9]:
from time import time
import csv
import os
import numpy as np
import rasterio

def getArrFlood(fname):
  x = rasterio.open(fname).read()
  return x

def download_flood_water_data_from_list(l, source, select_bands): #TODO band selector

  i = 0
  tot_nan = 0
  tot_good = 0
  flood_data = []
  for (im_fname, mask_fname) in l:
    if not os.path.exists(os.path.join("data/", im_fname)):
      print('No data for ', im_fname)
      continue

    temp_x = getArrFlood(os.path.join("data/", im_fname))
    #TODO band selector slice applicable bands here
    if source == 'S1':
      arr_x = np.nan_to_num(temp_x)
    else:
      arr_x = np.nan_to_num(temp_x)[select_bands,:,:]
    arr_y = getArrFlood(os.path.join("data/", mask_fname))
    arr_y[arr_y == -1] = 255 
    
    if source == 'S1':
      arr_x = np.clip(arr_x, -50, 1)
      arr_x = (arr_x + 50) / 51
    else:
      arr_x = arr_x / 10000
      
    if i % 100 == 0:
      print(im_fname, mask_fname)
    i += 1
    flood_data.append((arr_x,arr_y))
  #print(flood_data)
  return flood_data

# Isaac note: Change the fname to be the path to the weakly labeled csv (S1_Weak_data_Otsu.csv or S2_Index_Label_Weak.csv)
def load_flood_train_data(input_root, label_root, source='S1', select_bands=(0,1,2)):
  fname = "splits/flood_handlabeled/flood_train_data.csv"
  training_files = []
  with open(fname) as f:
    for line in csv.reader(f):
      training_files.append(tuple((input_root+line[0], label_root+line[1])))

  return download_flood_water_data_from_list(training_files, source, select_bands)

def load_flood_valid_data(input_root, label_root, source='S1', select_bands=(0,1,2)):
  fname = "splits/flood_handlabeled/flood_valid_data.csv"
  validation_files = []
  with open(fname) as f:
    for line in csv.reader(f):
      validation_files.append(tuple((input_root+line[0], label_root+line[1])))

  return download_flood_water_data_from_list(validation_files, source, select_bands)

def load_flood_test_data(input_root, label_root, source='S1', select_bands=(0,1,2)):
  fname = "splits/flood_handlabeled/flood_test_data.csv"
  testing_files = []
  with open(fname) as f:
    for line in csv.reader(f):
      testing_files.append(tuple((input_root+line[0], label_root+line[1])))
  
  return download_flood_water_data_from_list(testing_files, source, select_bands)

Load training data and validation data. Note that here, we have chosen to train and validate our model on flood data. However, you can simply replace the load function call with one of the options defined above to load a different dataset.

In [10]:
"""temp_train = load_flood_train_data('flood_events/HandLabeled/S1Hand/', 'flood_events/HandLabeled/LabelHand/', source='S1')
#print(type(temp_train))
print(np.shape(temp_train[0][0]))
print(len(temp_train))
list_train = []
list_test = []
for i in range(len(temp_train)):
    list_train.append(temp_train[i][0])
    list_test.append(temp_train[i][1])
train_array = np.array(list_train)
test_array = np.array(list_test)
print(train_array.shape)
print(test_array.shape)
print(test_array[0,0])"""




"temp_train = load_flood_train_data('flood_events/HandLabeled/S1Hand/', 'flood_events/HandLabeled/LabelHand/', source='S1')\n#print(type(temp_train))\nprint(np.shape(temp_train[0][0]))\nprint(len(temp_train))\nlist_train = []\nlist_test = []\nfor i in range(len(temp_train)):\n    list_train.append(temp_train[i][0])\n    list_test.append(temp_train[i][1])\ntrain_array = np.array(list_train)\ntest_array = np.array(list_test)\nprint(train_array.shape)\nprint(test_array.shape)\nprint(test_array[0,0])"

In [11]:
"""print(np.count_nonzero(test_array==2))"""


"""means = []
stds = []
for i in range(13):
    means.append(np.mean(train_array[:,i,:,:]))
    stds.append(np.std(train_array[:,i,:,:]))
print(means)
print(stds)"""



'means = []\nstds = []\nfor i in range(13):\n    means.append(np.mean(train_array[:,i,:,:]))\n    stds.append(np.std(train_array[:,i,:,:]))\nprint(means)\nprint(stds)'

In [12]:
# Isaac note: change the train data arguments here to be the path to the weakly labeled data. The S1Weak folder contains the actual
# training data. S1OtsuLabelWeak and S2IndexLabelWeak are different label sets generated by different algorithms. Use whichever you want.
SOURCE = 'S1'
train_data = load_flood_train_data('flood_events/HandLabeled/S1Hand/', 'flood_events/HandLabeled/LabelHand/', source=SOURCE, select_bands=(np.arange(13)))
train_dataset = InMemoryDataset(train_data, processAndAugment, source=SOURCE, select_bands=(np.arange(13)))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, sampler=None,
                  batch_sampler=None, num_workers=0, collate_fn=None,
                  pin_memory=True, drop_last=False, timeout=0,
                  worker_init_fn=None)
train_iter = iter(train_loader)

valid_data = load_flood_valid_data('flood_events/HandLabeled/S1Hand/', 'flood_events/HandLabeled/LabelHand/', source=SOURCE, select_bands=(np.arange(13))) 
valid_dataset = InMemoryDataset(valid_data, processTestIm, source=SOURCE, select_bands=(np.arange(13)))
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=4, shuffle=True, sampler=None,
                  batch_sampler=None, num_workers=0, collate_fn=lambda x: (torch.cat([a[0] for a in x], 0), torch.cat([a[1] for a in x], 0)),
                  pin_memory=True, drop_last=False, timeout=0,
                  worker_init_fn=None)
valid_iter = iter(valid_loader)

flood_events/HandLabeled/S1Hand/Ghana_103272_S1Hand.tif flood_events/HandLabeled/LabelHand/Ghana_103272_LabelHand.tif
flood_events/HandLabeled/S1Hand/Pakistan_132143_S1Hand.tif flood_events/HandLabeled/LabelHand/Pakistan_132143_LabelHand.tif
flood_events/HandLabeled/S1Hand/Sri-Lanka_916628_S1Hand.tif flood_events/HandLabeled/LabelHand/Sri-Lanka_916628_LabelHand.tif
flood_events/HandLabeled/S1Hand/Ghana_5079_S1Hand.tif flood_events/HandLabeled/LabelHand/Ghana_5079_LabelHand.tif


temp

Define differentiable loss functions. Assessment metrics defined below use argmax, making them non-differentiable.

In [14]:
#https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, output, target, smooth=1):
        if output.shape[1] == 1:
            output = fun.sigmoid(output)
        else:
            output = fun.softmax(output, dim=1)[:,1]

        #flatten label and prediction tensors
        output = output.flatten()
        target = target.flatten()
        no_ignore = target.ne(255).cuda()
        output = output.masked_select(no_ignore)
        target = target.masked_select(no_ignore)
        TP = torch.sum(output * target)
        return 1 - ((2. * TP + smooth) / (output.sum() + target.sum() + smooth))




In [15]:
# Dice Loss squared with optional gamma
class DiceLossSquared(nn.Module):
    def __init__(self, weight=None, size_average=True, gamma=1):
        super(DiceLossSquared, self).__init__()
        self.gamma = gamma

    def forward(self, output, target, smooth=1):
        if output.shape[1] == 1:
            output = fun.sigmoid(output)
        else:
            output = fun.softmax(output, dim=1)[:,1]

        #flatten label and prediction tensors
        output = output.flatten()
        target = target.flatten()
        no_ignore = target.ne(255).cuda()
        output = output.masked_select(no_ignore)
        target = target.masked_select(no_ignore)
        TP = torch.sum(output * target)
        return (1 - ((2. * TP + smooth) / ((output**2).sum() + (target**2).sum() + smooth)))**self.gamma




In [16]:
class IOU(nn.Module):
    def __init__(self, weight=None, size_average=True, smooth=1):
        super(IOU, self).__init__()
        self.smooth = smooth

    def forward(self, output, target):
        if output.shape[1] == 1:
            output = fun.sigmoid(output)
        else:
            output = fun.softmax(output, dim=1)[:,1]    
        
        #flatten label and prediction tensors
        output = output.flatten()
        target = target.flatten()
        no_ignore = target.ne(255).cuda()
        output = output.masked_select(no_ignore)
        target = target.masked_select(no_ignore)
        intersection = torch.sum(output * target)
        union = torch.sum(target) + torch.sum(output) - intersection
        return 1 - ((intersection + self.smooth) / (union + self.smooth))

In [17]:
# Modification to Dice loss to correct class imbalance
class FocalTverskyLoss(nn.Module):
    def __init__(self, weight=None, size_average=True, smooth=1, alpha = .5, beta = .5, gamma = 1):
        super(FocalTverskyLoss, self).__init__()
        self.smooth = smooth
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

    def forward(self, output, target):
        if output.shape[1] == 1:
            output = fun.sigmoid(output)
        else:
            output = fun.softmax(output, dim=1)[:,1]       
        
        #flatten label and prediction tensors
        output = output.flatten()
        target = target.flatten()
        no_ignore = target.ne(255).cuda()
        output = output.masked_select(no_ignore)
        target = target.masked_select(no_ignore)
        TP = torch.sum(output * target)
        FP = torch.sum((1 - target) * output)
        FN = torch.sum(target * (1 - output))
        return (1 - ((TP + self.smooth) / (TP + (self.alpha * FN) + (self.beta * FP) + self.smooth)))**self.gamma

Define the network. For our purposes, we use ResNet50. However, if you wish to test a different model framework, optimizer, or loss function you can simply replace those here. 

In [18]:
import torch
import torchvision.models as models
import torch.nn as nn
from src import seg_models

net = models.segmentation.fcn_resnet50(pretrained=False, num_classes=2, pretrained_backbone=False)
#print('initial net:', net)
net.backbone.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3, bias=False)
#net = seg_models.MyUnet()
#net = seg_models.AttentionUNet()
#criterion = IOU()
criterion = nn.CrossEntropyLoss(weight=torch.tensor([1,8]).float().cuda(), ignore_index=255) 
#criterion = DiceLoss()
#criterion = DiceLossSquared()
#criterion = FocalTverskyLoss()
optimizer = torch.optim.AdamW(net.parameters(),lr=LR)
#optimizer = torch.optim.SGD(net.parameters(), lr= MAX_LR, momentum=.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, len(train_loader) * 10, T_mult=2, eta_min=0, last_epoch=-1)
#scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=LR, max_lr=MAX_LR, step_size_up=128)

def convertBNtoGN(module, num_groups=16):
  if isinstance(module, torch.nn.modules.batchnorm.BatchNorm2d):
    return nn.GroupNorm(num_groups, module.num_features,
                        eps=module.eps, affine=module.affine)
    if module.affine:
        mod.weight.data = module.weight.data.clone().detach()
        mod.bias.data = module.bias.data.clone().detach()

  for name, child in module.named_children():
      module.add_module(name, convertBNtoGN(child, num_groups=num_groups))

  return module

net = convertBNtoGN(net)
#print('modified net:', net)




Define assessment metrics. For our purposes, we use overall accuracy and mean intersection over union. However, we also include functions for calculating true positives, false positives, true negatives, and false negatives.

In [19]:
def computeIOU(output, target):
  #TODO sigmoid support
  if output.shape[1] == 1:
    temp = -output
    output = torch.stack((output, temp), dim=1)

  output = torch.argmax(output, dim=1).flatten()  
  target = target.flatten()
  
  no_ignore = target.ne(255).cuda()
  output = output.masked_select(no_ignore)
  target = target.masked_select(no_ignore)
  intersection = torch.sum(output * target)
  union = torch.sum(target) + torch.sum(output) - intersection
  iou = (intersection + .0000001) / (union + .0000001)
  
  if iou != iou:
    print("failed, replacing with 0")
    iou = torch.tensor(0).float()
  
  return iou
  
def computeAccuracy(output, target):
  output = torch.argmax(output, dim=1).flatten() 
  target = target.flatten()
  
  no_ignore = target.ne(255).cuda()
  output = output.masked_select(no_ignore)
  target = target.masked_select(no_ignore)
  correct = torch.sum(output.eq(target))
  
  return correct.float() / len(target)

def truePositives(output, target):
  output = torch.argmax(output, dim=1).flatten() 
  target = target.flatten()
  no_ignore = target.ne(255).cuda()
  output = output.masked_select(no_ignore)
  target = target.masked_select(no_ignore)
  correct = torch.sum(output * target)
  
  return correct

def trueNegatives(output, target):
  output = torch.argmax(output, dim=1).flatten() 
  target = target.flatten()
  no_ignore = target.ne(255).cuda()
  output = output.masked_select(no_ignore)
  target = target.masked_select(no_ignore)
  output = (output == 0)
  target = (target == 0)
  correct = torch.sum(output * target)
  
  return correct

def falsePositives(output, target):
  output = torch.argmax(output, dim=1).flatten() 
  target = target.flatten()
  no_ignore = target.ne(255).cuda()
  output = output.masked_select(no_ignore)
  target = target.masked_select(no_ignore)
  output = (output == 1)
  target = (target == 0)
  correct = torch.sum(output * target)
  
  return correct

def falseNegatives(output, target):
  output = torch.argmax(output, dim=1).flatten() 
  target = target.flatten()
  no_ignore = target.ne(255).cuda()
  output = output.masked_select(no_ignore)
  target = target.masked_select(no_ignore)
  output = (output == 0)
  target = (target == 1)
  correct = torch.sum(output * target)
  
  return correct

Define training loop

In [20]:
training_losses = []
training_accuracies = []
training_ious = []

def train_loop(inputs, labels, net, optimizer, scheduler):
  global running_loss
  global running_iou
  global running_count
  global running_accuracy
  
  # zero the parameter gradients
  optimizer.zero_grad()
  net = net.cuda()
  
  # forward + backward + optimize
  outputs = net(inputs.cuda())
  loss = criterion(outputs["out"], labels.long().cuda())
  loss.backward()
  optimizer.step()
  print(scheduler.get_last_lr())
  scheduler.step()

  running_loss += loss
  running_iou += computeIOU(outputs["out"], labels.cuda())
  running_accuracy += computeAccuracy(outputs["out"], labels.cuda())
  running_count += 1

Define validation loop

In [21]:
from time import time
valid_losses = []
valid_accuracies = []
valid_ious = []

def validation_loop(validation_data_loader, net):
  global running_loss
  global running_iou
  global running_count
  global running_accuracy
  global max_valid_iou

  global training_losses
  global training_accuracies
  global training_ious
  global valid_losses
  global valid_accuracies
  global valid_ious

  net = net.eval()
  net = net.cuda()
  count = 0
  iou = 0
  loss = 0
  accuracy = 0
  with torch.no_grad():
      for (images, labels) in validation_data_loader:
          net = net.cuda()
          outputs = net(images.cuda())
          valid_loss = criterion(outputs["out"], labels.long().cuda())
          valid_iou = computeIOU(outputs["out"], labels.cuda())
          valid_accuracy = computeAccuracy(outputs["out"], labels.cuda())
          iou += valid_iou
          loss += valid_loss
          accuracy += valid_accuracy
          count += 1

  iou = iou / count
  accuracy = accuracy / count

  if iou > max_valid_iou:
    max_valid_iou = iou
    save_path = os.path.join("checkpoints", "{}_{}_{}.cp".format(RUNNAME, i, iou.item()))
    optim_save_path = os.path.join("checkpoints", "{}_{}_{}_{}.cp".format(RUNNAME, i, iou.item(), 'optim'))
    scheduler_save_path = os.path.join("checkpoints", "{}_{}_{}_{}.cp".format(RUNNAME, i, iou.item(), 'sheduler'))
    torch.save(net.state_dict(), save_path)
    torch.save(optimizer.state_dict(), optim_save_path)
    torch.save(scheduler.state_dict(), scheduler_save_path)
    print("model saved at", save_path)
    

  loss = loss / count
  print("Training Loss:", running_loss / running_count)
  print("Training IOU:", running_iou / running_count)
  print("Training Accuracy:", running_accuracy / running_count)
  print("Validation Loss:", loss)
  print("Validation IOU:", iou)
  print("Validation Accuracy:", accuracy)


  """training_losses.append(running_loss / running_count)
  training_accuracies.append(running_accuracy / running_count)
  training_ious.append(running_iou / running_count)
  valid_losses.append(loss)
  valid_accuracies.append(accuracy)
  valid_ious.append(iou)"""
  training_losses.append(running_loss.detach().cpu() / running_count)
  training_accuracies.append(running_accuracy.detach().cpu() / running_count)
  training_ious.append(running_iou.detach().cpu() / running_count)
  valid_losses.append(loss.detach().cpu())
  valid_accuracies.append(accuracy.detach().cpu())
  valid_ious.append(iou.detach().cpu())

Define testing loop (here, you can replace assessment metrics).

In [22]:
def test_loop(test_data_loader, net):
  net = net.eval()
  net = net.cuda()
  count = 0
  iou = 0
  loss = 0
  accuracy = 0
  with torch.no_grad():
      for (images, labels) in tqdm(test_data_loader):
          net = net.cuda()
          outputs = net(images.cuda())
          for i in range(outputs['out'].shape[0]):
            save_images(outputs['out'][i], labels[i], count, i)

          valid_loss = criterion(outputs["out"], labels.long().cuda())
          valid_iou = computeIOU(outputs["out"], labels.cuda())
          iou += valid_iou
          accuracy += computeAccuracy(outputs["out"], labels.cuda())
          count += 1

  iou = iou / count
  print("Test IOU:", iou)
  print("Test Accuracy:", accuracy / count)

Define training and validation scheme

In [23]:
from tqdm.notebook import tqdm
from IPython.display import clear_output

running_loss = 0
running_iou = 0
running_count = 0
running_accuracy = 0

training_losses = []
training_accuracies = []
training_ious = []
valid_losses = []
valid_accuracies = []
valid_ious = []


def train_epoch(net, optimizer, scheduler, train_iter):
  for (inputs, labels) in tqdm(train_iter):
    train_loop(inputs.cuda(), labels.cuda(), net.cuda(), optimizer, scheduler)
 

def train_validation_loop(net, optimizer, scheduler, train_loader,
                          valid_loader, num_epochs, cur_epoch):
  global running_loss
  global running_iou
  global running_count
  global running_accuracy
  net = net.train()
  running_loss = 0
  running_iou = 0
  running_count = 0
  running_accuracy = 0
  
  for i in tqdm(range(num_epochs)):
    train_iter = iter(train_loader)
    train_epoch(net, optimizer, scheduler, train_iter)
  clear_output()
  
  print("Current Epoch:", cur_epoch)
  validation_loop(iter(valid_loader), net)

Load checkpoints

In [24]:
"""device = torch.device('cuda')
save_path = 'checkpoints/Sen1Floods11_1_0.5941653251647949.cp'
optim_save_path = 'checkpoints/Sen1Floods11_1_0.5941653251647949_optim.cp'
scheduler_save_path = 'checkpoints/Sen1Floods11_1_0.5941653251647949_sheduler.cp'
net_checkpoint = torch.load(save_path)
optim_checkpoint = torch.load(optim_save_path)
schedule_checkpoint = torch.load(scheduler_save_path)
net.load_state_dict(net_checkpoint)
net.to(device=device)
optimizer.load_state_dict(optim_checkpoint)
scheduler.load_state_dict(schedule_checkpoint)"""



"device = torch.device('cuda')\nsave_path = 'checkpoints/Sen1Floods11_1_0.5941653251647949.cp'\noptim_save_path = 'checkpoints/Sen1Floods11_1_0.5941653251647949_optim.cp'\nscheduler_save_path = 'checkpoints/Sen1Floods11_1_0.5941653251647949_sheduler.cp'\nnet_checkpoint = torch.load(save_path)\noptim_checkpoint = torch.load(optim_save_path)\nschedule_checkpoint = torch.load(scheduler_save_path)\nnet.load_state_dict(net_checkpoint)\nnet.to(device=device)\noptimizer.load_state_dict(optim_checkpoint)\nscheduler.load_state_dict(schedule_checkpoint)"

Train model and assess metrics over epochs

In [25]:
"""import os
from IPython.display import display
import matplotlib.pyplot as plt

max_valid_iou = 0
start = 0

epochs = []
training_losses = []
training_accuracies = []
training_ious = []
valid_losses = []
valid_accuracies = []
valid_ious = []

for i in range(start, EPOCHS):
  train_validation_loop(net, optimizer, scheduler, train_loader, valid_loader, 1, i)
  epochs.append(i)
  x = epochs
  plt.plot(x, training_losses, label='training losses')
  #plt.plot(x, training_accuracies, 'tab:orange', label='training accuracy')
  plt.plot(x, training_ious, 'tab:purple', label='training iou')
  plt.plot(x, valid_losses, label='valid losses')
  #plt.plot(x, valid_accuracies, 'tab:red',label='valid accuracy')
  plt.plot(x, valid_ious, 'tab:green',label='valid iou')
  plt.legend(loc="upper left")

  display(plt.show())

  print("max valid iou:", max_valid_iou)"""

'import os\nfrom IPython.display import display\nimport matplotlib.pyplot as plt\n\nmax_valid_iou = 0\nstart = 0\n\nepochs = []\ntraining_losses = []\ntraining_accuracies = []\ntraining_ious = []\nvalid_losses = []\nvalid_accuracies = []\nvalid_ious = []\n\nfor i in range(start, EPOCHS):\n  train_validation_loop(net, optimizer, scheduler, train_loader, valid_loader, 1, i)\n  epochs.append(i)\n  x = epochs\n  plt.plot(x, training_losses, label=\'training losses\')\n  #plt.plot(x, training_accuracies, \'tab:orange\', label=\'training accuracy\')\n  plt.plot(x, training_ious, \'tab:purple\', label=\'training iou\')\n  plt.plot(x, valid_losses, label=\'valid losses\')\n  #plt.plot(x, valid_accuracies, \'tab:red\',label=\'valid accuracy\')\n  plt.plot(x, valid_ious, \'tab:green\',label=\'valid iou\')\n  plt.legend(loc="upper left")\n\n  display(plt.show())\n\n  print("max valid iou:", max_valid_iou)'

Load test data

In [26]:
SOURCE = 'S1'
test_data = load_flood_test_data('flood_events/HandLabeled/S1Hand/', 'flood_events/HandLabeled/LabelHand/', source=SOURCE, select_bands=(np.arange(13))) 
test_dataset = InMemoryDataset(test_data, processTestIm, source=SOURCE, select_bands=(np.arange(13)))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False, sampler=None,
                  batch_sampler=None, num_workers=0, collate_fn=lambda x: (torch.cat([a[0] for a in x], 0), torch.cat([a[1] for a in x], 0)),
                  pin_memory=True, drop_last=False, timeout=0,
                  worker_init_fn=None)
test_iter = iter(test_loader)


flood_events/HandLabeled/S1Hand/Ghana_313799_S1Hand.tif flood_events/HandLabeled/LabelHand/Ghana_313799_LabelHand.tif


Loach checkpoints and test model.

In [27]:
device = torch.device('cuda')
#save_path = 'checkpoints/Sen1Floods11_62_0.6036016941070557.cp' #S2 weak
#save_path = 'checkpoints/Sen1Floods11_1347_0.6440591216087341.cp' 
#save_path = 'checkpoints/Default_53_0.5246185064315796.cp'
save_path = 'checkpoints/S1WeakOtsu_15_0.5288943648338318.cp'
#save_path = 'checkpoints/Baseline_S1_Dice_26_0.5623193383216858.cp'

net_checkpoint = torch.load(save_path)
net.load_state_dict(net_checkpoint)
net.to(device=device)
net.eval()

test_loop(test_loader, net)


  0%|          | 0/23 [00:00<?, ?it/s]

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  label = label.astype(np.float)


Test IOU: tensor(0.5018, device='cuda:0')
Test Accuracy: tensor(0.9344, device='cuda:0')
