In [1]:
def train():
    '''
    training loop for a single epoch including attribution prior
    output --> average loss on the training dataset including attribution loss
    '''
    model.train()
    
    train_loss = []
    test_accuracy = []
    attribution_loss = []
    classif_loss = []
    
    count = 0
    for x, y in train_dataloader:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        y_hat = model(x)
        _, sparse_labels = y_hat.max(dim=1)
        
        # input x = (64, 1, 240), sparse_labels = [0,1,1,0,0] --> output (64, 1, 240)
        # attributions are calculated using Expected gradients and a single background reference
        attributions = APExp.shap_values(model, x, sparse_labels=sparse_labels)
        
        # standardize each attribution based on its var and mean --> output (64, 1, 240)
        normalized_attributions = per_x_standardization(attributions)
        
        # pass attributions to a differentiable regularization function "pix_loss" that encourages smooth attributions 
        var = pix_loss(normalized_attributions, tv_weight=tv_weight) 
        attribution_prior = var[~torch.isnan(var)].mean() # take mean to return penalizing scalar (torch.nanmean broken)
        
        # crossentropy loss for prediction
        classification_loss = criterion(y_hat, y)
        
        # total loss includes cross entropy plus the loss from the attribution prior regularized by lambda
        total_loss = classification_loss + lamb * attribution_prior
        
        total_loss.backward(retain_graph=True)
        optimizer.step()
        
        # collect losses
        train_loss.append( total_loss.item() )
        attribution_loss.append( (lamb * attribution_prior).item() )
        classif_loss.append( classification_loss.item() )
        
        # break epoch early so there is enough data and program doesn’t crash
        if count == len(train_dataloader) - 10:
            break
        count += 1
    
    return train_loss, classif_loss, attribution_loss


def test():
    '''
    Test loop for the model
    output --> average loss on the test dataset after each epoch  
    '''
    model.eval()
    val_loss = []
    with torch.no_grad():
        for x, y in test_dataloader:
            x = x.to(device)
            y = y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            val_loss.append(loss.item())
        val_loss = np.array(val_loss)
    return val_loss


def per_x_standardization(x):
    '''
    Standardizes each spectras attributions
    Equivalent behavior of tf.image.per_image_standardization but for 1D
    
    Input:
          x --> (-1, 1, features) tensor
    Output:
          normalized_attributions --> (-1, 1, features) 
    '''
    mean = torch.mean(x, 2)
    # provide an alternative lower limit so no divide by zero
    cut = torch.Tensor([1.0/x.shape[-1]])
    cut = cut.expand(x.shape[0], 1).to(device)
    adjusted_stddev = torch.max(torch.std(x, 2, unbiased=True), cut)
    normalized_attributions = (np.squeeze(x) - mean) / adjusted_stddev
    
    return normalized_attributions.unsqueeze(1)


def pix_loss(nprof, tv_weight=5):
    """
    Compute total variational across neighboring pixels
    Replicates functionality of tf.image.total_variation but for 1D
    
    Inputs:
    - nprof: Tensor of shape (batch, 1, features)
    - tv_weight: Scalar giving the weight w_t to use for the TV loss. For Expected grad tv_weight = 1
    Output:
    - loss: Tensor holding a vector giving the total variation loss for each spectrum
    """
    variance = torch.sum(torch.abs(nprof[:,:,:-1] - nprof[:,:,1:]), axis=2)
    loss = tv_weight * variance
    
    return loss