## Imports and Preliminary

In [1]:
import math
import torch
import torch.nn.functional as F

from torch import Tensor
from torchvision import models
from torch.autograd import Variable
from torch.nn import CrossEntropyLoss
from torch.optim import RMSprop, Adagrad
from overrides import overrides, final
from abc import abstractmethod

  from .autonotebook import tqdm as notebook_tqdm


## Training function

The SymNet architecture presents an overall training function based on compositionality. Hence, different modules are embedded, as reported below following the original implementation of [Zhang](https://arxiv.org/abs/1904.04663) <i>et al.</i> (see § 3.3). 

$$ \displaystyle\min_{C^s, C^t, C^{st}} \large \mathcal{E}_{task}^{(s)}(G, C^{(s)}) +  \large\mathcal{E}_{task}^{(t)}(G, C^{(t)}) + \large\mathcal{E}_{domain}^{(st)}(G, C^{(st)})$$
$$\displaystyle\min_{G} \large \mathcal{F}_{category}^{(st)}(G, C^{(st)}) + \lambda [\large \mathcal{F}_{domain}^{(st)}(G, C^{(st)}) + \large \mathcal{M}^{(st)}(G, C^{(st)})]$$

In order to reproduce the original implementation and its multiple losses, a tree structure has been chosen. The root is represented by the class <code>_Loss</code>, which inherits from the <code>[torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)</code> and has two subclasses:
<code>EntropyMinimizationLoss</code> and <code>SplitLoss</code>. The former is used to implement the <i>Entropy Minimization Principle</i> (§ 3.2.1 [Zhang](https://arxiv.org/abs/1904.04663) <i>et al.</i>), whereas the latter is further subdefined in modules, as depicted below:

```
_Loss
│
├── EntropyMinimizationLoss
│ 
└── SplitLoss
    │   
    └── SplitCrossEntropyLoss
    |
    └── DomainDiscriminationLoss
```

In [2]:
class _Loss(torch.nn.Module):
    
    _THRESHOLD = 1e-20
    
    def __init__(self):
        super(_Loss, self).__init__()
        
    def forward(self, input: Tensor):
        prob = self.to_softmax(input)
        return self.loss(prob)
        
    @final
    def add_threshold(self, prob: Tensor):
        '''
        Check whether the probability distribution after the softmax 
        is equal to 0 in any cell. If this holds, a standard threshold
        is added in order to avoid log(0) case. 

        Parameters
        ----------
        prob: Tensor
            output tensor of the softmax operation

        Returns
        -------
        Tensor
            updated tensor (in case the condition above holds)
        '''
        zeros = (prob == 0)
        if torch.any(zeros):
            thre_tensor = torch.zeros(zeros.shape)
            thre_tensor[zeros] = self._THRESHOLD
            prob += thre_tensor
        return prob
    
    def to_softmax(self, features: Tensor):
        '''
        Apply the softmax operation on the features tensor, 
        being the output of a feature extractor. 
        
        Parameters
        ----------
        features: Tensor
            input tensor of the softmax operation

        Returns
        -------
        Tensor
            probability distribution with (possible) threshold
        '''
        prob = F.softmax(features, dim=1)
        return self.add_threshold(prob)
    
    @abstractmethod
    def loss(self, prob: Tensor):
        pass

### Entropy Minimization Principle


The Entropy Minimization objective is here ([Zhang](https://arxiv.org/abs/1904.04663) <i>et al.</i>, § 3.2.1) adopted as regularizer. Specifically, it updates the feature extractor (<i>G</i>) and enhances the discrimination among task categories. This avoids having target samples stucked into wrong category predictions during the early training stages.

$$\displaystyle\min_{G}\mathcal{M}^{(st)}(G, C^{(t)}) = -\frac{1}{n_t}\sum_{j=1}^{n_t}\sum_{k=1}^{K}q_k^{(st)}(x_j^{(t)}log\bigg(q_k^{(st)}(x_j^{(t)})\bigg)$$
<br /><br />
<b>TODO:</b> Add formulas and check this all again --> REGULARIZER

In [3]:
class EntropyMinimizationLoss(_Loss):
    
    def __init__(self, n_classes: int):
        super(EntropyMinimizationLoss, self).__init__()
        self.n_classes = n_classes
    
    @overrides
    def loss(self, prob: Tensor):
        prob_source = prob[:, :self.n_classes]
        prob_target = prob[:, self.n_classes:]
        prob_sum = prob_source + prob_target
        return -(prob_sum.log().mul(prob_sum).sum(dim=1).mean())

### Cross-Entropy Based Losses

The losses contained in the overall training objective described aforeahead are implemented starting from the class <code>_CrossEntropyLoss_</code>. Thus, they are either normal cross entropy loss or a combination of two losses (i.e., <i>two-way cross entropy loss</i>). Specifically, two classifiers have been implemented to solve the classification task on the <i>source</i> ($ C^{(s)} $) and <i>target</i> ($ C^{(t)} $) domain. A <code>task classifier</code>, $ C^{(s)} $, is trained using the following cross-entropy loss over the <i>labeled</i> source samples: <br />
$$ \displaystyle\min_{C^{(s)}}\mathcal{E}_{task}^{(s)}(G, C^{(s)})=-\frac{1}{n_{s}}\sum_{i=1}^{n_{s}}log\bigg(p_{y_{i}^{(s)}}^{(s)}({\bf x}_{i}^{(s)})\bigg) $$
<br />In the formula above, $G$ represents the <code>feature extractor</code>, ${\bf x}_{i}^{(s)}$ the output vector of $ C^{(s)} $, and $p_{y_{i}^{(s)}}^{(s)}({\bf x}_{i}^{(s)}) \in [0,1]^{K}$ the distribution of probability after the <code>  [softmax](https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html)</code>  operation. 

Since target samples are <i>unlabeled</i>, there exists no direct
supervision signals to learn a task classifier $ C^{(t)} $. Therefore, [Zhang](https://arxiv.org/abs/1904.04663) <i>et al.</i> leverage the <i>labeled</i> source samples by using the following cross-entropy loss: 
$$ \displaystyle\min_{C^{(t)}}\mathcal{E}_{task}^{(t)}(G, C^{(t)})=-\frac{1}{n_{s}}\sum_{i=1}^{n_{s}}log\bigg(p_{y_{i}^{(s)}}^{(t)}({\bf x}_{i}^{(s)})\bigg) $$
<br />It is worth noticing that $C^{(t)}$ will be distinguishable from $C^{(s)}$ through the domain discrimination training of the classifier $C^{(st)}$, reported below. Moreover, [Zhang](https://arxiv.org/abs/1904.04663) <i>et al.</i> stress the use of <i>labeled</i> source samples to enhance $C^{(t)}$ performance in discriminating among task categories.

$$ \displaystyle\min_{C^{(st)}}\mathcal{E}_{domain}^{(st)}(G, C^{(t)})=-\frac{1}{n_{t}}\sum_{j=1}^{n_t}log\bigg(\sum_{k=1}^{K} p_{k+K}^{(st)}({\bf x}_{j}^{(t)})\bigg)-\frac{1}{n_s}\sum_{i=1}^{n_s}log\bigg(\sum_{k=1}^{K}p_{k}^{(st)}({\bf x}_{i}^{(s)})\bigg) $$

Furthermore a two-level confusion loss is applied. Whereas a first loss is category level and relies on the source labels, the second loss is domain level and focuses on the target. The formula are respectely: 

$$ \displaystyle\min_{G}\mathcal{F}_{category}^{(st)}(G, C^{(t)})=-\frac{1}{2n_{s}}\sum_{i=1}^{n_s}log(p_{y_i^s+K}^{(st)}({\bf x}_{i}^{(s)}))-\frac{1}{2n_s}\sum_{i=1}^{n_s}log(p_{y_i^s}^{(st)}({\bf x}_{i}^{(s)})) $$

$$ \displaystyle\min_{G}\mathcal{F}_{domain}^{(st)}(G, C^{(t)})=-\frac{1}{2n_{t}}\sum_{j=1}^{n_t}log\bigg(\sum_{k=1}^{K} p_{k+K}^{(st)}({\bf x}_{j}^{(t)})\bigg)-\frac{1}{2n_t}\sum_{j=1}^{n_t}log\bigg(\sum_{k=1}^{K}p_{k}^{(st)}({\bf x}_{j}^{(t)})\bigg) $$




<b>TODO:</b> Add the two-way cross-entropy loss with definition of domain discrimination and confusion

In [4]:
class SplitLoss(_Loss):
    
    def __init__(self, n_classes: int, source: bool, split_first: bool):
        super(SplitLoss, self).__init__()
        self.n_classes = n_classes
        self._is_source = source
        self._split_first = split_first
    
    @overrides
    def to_softmax(self, features: Tensor):
        if self._split_first:
            prob = self.split_vector(features)
            prob = F.softmax(prob, dim=1)
        else:
            prob = F.softmax(features, dim=1)
            prob = self.split_vector(prob)
        return self.add_threshold(prob)
    
    @final
    def split_vector(self, prob: Tensor):
        return prob[:,:self.n_classes] if self._is_source else prob[:,self.n_classes:]

In [5]:
class SplitCrossEntropyLoss(SplitLoss):
    
    def _get_y_labels(self):
        return self._y_labels
    def _set_y_labels(self, y_labels: Variable):
        if not all(y < self.n_classes for y in y_labels):
            raise ValueError('Expected all y labels < n_classes')
        self._y_labels = y_labels
    y_labels = property(fget=_get_y_labels, fset=_set_y_labels)
    
    def __init__(self, n_classes: int, source: bool, split_first: bool):
        super(SplitCrossEntropyLoss, self).__init__(n_classes, source, split_first)
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
    
    @overrides
    def loss(self, prob: Tensor):
        '''Computes cross-entropy loss w.r.t. ground-truth (y label)'''
        return self.cross_entropy_loss(prob, self.y_labels)

In [6]:
class DomainDiscriminationLoss(SplitLoss):
    
    def __init__(self, n_classes: int, source: bool):
        super(DomainDiscriminationLoss, self).__init__(n_classes, source, False)
        
    @overrides
    def loss(self, prob: Tensor):
        return -(prob.sum(dim=1).log().mean())

Accordingly to the [Zhang](https://arxiv.org/abs/1904.04663) <i>et al.</i> (see § 3.3) definition of the overall model's objective, the class <code>TrainingObjectives</code> recall the aforementioned losses. Moreover, the $\lambda$ trade-off parameter is introduced to suppress noisy signals for domain confusion loss and entropy. $\lambda$ value depends on the number of epochs, as it is interatively computer through the following formula:
$$\lambda = \frac{2}{1 + e^{(- \gamma \cdot \frac{ep}{n_{ep}})}} - 1$$
where $\gamma$ is usually set to $10$ ([Zhang](https://arxiv.org/abs/1904.04663) <i>et al.</i>, see § 4.1) and $\frac{ep}{n_{ep}}$ is iteratively updated at each epoch, as it represents the current epoch over the total. Therefore, $\lambda$ parameter will start from $0$ and gradually increase (i.e., $\displaystyle\lim_{ep \rightarrow n_{ep}}\lambda = 1$). Thus, the penalty of $\lambda$ on domain confusion loss and entropy decreases over time. 

In [7]:
class TrainingObjectives:
    
    @staticmethod
    def domain_discrimination_loss(src_dom_discrim_loss, tgt_dom_discrim_loss):
        return src_dom_discrim_loss + tgt_dom_discrim_loss
    
    @staticmethod
    def category_confusion_loss(src_cat_conf_loss, tgt_cat_conf_loss):
        return 0.5 * (src_cat_conf_loss + tgt_cat_conf_loss)
    
    @staticmethod
    def domain_confusion_loss(src_dom_conf_loss, tgt_dom_conf_loss):
        return 0.5 * (src_dom_conf_loss + tgt_dom_conf_loss)
    
    @staticmethod
    def overall_classifier_loss(src_task_class_loss, tgt_task_class_loss, domain_discrim_loss):
        return src_task_class_loss + tgt_task_class_loss + domain_discrim_loss
    
    @staticmethod
    def overall_generator_loss(cat_conf_loss, dom_conf_loss, tgt_entropy_loss, curr_epoch, tot_epochs):
        lambda_trade_off = 2 / (1 + math.exp(-1 * 10 * curr_epoch / tot_epochs)) - 1
        return cat_conf_loss + lambda_trade_off * (dom_conf_loss + tgt_entropy_loss)

Examples testing the correctness of the class implemented above:

In [8]:
# Example of batch size and num of classes
batch_size = 2
num_classes = 3

# Example of model's outcome after batch of inputs from source domain
# X_source_features = torch.randn(batch_size, num_classes * 2)
X_source_features = torch.tensor(
    [[-1.3382,  0.6833,  1.3363, -0.0465,  0.8953, -1.4505],
    [ 0.2133, -1.5612, -1.6918, -1.9907,  0.9956, -0.2287]])

# Example of model's outcome after batch of inputs from target domain
# X_target_features = torch.randn(batch_size, num_classes * 2)
X_target_features = torch.tensor(
    [[ 0.3109, -1.7531, -1.5460,  0.3308,  1.3116, -0.3035],
    [ 0.2777, -0.2101,  0.1629,  1.6425,  0.8126,  0.5605]])

# Example of labels associated with batch of inputs from source domain
y_source_labels = torch.tensor([2, 0])
y_source_labels_var = Variable(y_source_labels)

In [9]:
# Source Task Classifier Loss
source_task_class_loss = SplitCrossEntropyLoss(n_classes=3, source=True, split_first=True)
source_task_class_loss.y_labels = y_source_labels_var
_src_task_class_loss = source_task_class_loss(X_source_features)

# (Cross-Domain) Target Task Classifier Loss
target_task_class_loss = SplitCrossEntropyLoss(n_classes=3, source=False, split_first=True)
target_task_class_loss.y_labels = y_source_labels_var
_tgt_task_class_loss = target_task_class_loss(X_source_features)

# Domain Discrimination Loss
source_dom_discrim_loss = DomainDiscriminationLoss(n_classes=3, source=True)
target_dom_discrim_loss = DomainDiscriminationLoss(n_classes=3, source=False)
_src_dom_discrim_loss = source_dom_discrim_loss(X_source_features)
_tgt_dom_discrim_loss = target_dom_discrim_loss(X_target_features)
_domain_discrim_loss = TrainingObjectives.domain_discrimination_loss(_src_dom_discrim_loss, _tgt_dom_discrim_loss)

# Category-level Confusion Loss
source_cat_conf_loss = SplitCrossEntropyLoss(n_classes=3, source=True, split_first=False)
target_cat_conf_loss = SplitCrossEntropyLoss(n_classes=3, source=False, split_first=False)
source_cat_conf_loss.y_labels = y_source_labels_var
target_cat_conf_loss.y_labels = y_source_labels_var
_src_cat_conf_loss = source_cat_conf_loss(X_source_features)
_tgt_cat_conf_loss = target_cat_conf_loss(X_source_features)
_category_conf_loss = TrainingObjectives.category_confusion_loss(_src_cat_conf_loss, _tgt_cat_conf_loss)

# Domain-level Confusion Loss
source_dom_conf_loss = DomainDiscriminationLoss(n_classes=3, source=True)
target_dom_conf_loss = DomainDiscriminationLoss(n_classes=3, source=False)
_src_dom_conf_loss = source_dom_conf_loss(X_target_features)
_tgt_dom_conf_loss = target_dom_conf_loss(X_target_features)
_domain_conf_loss = TrainingObjectives.domain_confusion_loss(_src_dom_conf_loss, _tgt_dom_conf_loss)

# Entropy Minimization Principle
target_entropy_loss = EntropyMinimizationLoss(n_classes=3)
_tgt_entropy_loss = target_entropy_loss(X_target_features)

# Overall Classifier Loss
_overall_classifier_loss = TrainingObjectives.overall_classifier_loss(_src_task_class_loss, _tgt_task_class_loss, _domain_discrim_loss)
print('Overall Classifier Loss =', _overall_classifier_loss)

# Overall Feature Extractor Loss
_overall_generator_loss = TrainingObjectives.overall_generator_loss(_category_conf_loss, _domain_conf_loss, _tgt_entropy_loss, 432, 1000)
print('Overall Generator Loss =', _overall_generator_loss)

Overall Classifier Loss = tensor(3.3036)
Overall Generator Loss = tensor(2.8974)


## Feature Extractor (<i>G</i>) - Resnet18

<b>TODO:</b> Describe feature extractor, how it's built up and how it works

In [10]:
class FeatureExtractor:
    
    def __init__(self, n_classes: int, n_layers_trained: int, model='resnet18', optimizer='rmsprop', lr=0.01, weight_decay=0):
        
        # Upload pretrained model 
        if model.lower() == 'resnet18': 
            self.model = models.resnet18(pretrained=True)
        elif model.lower() == 'resnet50': 
            self.model = models.resnet50(pretrained=True)
        else:
            raise ValueError('Unknown model')
        
        # Modify last fully-connected layer
        self.model.fc = torch.nn.Linear(
            in_features = self.model.fc.in_features, 
            out_features = n_classes * 2
        )
        
        # Freeze pretrained layers
        params = list(self.model.parameters())
        for i in range(len(params)):
            n_layers_frozen = len(params) - i - 1
            params[i].requires_grad = (n_layers_frozen < n_layers_trained)
        params_to_train = filter(lambda p: p.requires_grad, self.model.parameters())
        
        # Initialize optimizer
        if optimizer.lower() == 'rmsprop':
            self.optim = torch.optim.RMSprop(
                params = params_to_train,
                lr = lr,
                weight_decay = weight_decay
            )
        elif optimizer.lower() == 'adadelta':
            self.optim = torch.optim.Adadelta(
                params = params_to_train,
                lr = lr,
                weight_decay = weight_decay
            )
        elif optimizer.lower() == 'sgd':
            self.optim = torch.optim.SGD(
                params = params_to_train,
                lr = lr,
                weight_decay = weight_decay,
                nesterov = True
            )
        else:
            raise ValueError('Unknown optimizer')
        

## Model Trainer

In [12]:
class ModelTrainer:
    
    def __init__(self, model: FeatureExtractor, n_classes: int, epochs: int):
        self.model = model 
        self.curr_epoch = 0
        self.tot_epochs = epochs
        self.n_classes = n_classes
        # Task classifier losses
        self.src_task_class_loss = SplitCrossEntropyLoss(n_classes=n_classes, source=True, split_first=True).cuda()
        self.tgt_task_class_loss = SplitCrossEntropyLoss(n_classes=n_classes, source=False, split_first=True).cuda()
        # Domain discrimination losses
        self.src_dom_discrim_loss = DomainDiscriminationLoss(n_classes=n_classes, source=True).cuda()
        self.tgt_dom_discrim_loss = DomainDiscriminationLoss(n_classes=n_classes, source=False).cuda()
        # Category-level confusion losses
        self.src_cat_conf_loss = SplitCrossEntropyLoss(n_classes=n_classes, source=True, split_first=False).cuda()
        self.tgt_cat_conf_loss = SplitCrossEntropyLoss(n_classes=n_classes, source=False, split_first=False).cuda()
        # Domain-level confusion losses
        self.src_dom_conf_loss = DomainDiscriminationLoss(n_classes=n_classes, source=True).cuda()
        self.tgt_dom_conf_loss = DomainDiscriminationLoss(n_classes=n_classes, source=False).cuda()
        # Entropy minimization loss
        self.tgt_entropy_loss = EntropyMinimizationLoss(n_classes=n_classes).cuda()
        
        
    def train_one_epoch(self, source_dataloader, target_dataloader):
        self.curr_epoch += 1
        end_of_epoch = False
        source_batch_loader = enumerate(source_dataloader)
        target_batch_loader = enumerate(target_dataloader)
        
        # Train for current epoch
        while not end_of_epoch:
            try:
                # Get next batch for both source and target
                (X_source, y_source) = source_batch_loader.__next__()[1]
                (X_target, _) = target_batch_loader.__next__()[1]
            except StopIteration:
                end_of_epoch = True
                continue
            
            # Tell model go training mode
            self.model.model.train()
            
            # Convert to torch.autograd variables
            X_source_var = Variable(X_source) 
            y_source_var = Variable(y_source)
            X_target_var = Variable(X_target)
            
            # Compute features for both inputs
            X_source_features = self.model.model(X_source_var)
            X_target_features = self.model.model(X_target_var)
            
            # Compute overall training objective losses
            classifier_loss, generator_loss = self.overall_losses(
                X_source_features, 
                X_target_features, 
                y_source_var
            )
            
            self.model.optim.zero_grad()
            classifier_loss.backward(retain_graph=True)  
            grad_classifier_tmp = [param.grad.data.clone() for param in self.model.model.parameters()]
            
            self.model.optim.zero_grad()
            generator_loss.backward()
            grad_generator_tmp = [param.grad.data.clone() for param in self.model.model.parameters()]
            
            count = 0 
            for p in self.model.model.parameters():
                grad_tmp = p.grad.data.clone().zero_() 
                if count < 159: # FIXME: capire perché 159
                    grad_tmp += grad_generator_tmp[count]
                else: 
                    grad_tmp += grad_classifier_tmp[count]
                p.grad.data = grad_tmp 
                count += 1 
            self.model.optim.step()

    def overall_losses(self, X_source_features, X_target_features, y_source_var):
        # Source task classifier loss
        self.src_task_class_loss.y_labels = y_source_var
        _src_task_class_loss = self.src_task_class_loss(X_source_features)
        
        # (Cross-domain) Target task classifier loss
        self.tgt_task_class_loss.y_labels = y_source_var
        _tgt_task_class_loss = self.tgt_task_class_loss(X_source_features)
        
        # Domain discrimination loss
        _src_dom_discrim_loss = self.src_dom_discrim_loss(X_source_features)
        _tgt_dom_discrim_loss = self.tgt_dom_discrim_loss(X_target_features)
        _domain_discrim_loss = TrainingObjectives.domain_discrimination_loss(
            _src_dom_discrim_loss, 
            _tgt_dom_discrim_loss
        )
        
        # Category-level confusion loss
        self.src_cat_conf_loss.y_labels = y_source_var
        self.tgt_cat_conf_loss.y_labels = y_source_var
        _src_cat_conf_loss = self.src_cat_conf_loss(X_source_features)
        _tgt_cat_conf_loss = self.tgt_cat_conf_loss(X_source_features)
        _category_conf_loss = TrainingObjectives.category_confusion_loss(
            _src_cat_conf_loss, 
            _tgt_cat_conf_loss
        )
        
        # Domain-level confusion loss
        _src_dom_conf_loss = self.src_cat_conf_loss(X_target_features)
        _tgt_dom_conf_loss = self.tgt_cat_conf_loss(X_target_features)
        _domain_conf_loss = TrainingObjectives.domain_confusion_loss(
            _src_dom_conf_loss, 
            _tgt_dom_conf_loss
        )

        # Entropy minimization loss
        _tgt_entropy_loss = self.tgt_entropy_loss(X_target_features)
        
        # Overall classifier loss
        _overall_classifier_loss = TrainingObjectives.overall_classifier_loss(
            _src_task_class_loss, 
            _tgt_task_class_loss, 
            _domain_discrim_loss
        )

        # Overall feature extractor loss
        _overall_generator_loss = TrainingObjectives.overall_generator_loss(
            _category_conf_loss, 
            _domain_conf_loss, 
            _tgt_entropy_loss, 
            self.curr_epoch, 
            self.tot_epochs
        )
        
        return _overall_classifier_loss, _overall_generator_loss

[backward specification](https://stackoverflow.com/questions/46774641/what-does-the-parameter-retain-graph-mean-in-the-variables-backward-method)
