In [1]:
import sys
from collections import Counter

### Numpy
import numpy as np

### Skimage
from skimage import segmentation

### matplotlib
import matplotlib.pyplot as plt

### PIL
from PIL import Image

### Pandas
import pandas as pd

In [2]:
### Torch Imports
import torch
import torch.nn.init
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

In [3]:
# Display parameters
%pylab inline
pylab.rcParams['figure.figsize'] = (20, 12)

Populating the interactive namespace from numpy and matplotlib


In [4]:
# Flag to check CUDA compatible GPU
use_cuda = torch.cuda.is_available()

In [5]:
# Parameters
nChannel = 100 # Number of conv filters/output channels
maxIter = 50 # Number of iterations
minLabels = 5 # Number of clusters
lr = 0.1 # Learning rate
nConv = 3 # Number of conv layers
num_superpixels = 100000 # Number of superpixels
compactness = 10 # Compactness of superpixels
visualize = 0 # Visualization flag
input_img = "NAIP_minis/0_2400.tif" # Input image

mask_flag = 1 # For masking based on (k2) NAIP masks

In [6]:
# CNN model
class MyNet(nn.Module):
    
    def __init__(self, input_dim):
        
        super(MyNet, self).__init__()
        self.conv1 = nn.Conv2d(input_dim, nChannel, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(nChannel)
        self.conv2 = []
        self.bn2 = []
        
        for i in range(nConv-1):
            self.conv2.append(nn.Conv2d(nChannel, nChannel, kernel_size=3, stride=1, padding=1))
            self.bn2.append(nn.BatchNorm2d(nChannel))
            
        self.conv3 = nn.Conv2d(nChannel, nChannel, kernel_size=1, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(nChannel)

    def forward(self, x):
        
        x = self.conv1(x)
        x = F.relu(x)
        x = self.bn1(x)
        
        for i in range(nConv-1):
            x = self.conv2[i](x)
            x = F.relu(x)
            x = self.bn2[i](x)
            
        x = self.conv3(x)
        x = self.bn3(x)
        
        return x

In [7]:
# load image
im = Image.open(input_img)
im = numpy.array(im)
im = im[:, :, 1:4]
print (im.shape)

data = torch.from_numpy(np.array([im.transpose((2, 0, 1)).astype('float32')/255.]))

if use_cuda:
    data = data.cuda()
    
data = Variable(data)

(800, 800, 3)


In [8]:
# slic
labels = segmentation.slic(im, compactness=compactness, n_segments=num_superpixels, multichannel=True)
labels = labels.reshape(im.shape[0]*im.shape[1])
u_labels = np.unique(labels)
l_inds = []

for i in range(len(u_labels)):
    l_inds.append(np.where(labels == u_labels[i])[0])

In [None]:
# train
model = MyNet(data.size(1))

if use_cuda:
    
    model.cuda()
    
    for i in range(nConv-1):
        
        model.conv2[i].cuda()
        model.bn2[i].cuda()
        
model.train()
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
label_colours = np.random.randint(255, size=(100,3))

for batch_idx in range(maxIter):
    
    # forward prop
    optimizer.zero_grad()
    output = model(data)[0]
    output = output.permute(1, 2, 0).contiguous().view(-1, nChannel)
    ignore, target = torch.max(output, 1)
    im_target = target.data.cpu().numpy()
    nLabels = len(np.unique(im_target))
    
    if visualize:
        
        im_target_rgb = np.array([label_colours[c%100] for c in im_target])
        im_target_rgb = im_target_rgb.reshape((im.shape[0], im.shape[1], 3)).astype(np.uint8)
        # cv2.imshow("output", im_target_rgb)
        # cv2.waitKey(10)
        # im_target_rgb = Image.fromarray(im_target_rgb)
        
        # if batch_idx%10 == 0:
            # name = "NAIP_trainoverCNN/" + str(batch_idx) + "_" + input_img[11:-4] + ".jpg"
            # im_target_rgb.save(name)
            
    # superpixel refinement
    for i in range(len(l_inds)):
        
        labels_per_sp = im_target[l_inds[i]]
        u_labels_per_sp = np.unique(labels_per_sp)
        hist = np.zeros(len(u_labels_per_sp))
        
        for j in range(len(hist)):
            
            hist[j] = len(np.where(labels_per_sp == u_labels_per_sp[j])[0])
            
        im_target[l_inds[i]] = u_labels_per_sp[np.argmax(hist)]
        
    target = torch.from_numpy(im_target)
    
    if use_cuda:
        
        target = target.cuda()
        
    target = Variable(target)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

    print (batch_idx, '/', maxIter, ':', nLabels, loss.data[0])
    
    if nLabels <= 10:
        
        im_target_rgb = np.array([label_colours[c] for c in im_target])
        im_target_rgb = im_target_rgb.reshape((im.shape[0], im.shape[1], 3)).astype(np.uint8)
        name = "NAIP_trainCNN/" + str(batch_idx) + "_" + input_img[11:-4] + "_test.jpg"
        im_target_rgb = Image.fromarray(im_target_rgb)
        im_target_rgb.save(name)
    
    if nLabels <= minLabels:
        
        print ("nLabels", nLabels, "reached minLabels", minLabels, ".")
        break 



0 / 50 : 73 tensor(3.7100)
1 / 50 : 74 tensor(3.5854)
2 / 50 : 74 tensor(3.4221)
3 / 50 : 72 tensor(3.2624)
4 / 50 : 72 tensor(3.1062)
5 / 50 : 68 tensor(2.9542)
6 / 50 : 70 tensor(2.7905)
7 / 50 : 68 tensor(2.6173)
8 / 50 : 65 tensor(2.4415)
9 / 50 : 62 tensor(2.2668)
10 / 50 : 60 tensor(2.0920)
11 / 50 : 55 tensor(1.9220)
12 / 50 : 49 tensor(1.7620)
13 / 50 : 45 tensor(1.6149)
14 / 50 : 35 tensor(1.4773)
15 / 50 : 32 tensor(1.3572)
16 / 50 : 27 tensor(1.2512)
17 / 50 : 26 tensor(1.1575)
18 / 50 : 23 tensor(1.0732)
19 / 50 : 19 tensor(0.9981)
20 / 50 : 16 tensor(0.9315)
21 / 50 : 16 tensor(0.8733)
22 / 50 : 14 tensor(0.8241)


In [None]:
output = model(data)[0]
print (output.shape)

In [None]:
output = output.permute(1, 2, 0).contiguous().view(-1, nChannel)
print (output.shape)

In [None]:
ignore, target = torch.max(output, 1)
print (ignore.shape)
print (target.shape)

In [None]:
im_target = target.data.cpu().numpy()
print (im_target.shape)

In [None]:
lbl = Counter(list(im_target.flatten()))
lbl = lbl.most_common()
lbl = [lbl[i][0] for i in range(len(lbl))]
lbl = np.array(lbl)
lbl

In [None]:
# ['green', 'blue', 'darkorange', 'deeppink', 'cyan', 'indigo', 'crimson', 'grey', 'white', 'lightsalmon', 'pink']
clr_lst = ['3cb44b', '0082c8', 'FF8C00', 'FF1493', '46f0f0', '4B0082', 'e6194b', '808080', 'FFFFFF', 'FFA07A', 'fabebe']
clr_lst = clr_lst[:len(lbl)]

In [None]:
# Convert hex color representations to RGB values
colors = []
for index in range(len(lbl)):
    rgb = list(int(clr_lst[index][i:i+2], 16) for i in (0, 2 ,4))
    colors.append(rgb)
    
color = np.array(colors)
print (color.shape)
color

In [None]:
# Create mask
im_target_rgb_org = np.array([color[np.where(lbl==c)[0][0]] for c in im_target])
im_target_rgb = np.copy(im_target_rgb_org)

maskName = input_img[11:-4] + "_mask.npy"
maskPath = "NAIP_masksnpy/" + str(maskName)
mask = np.load(maskPath)
idx_mask = np.array(np.where(mask!=mask_flag)[0])
im_target_rgb[idx_mask] = 0

im_target_rgb = im_target_rgb.reshape((im.shape[0], im.shape[1], 3)).astype(np.uint8)
im_target_rgb = Image.fromarray(im_target_rgb)

In [None]:
by_color = defaultdict(int)
for pixel in im_target_rgb.getdata():
     by_color[pixel] += 1
by_color

s = [(k, by_color[k]) for k in sorted(by_color, key=by_color.get, reverse=True)]
for k, v in s:
    print (k, v)

In [None]:
plt.imshow(im_target_rgb)