## Imports and Preliminary

In [2]:
import torch
import torch.nn.functional as F


from torch import Tensor, FloatTensor
from overrides import overrides, final

## Classifiers

Task classifier $ C^{(s)} $ trained using the following cross-entropy loss over the labeled source samples: <br /><br />
$$ min_{C^{(s)}}\mathcal{E}_{task}^{(s)}(G, C^{(s)})=-\frac{1}{n_{s}}\sum_{i=1}^{n_{s}}log(p_{y_{i}^{(s)}}^{(s)}(\bf{x}_{i}^{(s)})) $$

Since target samples are unlabeled, there exist no direct
supervision signals to learn a task classifier $ C^{(t)} $. <br /> Zhang <i>et al.</i> leverage the labeled source samples, and use following cross-entropy loss: <br /><br />
$$ min_{C^{(t)}}\mathcal{E}_{task}^{(t)}(G, C^{(t)})=-\frac{1}{n_{s}}\sum_{i=1}^{n_{s}}log(p_{y_{i}^{(s)}}^{(t)}(\bf{x}_{i}^{(s)})) $$



In [None]:
class CrossEntropyClassifier(torch.nn.Module):
    _THRESHOLD = 1e-6
    
    def __init__(self, n_classes: int):
        super(CrossEntropyClassifier, self).__init__()
        self.n_classes = n_classes
      
    @final
    def add_threshold(self, prob: Tensor, batch_size: int) -> Tensor:
        zeros = (prob.sum(dim=1) == 0)
        prob_sum = prob.sum(dim=1)
        if any(zeros):
            thre_tensor = FloatTensor(batch_size)._fill(0)
            thre_tensor[zeros] = self._THRESHOLD
            prob_sum += thre_tensor
        return prob_sum
      
    @final
    def preprocess(self, input: Tensor, split=False, take_first=True) -> Tensor:
        batch_size = input.size(0)
        prob = F.softmax(input, dim=1)
        fw_prob = prob if not split else (prob[:,:self.n_classes] if take_first else prob[:,self.n_classes:])
        return self.add_threshold(fw_prob, batch_size)
      
    @final
    def cross_entropy_loss(self, prob_sum: Tensor):
        return -(prob_sum.log().mean())

class SourceClassifier(CrossEntropyClassifier):
    def __init__(self, n_classes: int):
        super(SourceClassifier, self).__init__(n_classes)
        
    @overrides
    def forward(self, input: Tensor):
        prob_sum = self.preprocess(input, split=True, take_first=True)
        return self.cross_entropy_loss(prob_sum)
    
class TargetClassifier(CrossEntropyClassifier):
    def __init__(self, n_classes: int):
        super(TargetClassifier, self).__init__(n_classes)
        
    @overrides
    def forward(self, input: Tensor):
        prob_sum = self.preprocess(input, split=True, take_first=False)
        return self.cross_entropy_loss(prob_sum)