# Training

In [0]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image


In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

In [0]:
device = torch.device("cuda")
np.random.seed(seed=33)

if True:
  deeplab_resnet101 = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)

normalize3 = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
normalize1 = torchvision.transforms.Normalize(mean=[0.485], std=[0.229])
transforms3 = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize3])
transforms1 = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize1])

In [0]:
rgb_data = Image.open('/content/gdrive/My Drive/Potsdam/top_potsdam_2_10_RGB.tif')
elevation_data = Image.open('/content/gdrive/My Drive/Potsdam/dsm_potsdam_02_10_normalized_lastools.jpg')
label_data = Image.open('/content/gdrive/My Drive/Potsdam/top_potsdam_2_10_label.tif')

In [0]:
def transmute_to_classes(window):
  # 2 = tree
  # 3 = ground
  # 4 = clutter
  # 1 = building
  # 5 = car
  # 0 = everything else
  retval = (4*(window[:, :, 0]/0xff).astype(np.long) + 2*(window[:, :, 1]/0xff).astype(np.long) + 1*(window[:, :, 2]/0xff).astype(np.long))
  cars = (retval == 6)
  not_cars = (retval != 6)
  retval = (retval * not_cars) + 5*cars
  retval = retval * (retval < 6)
  return retval

def random_potsdam_training_window(rgb_data, elevation_data, label_data):
  size = 224
  x = np.random.randint(0, 6000 - size)
  y = np.random.randint(0, 6000 - size)
  box = (x, y, x + size, y + size)
  rgb_window = rgb_data.crop(box)
  elevation_window = elevation_data.crop(box)
  labels_window = np.array(label_data.crop(box))
  labels_window = transmute_to_classes(labels_window)
  return (rgb_window, elevation_window, labels_window)

def random_potsdam_training_batch(rgb_ar, elevation_ar, labels_ar):
  batch_size = 16
  
  rgbs = []
  elvs = []
  labs = []
  
  for i in range(batch_size):
    rgb, elv, lab = random_potsdam_training_window(rgb_ar, elevation_ar, labels_ar)

    rgbs.append(transforms3(rgb))
    elvs.append(transforms1(elv))
    labs.append(torch.unsqueeze(torch.from_numpy(lab), 0))

  rgbs = torch.stack(rgbs).to(device)
  elvs = torch.stack(elvs).to(device)
  labs = torch.cat(labs, dim=0).to(device)

  return (rgbs, elvs, labs)


In [0]:
# Reshape Network for 7 Classes

last_class = deeplab_resnet101.classifier[4] = torch.nn.Conv2d(256, 7, kernel_size=(1,1), stride=(1,1))
last_class_aux = deeplab_resnet101.aux_classifier[4] = torch.nn.Conv2d(256, 7, kernel_size=(1,1), stride=(1,1))

deeplab_resnet101 = deeplab_resnet101.to(device)

In [0]:
# Feature Extraction Only

if False:
  for p in deeplab_resnet101.parameters():
    p.requires_grad = False

  for p in last_class.parameters():
    p.requires_grad = True

  for p in last_class_aux.parameters():
    p.requires_grad = True


In [0]:
# Optimizer

ps = []
for n, p in deeplab_resnet101.named_parameters():
  if p.requires_grad == True:
    ps.append(p)
    
opt = torch.optim.SGD(ps, lr=0.01, momentum=0.9)

In [0]:
# Objective Function

obj = torch.nn.CrossEntropyLoss().to(device)

In [0]:
# Get a Batch

if False:
  import time

  before = time.time()
  batch_tensor = random_potsdam_training_batch(rgb_data, elevation_data, label_data)
  after = time.time()
  print(after - before)

In [0]:
# Train for One Step

if False:
  opt.zero_grad()
  pred = deeplab_resnet101(batch_tensor[0])
  loss = obj(pred.get('out'), batch_tensor[2])
  loss.backward()
  opt.step()

In [0]:
# Train

import time
import math

steps_per_epoch_per_image = int((6000 * 6000) / (224 * 224 * 16))
epochs = 3

deeplab_resnet101.train()
for i in range(epochs):
  for j in range(steps_per_epoch_per_image):
    batch_tensor = random_potsdam_training_batch(rgb_data, elevation_data, label_data)
    opt.zero_grad()
    pred = deeplab_resnet101(batch_tensor[0])
    loss = obj(pred.get('out'), batch_tensor[2]) + 0.4*obj(pred.get('aux'), batch_tensor[2])
    loss.backward()
    opt.step()
  print('epoch={} time={} loss={}'.format(i, time.time(), loss.item()))

# Eval

In [0]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
device = torch.device("cuda")
np.random.seed(seed=33)

if True:
  deeplab_resnet101 = torch.load('/content/gdrive/My Drive/Potsdam/deeplab_resnet101_ft.pth')

normalize3 = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
normalize1 = torchvision.transforms.Normalize(mean=[0.485], std=[0.229])
transforms3 = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize3])
transforms1 = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize1])

In [0]:
rgb_data = Image.open('/content/gdrive/My Drive/Potsdam/Validation/top_potsdam_2_12_RGB.tif')
elevation_data = Image.open('/content/gdrive/My Drive/Potsdam/Validation/dsm_potsdam_02_12_normalized_lastools.jpg')
label_data = Image.open('/content/gdrive/My Drive/Potsdam/Validation/top_potsdam_2_12_label.tif')

In [0]:
def transmute_to_classes(window):
  # 2 = tree
  # 3 = ground
  # 4 = clutter
  # 1 = building
  # 5 = car
  # 0 = everything else
  retval = (4*(window[:, :, 0]/0xff).astype(np.long) + 2*(window[:, :, 1]/0xff).astype(np.long) + 1*(window[:, :, 2]/0xff).astype(np.long))
  retval = retval * (retval < 6) # XXX error here
  return retval

def potsdam_eval_window(x, y, rgb_data, elevation_data, label_data):
  size = 1000
  x = int(x * size)
  y = int(y * size)
  box = (x, y, x + size, y + size)
  rgb_window = rgb_data.crop(box)
  elevation_window = elevation_data.crop(box)
  labels_window = np.array(label_data.crop(box))
  labels_window = transmute_to_classes(labels_window)
  return (rgb_window, elevation_window, labels_window)

def potsdam_eval_batch(x, y, rgb_ar, elevation_ar, labels_ar):
  batch_size = 1
  
  rgbs = []
  elvs = []
  labs = []
  
  for i in range(batch_size):
    rgb, elv, lab = potsdam_eval_window(x, y, rgb_ar, elevation_ar, labels_ar)

    rgbs.append(transforms3(rgb))
    elvs.append(transforms1(elv))
    labs.append(torch.unsqueeze(torch.from_numpy(lab), 0))

  rgbs = torch.stack(rgbs).to(device)
  elvs = torch.stack(elvs).to(device)
  labs = torch.cat(labs, dim=0).to(device)

  return (rgbs, elvs, labs)


In [0]:
# Get a Batch
if False:
  batch_tensor = potsdam_eval_batch(2.5, 2.5, rgb_data, elevation_data, label_data)


In [0]:
if False:
  deeplab_resnet101.eval()
  out = deeplab_resnet101(batch_tensor[0])
  out = out['out'].data.cpu().numpy()
  out.shape

In [0]:
if False:
  index = 0
  predicted_segments = np.apply_along_axis(np.argmax, 0, out[index])
  plt.imshow(predicted_segments)

In [0]:
if False:
  groundtruth_segments = batch_tensor[2].data.cpu().numpy()[index]
  plt.imshow(groundtruth_segments)

In [0]:
if False:
  plt.imshow(np.fabs(predicted_segments - groundtruth_segments))

In [0]:
if False:
  img = np.transpose((batch_tensor[0][index].cpu().numpy() * 255 + 255), (1, 2, 0)).astype(int)
  plt.imshow(img)

In [0]:
tps = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
fps = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
fns = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

deeplab_resnet101.eval()

for x in range(6):
  for y in range(6):
      batch_tensor = potsdam_eval_batch(x, y, rgb_data, elevation_data, label_data)
      out = deeplab_resnet101(batch_tensor[0])
      out = out['out'].data.cpu().numpy()
      index = 0
      predicted_segments = np.apply_along_axis(np.argmax, 0, out[index])
      groundtruth_segments = batch_tensor[2].data.cpu().numpy()[index]
      for c in range(6):
        tps[c] += ((predicted_segments == c) * (groundtruth_segments == c)).sum()
        fps[c] += ((predicted_segments == c) * (groundtruth_segments != c)).sum()
        fns[c] += ((predicted_segments != c) * (groundtruth_segments == c)).sum()

print('True Positives:  {}'.format(tps))
print('False Negatives: {}'.format(fps))
print('False Negatives: {}'.format(fns))

True Positives:  [6493420.0, 3122085.0, 5042536.0, 17133250.0, 99557.0, 0.0]
False Negatives: [554530.0, 162405.0, 924831.0, 2224644.0, 242742.0, 0.0]
False Negatives: [745670.0, 449560.0, 1590466.0, 1189350.0, 134106.0, 0.0]


In [0]:
recalls = []
precisions = []

for c in range(5): # XXX no cars
  recall = tps[c] / (tps[c] + fns[c])
  recalls.append(recall)
  precision = tps[c] / (tps[c] + fps[c])
  precisions.append(precision)

print('Recalls:   {}'.format(recalls))
print('Precision: {}'.format(precisions))

Recalls:   [0.8969939591854778, 0.8741308276718431, 0.7602192792946542, 0.9350883608221541, 0.4260708798568879]
Precision: [0.9213203839414298, 0.9505539672825918, 0.845018581897175, 0.8850782011720904, 0.29084805973724726]


In [0]:
names = [
    'other         ',
    'building      ',
    'tree          ',
    'low vegitation',
    'clutter       ',
    'car           '
]
f1s = []

for c in range(5): # XXX no cars
  f1 = 2 * (precisions[c] * recalls[c]) / (precisions[c] + recalls[c])
  f1s.append(f1)
  print('{} {}'.format(names[c], f1))


0.9089944453154747 everything else
0.910741985098018 building
0.8003791000088966 tree
0.909396251545959 ground
0.3457068348259087 clutter


In [0]:
precision = np.array(tps).sum() / (np.array(tps).sum() + np.array(fps).sum())
recall = np.array(tps).sum() / (np.array(tps).sum() + np.array(fns).sum())
f1 = 2 * (precision * recall) / (precision + recall)

print('Overall Precision: {}'.format(precision))
print('Overall Recall:    {}'.format(recall))
print('Overall f1:        {}'.format(f1))

Overall Precision: 0.8858568888888889
Overall Recall:    0.8858568888888889
Overall f1:        0.8858568888888889


In [0]:
np.array(fps).sum()

In [0]:
np.array(fns).sum()