## Imports and Preliminary

In [2]:
import torch
import torch.nn.functional as F
from torch import Tensor
from overrides import overrides, final

## Classifiers

Two classifiers have been implemented to solve the classification task on the <i>source</i> ($ C^{(s)} $) and <i>target</i> ($ C^{(t)} $) domain. The former <code>task classifier</code>, $ C^{(s)} $, is trained using the following cross-entropy loss over the <i>labeled</i> source samples: <br /><br />
$$ min_{C^{(s)}}\mathcal{E}_{task}^{(s)}(G, C^{(s)})=-\frac{1}{n_{s}}\sum_{i=1}^{n_{s}}log\bigg(p_{y_{i}^{(s)}}^{(s)}({\bf x}_{i}^{(s)})\bigg) $$
In the formula above, $G$ represents the <code>feature extractor</code>, ${\bf x}_{i}^{(s)}$ the output vector of $ C^{(s)} $, and $p_{y_{i}^{(s)}}^{(s)}({\bf x}_{i}^{(s)}) \in [0,1]^{K}$ the distribution of probability after the <code>  [softmax](https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html)</code>  operation. 

Since target samples are unlabeled, there exist no direct
supervision signals to learn a task classifier $ C^{(t)} $. Therefore, [Zhang](https://arxiv.org/abs/1904.04663) <i>et al.</i> leverage the labeled source samples by using the following cross-entropy loss: 
$$ min_{C^{(t)}}\mathcal{E}_{task}^{(t)}(G, C^{(t)})=-\frac{1}{n_{s}}\sum_{i=1}^{n_{s}}log(p_{y_{i}^{(s)}}^{(t)}({\bf x}_{i}^{(s)})) $$

It is worth noticing that $C^{(t)}$ will be distinguishable from $C^{(s)}$ through the domain discrimination training of the classifier $C^{(st)}$. Moreover, [Zhang](https://arxiv.org/abs/1904.04663) <i>et al.</i> stresses the use of labeled source samples to enhance $C^{(t)}$'s performance in discriminating among task categories. 

In [None]:
class CrossEntropyClassifier(torch.nn.Module):
    _THRESHOLD = 1e-6
    
    def __init__(self, n_classes: int):
        super(CrossEntropyClassifier, self).__init__()
        self.n_classes = n_classes
      
    @final
    def add_threshold(self, prob: Tensor, batch_size: int) -> Tensor:
        """
        Check whether the probability distribution after the softmax 
        corresponds to 0. If this holds, a standard threshold of 1e-6
        is added in order to avoid log(0) case. 

        Args:
            prob (Tensor): output tensor of the softmax operation
            batch_size (int): batch size of the input 

        Returns:
            Tensor: updated tensor (in case the condition above holds)
        """
        zeros = (prob.sum(dim=1) == 0) 
        prob_sum = prob.sum(dim=1)
        if any(zeros):
            thre_tensor = torch.zeros(batch_size)
            thre_tensor[zeros] = self._THRESHOLD
            prob_sum += thre_tensor
        return prob_sum
      
    @final
    def preprocess(self, input: Tensor, split=False, take_first=True) -> Tensor:
        """
        Transforms the output vector of any classifier applying the softmax operation
        along the first dimension. Further transformations are performed depending on
        'split' and 'take_first' values. 
        
        Args:
            input (Tensor): output vector of any classifier 
            split (bool, optional): whether to consider 2K neurons (constructed classifier C^{st}) or not. Defaults to False.
            take_first (bool, optional): whether to take the first K neurons (source classifier) 
                                        or the last K neurons (target classifier). Defaults to True.
        Returns:
            Tensor: final probabilities usable as input for the cross-entropy loss function
        """
        batch_size = input.size(0)
        prob = F.softmax(input, dim=1)
        fw_prob = prob if not split else (prob[:,:self.n_classes] if take_first else prob[:,self.n_classes:])
        return self.add_threshold(fw_prob, batch_size)
      
    @final
    def cross_entropy_loss(self, prob_sum: Tensor):
        return -(prob_sum.log().mean())

class SourceClassifier(CrossEntropyClassifier):
    """
    Task classifier on the source domain. 
    It considers the first k neurons (i.e., n_classes)
    """
    def __init__(self, n_classes: int):
        super(SourceClassifier, self).__init__(n_classes)
        
    @overrides
    def forward(self, input: Tensor):
        prob_sum = self.preprocess(input, split=True, take_first=True)
        return self.cross_entropy_loss(prob_sum)
    
class TargetClassifier(CrossEntropyClassifier):
    """
    Task classifier on the target domain. 
    It considers the last k neurons (i.e., n_classes)
    """
    def __init__(self, n_classes: int):
        super(TargetClassifier, self).__init__(n_classes)
        
    @overrides
    def forward(self, input: Tensor):
        prob_sum = self.preprocess(input, split=True, take_first=False)
        return self.cross_entropy_loss(prob_sum)