## Imports and Preliminary

In [3]:
import torch
import torch.nn.functional as F

from torch import Tensor
from torchvision import models
from torch.optim import RMSprop, Adagrad
from overrides import overrides, final
from abc import abstractmethod

  from .autonotebook import tqdm as notebook_tqdm


## Training function

The SymNet architecture presents an overall training function based on compositionality. Hence, different modules are embedded, as reported below following the original implementation of [Zhang](https://arxiv.org/abs/1904.04663) <i>et al.</i>(see Section 3.3). 

$$ \displaystyle\min_{C^s, C^t, C^{st}} \large \mathcal{E}_{task}^{(s)}(G, C^{(s)}) +  \large\mathcal{E}_{task}^{(t)}(G, C^{(t)}) + \large\mathcal{E}_{task}^{(st)}(G, C^{(st)})$$
$$\displaystyle\min_{G} \large \mathcal{F}_{category}^{(st)}(G, C^{(st)}) + \lambda [\large \mathcal{F}_{domain}^{(st)}(G, C^{(st)}) + \large \mathcal{M}^{(st)}(G, C^{(st)})]$$

In order to reproduce the original implementation and its multiple losses, a tree structure has been chosen. The root is represented by the class <code>_Loss</code>, which inherits from the <code>[torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)</code> and has two submodules as children:
<code>_CrossEntropyLoss</code> and <code>_EntropyLoss</code>. The latter refers to the <i>Entropy Minimization Principle</i> (Section 3.2.1 [Zhang](https://arxiv.org/abs/1904.04663) <i>et al.</i>), whereas the former is further subdefined in modules:

```
_Loss
│
├── _EntropyLoss
│ 
└── _CrossEntropyLoss
    │   
    └── SplitCrossEntropyLoss
        |
        └── DomainDiscriminationLoss
```

In [4]:
class _Loss(torch.nn.Module):
    
    _THRESHOLD = 1e-20
    
    def __init__(self):
        super(_Loss, self).__init__()
        
    def forward(self, input: Tensor):
        prob = self.to_softmax(input)
        return self.loss(prob)
        
    @final
    def add_threshold(self, prob: Tensor):
        '''
        Check whether the probability distribution after the softmax 
        is equal to 0 in any cell. If this holds, a standard threshold
        is added in order to avoid log(0) case. 

        Parameters
        ----------
        prob: Tensor
            output tensor of the softmax operation

        Returns
        -------
        Tensor
            updated tensor (in case the condition above holds)
        '''
        zeros = (prob == 0)
        if torch.any(zeros):
            thre_tensor = torch.zeros(zeros.shape)
            thre_tensor[zeros] = self._THRESHOLD
            prob += thre_tensor
        return prob
    
    def to_softmax(self, features: Tensor):
        '''
        Apply the softmax operation on the features tensor, 
        being the output of a classifier. It returns the distribution 
        of probability withing the range [0,1]. 
        
        Parameters
        ----------
        features: Tensor
            input tensor of the softmax operation

        Returns
        -------
        Tensor
            probability distribution with (possible) threshold
        '''
        prob = F.softmax(features, dim=1)
        return self.add_threshold(prob)
    
    @abstractmethod
    def loss(self, prob: Tensor):
        pass

### Entropy Minimization Principle

The Entropy Minimization objective is here ([Zhang](https://arxiv.org/abs/1904.04663) <i>et al.</i>, Section 3.2.1) adopted to update the feature extractor (<i>G</i>) and to enhance the discrimination among task categories. This avoids having target samples stucked into wrong category predictions during the early training stages. 
#TODO: check this again!!

In [None]:
class _EntropyLoss(_Loss):
    
    def __init__(self):
        super(_EntropyLoss, self).__init__()
    
    @overrides
    def loss(self, prob: Tensor):
        return -(prob.log().mul(prob).sum(dim=1).mean())

### Cross Entropy loss(es)

Two classifiers have been implemented to solve the classification task on the <i>source</i> ($ C^{(s)} $) and <i>target</i> ($ C^{(t)} $) domain. The former <code>task classifier</code>, $ C^{(s)} $, is trained using the following cross-entropy loss over the <i>labeled</i> source samples: <br />
$$ min_{C^{(s)}}\mathcal{E}_{task}^{(s)}(G, C^{(s)})=-\frac{1}{n_{s}}\sum_{i=1}^{n_{s}}log\bigg(p_{y_{i}^{(s)}}^{(s)}({\bf x}_{i}^{(s)})\bigg) $$
<br />In the formula above, $G$ represents the <code>feature extractor</code>, ${\bf x}_{i}^{(s)}$ the output vector of $ C^{(s)} $, and $p_{y_{i}^{(s)}}^{(s)}({\bf x}_{i}^{(s)}) \in [0,1]^{K}$ the distribution of probability after the <code>  [softmax](https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html)</code>  operation. 

Since target samples are <i>unlabeled</i>, there exists no direct
supervision signals to learn a task classifier $ C^{(t)} $. Therefore, [Zhang](https://arxiv.org/abs/1904.04663) <i>et al.</i> leverage the <i>labeled</i> source samples by using the following cross-entropy loss: 
$$ min_{C^{(t)}}\mathcal{E}_{task}^{(t)}(G, C^{(t)})=-\frac{1}{n_{s}}\sum_{i=1}^{n_{s}}log\bigg(p_{y_{i}^{(s)}}^{(t)}({\bf x}_{i}^{(s)})\bigg) $$
<br />It is worth noticing that $C^{(t)}$ will be distinguishable from $C^{(s)}$ through the domain discrimination training of the classifier $C^{(st)}$. Moreover, [Zhang](https://arxiv.org/abs/1904.04663) <i>et al.</i> stress the use of <i>labeled</i> source samples to enhance $C^{(t)}$ performance in discriminating among task categories. 

In [None]:
class _CrossEntropyLoss(_Loss):
    
    def __init__(self):
        super(_CrossEntropyLoss, self).__init__()
    
    @overrides
    def loss(self, prob: Tensor):
        return -(prob.log().mean())

In [None]:
class SplitCrossEntropyLoss(_CrossEntropyLoss):
    
    def __init__(self, n_classes: int, source: bool):
        super(SplitCrossEntropyLoss, self).__init__()
        self.n_classes = n_classes
        self._is_source = source
    
    @overrides
    def to_softmax(self, features: Tensor):
        prob = F.softmax(features, dim=1)
        prob = self.split_softmax(prob)
        return self.add_threshold(prob)
    
    @final
    def split_softmax(self, prob: Tensor):
        return prob[:,:self.n_classes] if self._is_source else prob[:,self.n_classes:]

In [None]:
class DomainDiscriminationLoss(SplitCrossEntropyLoss):
    
    def __init__(self, n_classes: int, source: bool):
        super(DomainDiscriminationLoss, self).__init__(n_classes, source)
        
    @overrides
    def loss(self, prob: Tensor):
        return -(prob.sum(dim=1).log().mean())

In [None]:
source_dom_class_loss = DomainDiscriminationLoss(n_classes=1000, source=True)
target_dom_class_loss = DomainDiscriminationLoss(n_classes=1000, source=False)
domain_class_loss = source_dom_class_loss + target_dom_class_loss

source_task_class_loss = SplitCrossEntropyLoss(n_classes=1000, source=True)
target_task_class_loss = SplitCrossEntropyLoss(n_classes=1000, source=False)

source_dom_conf_loss = DomainDiscriminationLoss(n_classes=1000, source=True)
target_dom_conf_loss = DomainDiscriminationLoss(n_classes=1000, source=False)
domain_conf_loss = 0.5 * (source_dom_conf_loss + target_dom_conf_loss)

source_cat_conf_loss = SplitCrossEntropyLoss(n_classes=1000, source=True)
target_cat_conf_loss = SplitCrossEntropyLoss(n_classes=1000, source=False)
category_conf_loss = 0.5 * (source_cat_conf_loss + target_cat_conf_loss)

target_entropy_loss = _EntropyLoss()

## Feature Extractor (<i>G</i>) - Resnet18

In [None]:
class FeatureExtractor:
    
    def __init__(self, n_classes: int, n_layers_trained: int, model='resnet18', optimizer='rmsprop', lr=0.01, weight_decay=0):
        
        # Upload pretrained model 
        if model.lower() == 'resnet18': 
            self.model = models.resnet18(pretrained=True)
        elif model.lower() == 'resnet50': 
            self.model = models.resnet50(pretrained=True)
        else:
            raise ValueError('Unknown model')
        
        # Modify last fully-connected layer
        self.model.fc = torch.nn.Linear(
            in_features = self.model.fc.in_features, 
            out_features = n_classes * 2
        )
        
        # Freeze pretrained layers
        params = list(self.model.parameters())
        for i in range(len(params)):
            n_layers_frozen = len(params) - i - 1
            params[i].requires_grad = (n_layers_frozen < n_layers_trained)
        params_to_train = filter(lambda p: p.requires_grad, self.model.parameters())
        
        # Initialize optimizer
        if optimizer.lower() == 'rmsprop':
            self.optim = torch.optim.RMSprop(
                params = params_to_train,
                lr = lr,
                weight_decay = weight_decay
            )
        elif optimizer.lower() == 'adadelta':
            self.optim = torch.optim.Adadelta(
                params = params_to_train,
                lr = lr,
                weight_decay = weight_decay
            )
        elif optimizer.lower() == 'sgd':
            self.optim = torch.optim.SGD(
                params = params_to_train,
                lr = lr,
                weight_decay = weight_decay,
                nesterov = True
            )
        else:
            raise ValueError('Unknown optimizer')