In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

import os
import pathlib

In [282]:
def binary_focal_loss_with_logits(truth, logits,
                                  gamma = 0, alpha = None):
    """
    Implementation of binary focal loss function 
    
    truth:
        Ground truth confidence, i.e. 1 for close anchors, 0 for anchors
        that are too far off and -1 for anchors to be ignored. Must have
        shape (?, fh, fw, k)
        
    logits:
        PPN confidence output, must have shape (?, fh, fw, k).
        
    gamma: 
        As defined in mathematical equation
    
    alpha:
        As defined in mathematical equation
    """
    
    # Using sigmoid not softmax as is typical for binary classifiers
    x = nn.functional.sigmoid(logits)
    
    
    
    

In [285]:
class BinaryFocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
            
    def forward():
        return
        

In [313]:
def reg_loss(reg_logits,reg_gt,conf_gt):
    """
    Function to calculate the regression part of 
    loss function. Masks all elements of ground 
    truth with confidence not equal to 1 then 
    calculates mean squared error.
        
    reg_logits: 
        Predicted normalised regression for each
        origin (?, fh, fw, 2)
    
    reg_gt: 
        Regression truth for each origin (?, fh, fw, 2)
    
    conf_gt:
        Confidence ground truth for each origin
        (0 for no galaxy close enough or 1 for 
        galaxy close enough) (?, fh, fw, 1)
    """
    regLoss = nn.MSELoss(reduction="mean")
    pos_mask = torch.eq(conf_gt,1)
    
    batch_size = conf_gt.shape[0]
    dim = conf_gt.shape[2]

    return regLoss(reg_logits[pos_mask.expand(batch_size,dim,dim,2)],
            reg_gt[pos_mask.expand(batch_size,dim,dim,2)])

In [None]:
# Testing reg_loss
dim_ = 2
batch_ = 3
x = torch.rand(batch,dim,dim,2)
y = torch.rand(batch,dim,dim,2)
conf_gt = torch.randint(low=0,high=2,size=(batch,dim,dim,1))
pos_mask_2 = torch.eq(conf_gt,1)

In [None]:
print(regLoss(x[pos_mask_2.expand(batch,dim,dim,2)],y[pos_mask_2.expand(batch,dim,dim,2)]))
print(reg_loss(y,x,conf_gt))

In [305]:
test_config = {
    "image_size": 224,
    "feature_size": 56,
    "r_far": np.sqrt(0.5*0.5 + 0.5*0.5),
    "r_near": np.sqrt(0.5*0.5 + 0.5*0.5),
    "N_conf": 1/128.0,
    "N_reg": 1/128.0,
    "Conf_loss_fn": 'crossentropy'
}

# N_reg = N_conf = 1/batch_size


In [306]:
def total_loss(reg_logits,reg_gt,conf_logits,conf_gt,config):
    """
    Function to calculate the total loss function.
    Combines the regressional loss and confidence loss.
    
    reg_gt: 
        Regression truth for each origin (?, fh, fw, 2)
        
    reg_logits: 
        Predicted normalised regression for each
        origin (?, fh, fw, 2)
    
    conf_gt:
        Confidence ground truth for each origin
        (0 for no galaxy close enough or 1 for 
        galaxy close enough) (?, fh, fw, 1)
    
    conf_logits:
        Confidence prediction for each origin. Note 
        this is before sigmoid applied. (?, fh, fw, 1)
        
    config:
        Config file
    """
    
    # Calculate regression part of loss function
    regression_loss = reg_loss(reg_logits,reg_gt,conf_gt)
    
    # Define which points to include in confidence mask
    conf_mask = torch.ne(conf_gt,-1)
    
    if config['Conf_loss_fn'] == 'CrossEntropy':
        CrossEntropy = nn.CrossEntropyLoss(weight=None, 
                        size_average=None, ignore_index=- 1, 
                        reduce=None, reduction='mean',
                        label_smoothing=0.0
                        )
        confidence_loss = CrossEntropy(conf_logits[conf_mask],
                                       conf_gt[conf_mask])
        
    else:
        return "Need to implement focal loss if necessary"
    
    N_conf, N_reg = config['N_conf'], config['N_reg']
    
    return (N_conf * confidence_loss + N_reg * regression_loss)




In [356]:
# Testing reg_loss
dim_ = 2
batch_ = 3
x = torch.rand(batch,dim,dim,2)
y = torch.rand(batch,dim,dim,2)
conf_gt = torch.randint(low=0,high=2,size=(batch,dim,dim,1),dtype=torch.float32)
conf_logits = torch.randint(low=0,high=2,size=(batch,dim,dim,1),dtype=torch.float32)
pos_mask_2 = torch.eq(conf_gt,1)

CE = nn.CrossEntropyLoss(weight=None, 
                    size_average=None, ignore_index=- 1, 
                    reduce=None, reduction='mean',
                    label_smoothing=0.0)

print(CE(conf_logits,conf_gt))
print(total_loss(x,y,conf_logits,conf_gt, test_config)*128)

tensor(0.8133)
tensor(0.0990)
tensor(0.8133)
tensor(0.9122)


In [1]:
############################## FULL MODEL FUNCTIONS

In [72]:
test_config = {
    "image_size": 224,
    "feature_size": 7,
    "r_far": np.sqrt(0.5*0.5 + 0.5*0.5),
    "r_near": np.sqrt(0.5*0.5 + 0.5*0.5),
    "N_conf": 1/1.0,
    "N_reg": 1/1.0,
    "Conf_loss_fn": 'CrossEntropy',
    "layers": [8,4,2,1],
    "batch_size": 1
}

def reg_loss(reg_logits,reg_gt,conf_gt):
    """
    Function to calculate the regression part of 
    loss function. Masks all elements of ground 
    truth with confidence not equal to 1 then 
    calculates mean squared error.
        
    reg_logits: 
        Predicted normalised regression for each
        origin (?, 2, fh, fw)
        
    reg_gt: 
        Regression truth for each origin (?, 2, fh, fw)
    
    conf_gt:
        Confidence ground truth for each origin
        (0 for no galaxy close enough or 1 for 
        galaxy close enough) (?, 1, fh, fw)
    """
    regLoss = nn.MSELoss(reduction="mean")
    pos_mask = torch.eq(conf_gt.expand(-1,2,-1,-1),1)
    #print(pos_mask)
    #print(reg_logits[pos_mask] == reg_gt[pos_mask])
    #print(reg_gt[pos_mask])
    return regLoss(reg_logits[pos_mask],reg_gt[pos_mask])


def total_loss(reg_logits,reg_gt,conf_logits,conf_gt,config):
    """
    Function to calculate the total loss function.
    
    Combines the regressional loss and confidence loss.
    
    Masks all confidence values equal to -1 then calculates
    confidence loss with BinaryCrossEntropy with logits. 
    
    If needed can add focal loss however will see how 
    training goes to start with.
    
    reg_gt: 
        Regression truth for each origin (?, 2, fh, fw)
        
    reg_logits: 
        Predicted normalised regression for each
        origin (?, 2, fh, fw)
    
    conf_gt:
        Confidence ground truth for each origin
        (0 for no galaxy close enough or 1 for 
        galaxy close enough) (?, 1, fh, fw)
    
    conf_logits:
        Confidence prediction for each origin. Note 
        this is before sigmoid applied. (?, 1, fh, fw)
        
    config:
        Config file
    """
    
    # Calculate regression part of loss function
    regression_loss = reg_loss(reg_logits,reg_gt,conf_gt)
    print(regression_loss)
    
    # Define which points to include in confidence mask
    conf_mask = torch.ne(conf_gt,-1)
    print(conf_mask)
    
    if config['Conf_loss_fn'] == 'CrossEntropy':
        
        #CrossEntropy = nn.CrossEntropyLoss(weight=None, 
        #                size_average=None, ignore_index=- 1, 
        #                reduce=None, reduction='mean',
        #                label_smoothing=0.0
        #                )
        #confidence_loss = CrossEntropy(conf_logits[conf_mask],
        #                               conf_gt[conf_mask])
        
        # need to use withlogitsloss in actually model as haven't sigmoided yet
        # but the ground truth has been sigmoided so need to just use BCELoss
        #WithLogitsLoss
        BCE = torch.nn.BCELoss(weight=None, size_average=None, 
                                         reduce=None, reduction='mean',
                                         )
         
        confidence_loss = BCE(conf_logits[conf_mask],conf_gt[conf_mask])
        
    else:
        return "Need to implement focal loss if necessary"
    
    print(confidence_loss)
    
    N_conf, N_reg = config['N_conf'], config['N_reg']
    
    return (N_conf * confidence_loss + N_reg * regression_loss)



dir_path = "/Users/edroberts/Desktop/im_gen/training_data/train"
os.chdir(dir_path)

images = np.load('training_images.npy')
test_im[0] = images[0]
test_im = np.zeros((1,1,224,224))
test_conf = np.zeros((2,1,7,7))
test_reg = np.zeros((2,2,7,7))
anchors  = np.load("anchors.npy")

os.chdir(dir_path+"/anchor_labels")
test_conf[0] = np.load("test_0000_conf.npy")
test_reg[0] = np.load("test_0000_reg.npy")
test_conf[1] = np.load("test_0001_conf.npy")
test_reg[1] = np.load("test_0001_reg.npy")

print(test_conf == test_conf)
ttest_conf = torch.tensor(test_conf)
ttest_reg = torch.tensor(test_reg)
#print(test_conf.shape)
#print(test_reg.shape)

#edited_ttest_conf = torch.clone(ttest_conf)
#print(edited_ttest_conf)
#edited_ttest_conf[0,0,0,0] = 1
#print(edited_ttest_conf)

#edited_ttest_reg = torch.clone(ttest_reg)
##print(edited_ttest_reg)
#edited_ttest_reg[0,0,0,1] = -1.3125
#print(edited_ttest_reg)

#print(reg_loss(edited_ttest_reg,ttest_reg,ttest_conf))

print(total_loss(ttest_reg,ttest_reg,
                 ttest_conf,ttest_conf,
                 test_config))

[[[[ True  True  True  True  True  True  True]
   [ True  True  True  True  True  True  True]
   [ True  True  True  True  True  True  True]
   [ True  True  True  True  True  True  True]
   [ True  True  True  True  True  True  True]
   [ True  True  True  True  True  True  True]
   [ True  True  True  True  True  True  True]]]


 [[[ True  True  True  True  True  True  True]
   [ True  True  True  True  True  True  True]
   [ True  True  True  True  True  True  True]
   [ True  True  True  True  True  True  True]
   [ True  True  True  True  True  True  True]
   [ True  True  True  True  True  True  True]
   [ True  True  True  True  True  True  True]]]]
tensor(0., dtype=torch.float64)
tensor([[[[True, True, True, True, True, True, True],
          [True, True, True, True, True, True, True],
          [True, True, True, True, True, True, True],
          [True, True, True, True, True, True, True],
          [True, True, True, True, True, True, True],
          [True, True, True, True