In [1]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import pdb

# supporting:

sys.path.insert(0,'..')
from config import global_config
from dataset import LaneNetDataset
from model import vgg_encoder
from model import fcn_decoder

Mainly loss calculation and inference function

In [69]:
class LaneNet:

    def __init__(self):
        self.encoder = vgg_encoder.VGGEncoder()
        self.decoder = fcn_decoder.FCNDecoder()
        self.conv1 = nn.Conv2d(64, 3, kernel_size=1, bias=False)  # pixembedding

        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
        self.preprocess = transforms.Compose([
            transforms.ToTensor(),
            normalize])

    def inference(self, src):
        decode_logits, decode_deconv  = self.run_model(src)
        
        binary_seg_ret = F.softmax(decode_logits)
        binary_seg_ret = np.argmax(binary_seg_ret, dim=1)
        
        pix_embedding = F.relu(self.conv1(decode_deconv))
        return (binary_seg_ret, pix_embedding)
    
    def run_model(self, src):
        src_tensor = self.preprocess(src)

        # have to check if batch or not
        if len(src_tensor) != 4:
            src_tensor = src_tensor.unsqueeze(0)
        
        # encode
        ret = self.encoder(src_tensor)
        # decode
        decode_logits, decode_deconv  = self.decoder(ret)
        return (decode_logits, decode_deconv)

    def compute_loss(self, src, binary, instance):

        decode_logits, decode_deconv  = self.run_model(src)

        # step 1:
        # calculate loss between binary and decode logits
        #
        # use softmax_cross_entropy
        binary_segmenatation_loss = torch.sum(- binary * F.log_softmax(decode_logits, -1), -1)
        binary_segmenatation_loss = binary_segmenatation_loss.mean()

        # step 2:
        # calculate discrimitive loss between deconv and instance
        # change deconv into pix_embedding

        # then calculate discrimitive loss
        pix_embedding = F.relu(self.conv1(decode_deconv))
        disc_loss, l_var, l_dist, l_reg = \
                lanenet_discriminative_loss.discriminative_loss(
                    pix_embedding, instance, 3, 0.5, 1.5, 1.0, 1.0, 0.001)
        
        total_loss = 0.7*binary_segmentation_loss + 0.3*disc_loss
        
        ret = {
            'total_loss': total_loss,
            'binary_seg_logits': decode_logits,
            'instance_seg_logits': pix_embedding,
            'binary_seg_loss': binary_segmenatation_loss,
            'discriminative_loss': disc_loss
        }
        
        return ret
        
    def discrimitive_loss(self, prediction, correct_label, feature_dim,
                        delta_v, delta_d, param_var, param_dist, param_reg):
        
        # saving list (maybe implement dynamic tensor?)
        output_ta_loss = []
        output_ta_var = []
        output_ta_dist = []
        output_ta_reg = []
        
        # for each batch calculate the loss
        i = 0
        while i < prediction.shape[0]:
            # calculate discrimitive loss for single image
            single_prediction = prediction[i]
            single_label = correct_lable[i]
            # pdb.set_trace()
            disc_loss, l_var, l_dist, l_reg = single_discrimitive_loss(
                single_prediction, single_label, feature_dim, delta_v, delta_d, param_var, param_dist, param_reg)
            
            output_ta_loss.append(disc_loss.unsqueeze(0))
            output_ta_va.append(l_var.unsqueeze(0))
            output_ta_dist.append(l_dist.unsqueeze(0))
            output_ta_reg.append(l_reg.unsqueeze(0))
            
            i += 1  # next image in batch
        
        out_loss_op = torch.cat(output_ta_loss)
        out_var_op = torch.cat(output_ta_var)
        out_dist_op = torch.cat(output_ta_dist)
        out_reg_op = torch.cat(output_ta_reg)
        
        # calculate mean of the batch
        disc_loss = out_loss_op.mean()
        l_var = out_var_op.mean()
        l_dist = out_vdist_op.mean()
        l_reg = out_reg_op.mean()

        return disc_loss, l_var, l_dist, l_reg
        
    def discriminative_loss_single(
            prediction,
            correct_label,
            feature_dim,
            delta_v,
            delta_d,
            param_var,
            param_dist,
            param_reg):
        """
        The example partition loss function mentioned in the paper equ(1)
        :param prediction: inference of network
        :param correct_label: instance label
        :param feature_dim: feature dimension of prediction
        :param delta_v: cutoff variance distance
        :param delta_d: curoff cluster distance
        :param param_var: weight for intra cluster variance
        :param param_dist: weight for inter cluster distances
        :param param_reg: weight regularization
        """

        # Make it a single line
        correct_label = correct_label.view([correct_label.shape[0] * correct_label.shape[1]]).float()
        reshaped_pred = prediction.view([feature_dim, prediction[0] * prediction[1]]).float()
        
        # Get unique labels
        unique_labels, unique_id = torch.unique(correct_label, sorted=True, return_inverse=True)
        ids, counts = np.unique(unique_id, return_counts=True)
        num_instances = len(counts)
        counts = torch.tensor(counts, dtype=torch.float32)
        
        # Calculate the pixel embedding mean vector
        segmented_sum = torch.zeros(feature_dim, num_instances).scatter_add(1, unique_id.repeat([feature_dim,1]), reshaped_pred)
        mu = torch.div(segmented_sum, counts)
        mu_expand = torch.gather(mu, 1, unique_id.repeat([feature_dim,1]))

        # Calculate loss(var)
        distance = (mu_expand - reshaped_pred).t().norm(dim=1)
        distance -= torch.tensor(delta_v, dtype=torch.float32)
        distance = torch.clamp(distance, min=0.)   # min is 0.
        distance = distance.pow(2)
        
        l_var = torch.zeros(num_instances).scatter_add(0, unique_id, distance)
        l_var = torch.div(l_var, counts)
        l_var = l_var.sum()
        l_var = torch.div(l_var, num_instances)  # single value 
   
        # Calculate the loss(dist) of the formula
        for i in range(feature_dim):
            for j in range(feature_dim):
                if i != j:
                    diff = mu[i] - mu[j]
                    mu_diff.append(diff.unsqueeze(0))
                    
        mu_diff = torch.cat(mu_diff)
        
        mu_norm = mu_diff.norm(dim=1)
        mu_norm = (2. * delta_d - mu_norm)
        mu_norm = torch.clamp(mu_norm, min=0.)
        mu_norm = mu_norm.pow(2)
        
        l_dist = mu_norm.mean()
        
        # Calculate the regular term loss mentioned in the original Discriminative Loss paper
        l_reg = mu.norm(dim=1).mean()

        # Consolidation losses are combined according to the parameters mentioned in the original Discriminative Loss paper
        param_scale = 1.
        l_var = param_var * l_var
        l_dist = param_dist * l_dist
        l_reg = param_reg * l_reg

        loss = param_scale * (l_var + l_dist + l_reg)

        return loss, l_var, l_dist, l_reg


In [2]:
# checking 3/4d matrix and 
v = torch.randn(3, 4, 5)
v_ = v.unsqueeze(0)
# v_.shape

l = torch.randn(4, 3, 5, 5)
l.shape[0]
# l[2].shape

var = []
for i in range(l.shape[0]):
    l_ = l[0].unsqueeze(0)
    var.append(l_)
    
la = torch.cat(var)

In [3]:
la.shape

torch.Size([4, 3, 5, 5])

In [4]:
print(la[0])
la[0].view([3*5*5])

tensor([[[ 1.7211, -2.7096,  0.3037,  0.2904, -1.1449],
         [-1.9282, -1.0326, -1.1341,  0.1440,  0.7094],
         [ 2.1354,  1.4281, -0.4329, -0.7835, -0.7612],
         [ 1.3064, -0.4989, -0.0346,  1.5963,  1.4391],
         [ 0.1881,  0.0033,  0.5158, -0.1344, -2.5246]],

        [[-0.0331,  1.0059, -0.5271, -0.0486, -0.4435],
         [ 0.7803,  0.2696, -1.0937, -0.2011, -1.2671],
         [-0.0911,  0.1432, -0.7650, -0.6285,  1.1247],
         [ 0.5227,  0.2437, -0.4807,  1.6500,  0.6388],
         [ 1.3636, -0.7209, -0.5638,  1.5149,  0.0320]],

        [[-0.2427, -0.1569, -0.4877,  1.5384, -0.1461],
         [ 0.0479,  1.1628,  1.0367,  0.8321,  0.9458],
         [-0.2865,  0.7631,  0.6680,  1.0855, -0.9787],
         [-0.2787, -0.1016, -0.3290, -1.0247,  0.4326],
         [ 0.2514,  0.8999, -0.0997,  1.3971,  0.2803]]])


tensor([ 1.7211, -2.7096,  0.3037,  0.2904, -1.1449, -1.9282, -1.0326, -1.1341,
         0.1440,  0.7094,  2.1354,  1.4281, -0.4329, -0.7835, -0.7612,  1.3064,
        -0.4989, -0.0346,  1.5963,  1.4391,  0.1881,  0.0033,  0.5158, -0.1344,
        -2.5246, -0.0331,  1.0059, -0.5271, -0.0486, -0.4435,  0.7803,  0.2696,
        -1.0937, -0.2011, -1.2671, -0.0911,  0.1432, -0.7650, -0.6285,  1.1247,
         0.5227,  0.2437, -0.4807,  1.6500,  0.6388,  1.3636, -0.7209, -0.5638,
         1.5149,  0.0320, -0.2427, -0.1569, -0.4877,  1.5384, -0.1461,  0.0479,
         1.1628,  1.0367,  0.8321,  0.9458, -0.2865,  0.7631,  0.6680,  1.0855,
        -0.9787, -0.2787, -0.1016, -0.3290, -1.0247,  0.4326,  0.2514,  0.8999,
        -0.0997,  1.3971,  0.2803])

In [5]:
h, j = torch.unique(l[0], sorted=True, return_inverse=True)

In [6]:
a = torch.tensor([10, 5, 3, 0, 4, 5, 4, 5, 5, 3, 2, 1])
a_ = torch.tensor([[9, 4, 3, 0, 3, 5, 3, 5, 6, 3, 2, 1], [8, 3, 3, 0, 3, 5, 3, 5, 7, 3, 1, 1]], dtype=torch.float32)

In [7]:
unique_labels, unique_id = torch.unique(a, sorted=True, return_inverse=True)

In [8]:
unique_labels

tensor([ 0,  1,  2,  3,  4,  5, 10])

In [9]:
unique_id

tensor([6, 5, 3, 0, 4, 5, 4, 5, 5, 3, 2, 1])

In [10]:
labels, counts = np.unique(unique_id, return_counts=True)

In [11]:
print(labels, counts)
len(counts)

[0 1 2 3 4 5 6] [1 1 1 2 2 4 1]


7

In [12]:
# segmented sum:

index = torch.tensor([[0, 0, 1, 1, 0, 1]])
data = torch.tensor([[5., 1., 7., 2., 3., 4.]])

torch.zeros(1, 2).scatter_add(1, index, data)

tensor([[ 9., 13.]])

In [13]:
print(unique_id.shape)
print(a_.shape)
print(len(counts))
print(unique_id.repeat([2,1]))
segmented_sum = torch.zeros(2, len(counts)).scatter_add(1, unique_id.repeat([2,1]), a_)

torch.Size([12])
torch.Size([2, 12])
7
tensor([[6, 5, 3, 0, 4, 5, 4, 5, 5, 3, 2, 1],
        [6, 5, 3, 0, 4, 5, 4, 5, 5, 3, 2, 1]])


In [14]:
a_count = torch.tensor(counts, dtype=torch.float32)
# a_count = a_count.view(-1, 1)
print('seg sum', segmented_sum)
print('counts ', a_count)
mu = torch.div(segmented_sum, a_count)
print('normalized ', mu)

seg sum tensor([[ 0.,  1.,  2.,  6.,  6., 20.,  9.],
        [ 0.,  1.,  1.,  6.,  6., 20.,  8.]])
counts  tensor([1., 1., 1., 2., 2., 4., 1.])
normalized  tensor([[0., 1., 2., 3., 3., 5., 9.],
        [0., 1., 1., 3., 3., 5., 8.]])


In [15]:
print('ids ', unique_id)
print('mu ', mu)
mu_expand = torch.gather(mu, 1, unique_id.repeat([2,1]))
mu_expand

ids  tensor([6, 5, 3, 0, 4, 5, 4, 5, 5, 3, 2, 1])
mu  tensor([[0., 1., 2., 3., 3., 5., 9.],
        [0., 1., 1., 3., 3., 5., 8.]])


tensor([[9., 5., 3., 0., 3., 5., 3., 5., 5., 3., 2., 1.],
        [8., 5., 3., 0., 3., 5., 3., 5., 5., 3., 1., 1.]])

In [32]:
print(mu_expand-a_)
distance = (mu_expand - a_).norm(dim=0)
# distance = torch.norm(mu_expand - a_, 2, -1)
distance

tensor([[ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0., -1.,  0.,  0.,  0.],
        [ 0.,  2.,  0.,  0.,  0.,  0.,  0.,  0., -2.,  0.,  0.,  0.]])


tensor([0.0000, 2.2361, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.2361,
        0.0000, 0.0000, 0.0000])

In [33]:
delta_v = 0.03
distance -= torch.tensor(delta_v, dtype=torch.float32)
distance

tensor([-0.0300,  2.2061, -0.0300, -0.0300, -0.0300, -0.0300, -0.0300, -0.0300,
         2.2061, -0.0300, -0.0300, -0.0300])

In [34]:
distance = torch.clamp(distance, min=0.)
distance 

tensor([0.0000, 2.2061, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.2061,
        0.0000, 0.0000, 0.0000])

In [35]:
distance = distance.pow(2)
distance

tensor([0.0000, 4.8667, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 4.8667,
        0.0000, 0.0000, 0.0000])

In [36]:
l_var = torch.zeros(len(counts)).scatter_add(0, unique_id, distance)
l_var

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 9.7335, 0.0000])

In [37]:
l_var = torch.div(l_var, torch.tensor(counts, dtype=torch.float32))
l_var

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.4334, 0.0000])

In [38]:
l_var = l_var.sum()
l_var = torch.div(l_var, len(counts))
l_var

tensor(0.3476)

In [50]:
feature_dim = 2
num_instances = len(counts)
mu_diff = []

# think of a better way to do this:
for i in range(feature_dim):
    for j in range(feature_dim):
        if i != j:
            diff = mu[i] - mu[j]
            mu_diff.append(diff.unsqueeze(0))
            
mu_diff = torch.cat(mu_diff)
mu_diff

tensor([[ 0.,  0.,  1.,  0.,  0.,  0.,  1.],
        [ 0.,  0., -1.,  0.,  0.,  0., -1.]])

In [64]:
mu_norm = mu_diff.norm(dim=1)
mu_norm

tensor([1.4142, 1.4142])

In [65]:
delta_d = 3
mu_norm = (2. * delta_d - mu_norm)
mu_norm

tensor([4.5858, 4.5858])

In [66]:
mu_norm = torch.clamp(mu_norm, min=0.)
mu_norm = mu_norm.pow(2)
mu_norm

tensor([21.0294, 21.0294])

In [67]:
l_dist = mu_norm.mean()
l_dist

tensor(21.0294)

In [68]:
print(mu.norm(dim=1))
print(mu.norm(dim=1).mean())

l_reg = mu.norm(dim=1).mean()
l_reg

tensor([11.3578, 10.4403])
tensor(10.8991)


tensor(10.8991)