<a href="https://colab.research.google.com/github/jiahfong/incoherent-thoughts/blob/develop/Consistency_based_SSAL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Problem setting
Unlabelled data mainly used for selection mechanisms, but rarely for model training in coventional pool-based AL methods

# 2. Contributions
* Exploit both labelled and unlabelled data by using AL + SSL
* Propose a new selection metric to choose points in a way that is consistent with the SSL training objective function such that the selected points are effective in improving model performance.
* Study an important question: when can we start AL. (cold-start, burn-in)
    - propose a measure that is empirically correlated with the AL loss
    - in the absence of labelled data, a subset is chosen to be manually labelled, thereby initiating the AL training cycle. However, if the subset is too small, the models in subsequent AL cycles are highly-skewed and results in *biased selections* -- a phenomenon known as the *cold start problem*.
    - the trivial solution is to simply increase the size of this subset, but that implies increasing the labelling budget.
    - using better understanding of the data, the author proposed a method to relieve the cold-start problem whilst minimising the size of the initial subset.
    - nonetheless, one still has to determine the start size, and determining the proper initial subset size to avoid cold-start is not trivial. They proposed a measure that has empirically shown to be helpful in estimating the proper start size.

# 3. Consistency-based semi-supervised active learning

A model should be consistent in its predictions about a sample and its meaningful-distortions. I.e. if a sample and its distortions generate inconsistent predictions, then it should be acquired as a point for human labelling.

1. for i = 1 $\dots$ $T$:
2. &nbsp;&nbsp;&nbsp;&nbsp; train model $\mathcal{M}$ using $\mathcal{D}_{train}$ with objective function $\mathcal{L}$
3. &nbsp;&nbsp;&nbsp;&nbsp; take the top $b$ highest scoring points: $B = x_1, \dots, x_b \in \mathcal{D}_{pool}$ according to the scoring function $B = \text{argmax}_{B}\ \mathcal{C}(B, \mathcal{M})$
4. &nbsp;&nbsp;&nbsp;&nbsp; $\mathcal{D}_{train} = \mathcal{D}_{train} \cup B$
5. &nbsp;&nbsp;&nbsp;&nbsp; $\mathcal{D}_{pool} = \mathcal{D}_{pool} \setminus B$

* $T$ = number of iterations
* $\mathcal{L} = \frac{1}{|\mathcal{D}_{train}|}\sum_{(x,y)\in\mathcal{D}_{train}} \mathcal{L}_1(\mathcal{M}(x), y) + \frac{1}{|\mathcal{D}_{pool}|} \sum_{x \in \mathcal{D}_{pool}}\mathcal{L}_2(\mathcal{M}, x)$ where:
    - $\mathcal{L}_1$ is your regular loss (e.g. cross-entropy)
    - $\mathcal L_2$ could be either KL divergence ($D_{KL}\left(P(\hat Y | x, \mathcal{M})\ ||\ P(\hat Y | \tilde{x}, \mathcal{M})\right)$) or L2 norm ($||P(\hat Y | x, \mathcal{M}) - P(\hat Y | \tilde{x}, \mathcal{M})||_{2}$) where $\tilde{x}$ is a perturbation of $x$.
    - Note, the authors used $\mathcal{L}$ = Mixmatch (Berthelot et al., 2019) in their experiments.
* $\mathcal{C}(B, \mathcal{M}) = \sum_{x \in B} \sum_{y = 1}^{C} \text{Var}\left[ P(\hat Y = y | x, \mathcal{M}), P(\hat Y = y | \tilde{x}_1, \mathcal{M}), \dots, P(\hat Y = y | \tilde{x}_N, \mathcal{M})\right]$ where $N$ is the number of perturbations to $x$.

Points to note:

1. To minimise $\mathcal{L}$, it is necessary for the model to be robust to perturbed inputs from $\mathcal{D}_{pool}$. Since $\mathcal{C}$ takes the top $b$ points that are most inconsistent when perturbed, the acquisition function is *directly* trying to lower the total loss $\mathcal{L}$!
2. Taking the argmax in line 3 reduces to simply taking the top $b$ points since they trivially maximuse the summation of $\mathcal{C}(B, \mathcal{M})$.
3. Perturbations here include standard augmentation techniques like random crops, horizontal flips, etc.
4. Perhaps interestingly, in each iteration, model $\mathcal{M}$ is re-initialised to the trained weights of its previous iteration; this is different from BatchBALD and ICAL: they both re-initialise the model to prevent correlations between acquired batches.

# 4. In practice

1. $\mathcal{L}$ adopts the recently-proposed SOTA, Mixmatch
2. $\mathcal{M}_t$ is initialised to $\mathcal{M}_{t-1}$ (cf. BatchBALD & ICAL)
3. $N$ is usually 50, but authors observed that 5 is usually enough

# 5. Results & Experiments

The results outlined here are specifically for CIFAR-10, although the authors did it on both 10 and 100.

## 5.1 With existing AL acquisition functions w/o SSL

Bested _Uniform_ (random ac.), _Entropy_, and _k-centre_ (Sener & Savarese, 2018; maximising distance to nearest neighbour in the labelled pool, using the last layer as embedding) significantly.

## 5.2 Integrating existing AL acquisition functions with SSL framework

Figure 1 in the paper shows that their method outperforms all baselines when they integrated the 3 methods above with consistency-based SSL. Unlike before, the improvements are not as pronounced (~2% difference).

> More importantly, this illustrates that just by introducing SSL into regular training, the improvements are significant: ~50% accuracy to ~88% accuracy for all three acquisition functions!

# 6. Discussion

The authors pointed out that the following attributes are known to be important when designing acquisition functions: diversity, uncertainty & confident mis-classification, and compliancy between acquired class distribution and per-class classification error.

## 6.1 Diversity

_k-centre_, for example, tries to acquire points that cover the entire input space. They visualised the entire unlabelled pool in a 2-D PCA plot and showed that the points acquired by _consistency_ (their method) are as diverse as those acquired by _k-centre_. These points are spread out in the plot, indicating that it's a diverse batch.

> Upshot: their method acquires a diverse batch, even when the selection criterion function $\mathcal{C}$ does not explicitly require it to do so. Meanwhile, _entropy_ seems to acquire similar points.

Interestingly (if interpreted correctly), Figure 3 (left) shows the average pairwise distances between acquired points in a batch. It's calculated using L2 distance (which is probably _not_ an ideal choice for images? Semantically similar images can have large L2 distances anyway; showing that metrics acquire highly distant points does not imply diversity)

## 6.2 Uncertainty & confident mis-classification

The authors argued that raw softmax probabilities tend to be poorly calibrated: a model can be highly confident even when it is wrong. This implies that entropy based methods will pick suboptimal points.

They showed that their method is superior in detecting highly-confident misclassications (compared to _entropy_). They also showed that _k-centre_ and _uniform_ do not acquire points based on uncertainty at all (which makes sense, duh).

> Their method tends to select highly-uncertain samples but not necessarily the top ones (afterall, their method does not explicitly seek to do so).



## 6.3 Compliancy between acquired class distribution and per-class classfication error

This is an intuitive idea: the points that should be acquired by the acquisition function are likely to be the points where the model is currently classifiying incorrectly. If we plot the class distribution of the acquired points, it should have roughly the same pattern as the distribution of misclassified target classes. Afterall, those points are causing our model to incur higher loss.

> Samples acquired by _entropy_ and _consistency_ are correlated with per-class classification error whilst _k-centre_ and _uniform_ show less correlation.

### 6.4 Summary of discussion

The claims above can be summarised as:

_consistency_ is similar to _entropy_ in the sense that it's acquiring highly-uncerntain samples and is similar to _k-centre_ in that it's acquiring a diverse set of points, despite not explicitly told to do so. Lastly, it intuitively picks samples with target classes that the model is mostly misclassifying on, much like _entropy_.

# TODO: When to start AL?


# TODO
1. The consistency-based SSL algorithm here doesn't actually assign labels to the unlabelled pool. It depends entirely on the complementary acquisition function to obtain true labels. (cf. noisy-student SSL or self-training with pseudo-labels (p. rhee et al. 2017))
2. Discussion (talk about desired properties of AL algorithms and how this fufils most of them)
    - esp. diversity. They claim to acquire a diverse batch, but is this true in repetitive pools? Given that BALD didn't fare so well, it's likely that this would suffer from the same problem since the summation in $\mathcal{C}$ does not account for previously acquired points.

> see their 2-moons plot: it's a nice illustration on why random acquisition works well initially (esp. when inital training dataset is small)


# Closing remarks/thoughts

1. The authors mentioned that in using softmax probabilities, entropy-based methods are sub-optimal since the probabilities may be overconfident. Does this hold true even in the bayesian setting. For example, $p(y | x, \mathcal{D}_{train}) = \int p(y | x, \mathbf \omega)p(\mathbf\omega | \mathcal{D}_{train}) d\mathbf\omega$? Did the authors consider using this when calculating entropies?
2. Figure 3 (left) in the paper shows the average pairwise distances between acquired points in a batch. It's calculated using L2 distance, which is probably _not_ an ideal choice for images? Semantically similar images can have large L2 distances anyway; showing that metrics acquire highly distant points does not imply diversity.