# **DeiT (Data-Efficient Image Transformer)**

**Paper:** *Training data-efficient image transformers & distillation through attention* (ICML 2021, Facebook Research, 2020)

---

## Motivation

**Vision Transformer (ViT)** achieved strong results but required **huge datasets** (ImageNet-21k, JFT-300M) and **massive compute**.
**DeiT** made ViT trainable **on ImageNet-1k (1.3 M images)** — *no extra data* — through **data-efficient augmentation** and **distillation through attention**.

---

## ViT Recap

1. **Patch embedding**

   Split an image of size $H\times W\times C$ into patches $P\times P$; number of patches:
   $$
   N=\frac{H\times W}{P^2}
   $$
   Each patch → vector → $\mathbf X_p\in\mathbb R^{N\times D}$.

2. **Add tokens and positional embeddings**
   $$
   Z_0=[x_{\text{cls}},x_p^1,\dots,x_p^N]+E_{\text{pos}}
   $$

3. **Transformer encoder**

   $L$ blocks of MSA + FFN + residuals + layer norm.

4. **Classification head** uses final `[CLS]`.

---

## DeiT Architecture

Adds one more token — **[DIST]**:

$$
Z_0=[x_{\text{cls}},x_{\text{dist}},x_p^1,\dots,x_p^N]+E_{\text{pos}}
$$

Both `[CLS]` and `[DIST]` participate equally in attention through all layers.

### Output heads

After the final block:
$$
Z_L=[z_{\text{cls}},z_{\text{dist}},z_p^1,\dots,z_p^N]
$$



<img src="images/deit.png" height="30%" width="30%" />


## Training loss

$$
\mathcal{L}_{\text{DeiT}} = (1-\lambda)\mathcal{L}_{\text{CE}}\big(\sigma(Z_{\text{cls}}), y\big) + \lambda\tau^2\mathrm{KL}\left( \sigma\left(\frac{Z_{\text{dist}}}{\tau}\right), \sigma\left(\frac{Z_t}{\tau}\right) \right)
$$

---


| Symbol               | Meaning                                                         |
| :------------------- | :-------------------------------------------------------------- |
| $ Z_{\text{cls}} $   | logits from the **student’s `[CLS]` head**.                     |
| $ Z_{\text{dist}} $  | logits from the **student’s `[DIST]` head**.                    |
| $ Z_t $              | Teacher logits (from pretrained CNN like RegNetY)               |
| $ y $                | Ground-truth class label                                        |
| $ \sigma(\cdot) $    | Softmax function                                                |
| $ \mathcal{L}_{CE} $ | Cross-entropy loss with ground truth                            |
| $ KL(\cdot,\cdot) $  | Kullback–Leibler divergence (between probability distributions) |
| $ \tau $             | **Temperature** to soften the logits                            |
| $ \lambda $          | Balancing factor between supervised and distillation losses     |


So overall DeiT loss becomes:

$$
\mathcal{L}_{\text{DeiT}} =
\mathcal{L}_{CE}^{[\text{CLS}]} +
\mathcal{L}_{\text{student}}^{[\text{DIST}]}
$$

---
#### Intuitive Meaning

The student (DeiT) is trained to satisfy **two goals simultaneously**:

1. **Match the true labels**
   → via the standard cross-entropy term

   $$
   (1-\lambda)\mathcal{L}_{\text{CE}}\big(\sigma(Z_{\text{cls}}), y\big)
   $$

3. **Mimic the teacher’s “dark knowledge”** (soft class probabilities)
   → via the KL divergence term
   $$
   \lambda\tau^2\mathrm{KL}\left( \sigma\left(\frac{Z_{\text{dist}}}{\tau}\right), \sigma\left(\frac{Z_t}{\tau}\right) \right)
   $$

The **teacher’s predictions** (even for incorrect classes) contain valuable information about class similarity — e.g., a cat image might get 0.7 cat, 0.2 dog, 0.1 fox.
These *soft targets* help the student generalize better than one-hot labels.

---

#### The Role of **Temperature $ \tau $**

* When $ \tau > 1 $, the softmax becomes **softer** — it spreads probability mass across classes.
* This reveals **relative similarities** between classes.

$$
\sigma_i(Z / \tau) = \frac{e^{Z_i / \tau}}{\sum_j e^{Z_j / \tau}}
$$

Typical values: $ \tau \in [2, 5] $

During training:

* Compute both soft teacher and soft student distributions at temperature $ \tau $.
* Multiply the KL term by $ \tau^2 $ (to keep gradient magnitudes consistent).

---

#### Role of **Balancing factor $ \lambda $**

Controls the tradeoff between:

* Fitting to **ground truth** (hard labels)
* Mimicking the **teacher** (soft labels)

Typical choice: $ \lambda = 0.5 $

---



#### PyTorch Implementation Example

```python
import torch
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.5):
    """DeiT-style KD loss"""
    # Hard-label loss
    ce_loss = F.cross_entropy(student_logits, labels)
    
    # Soft-label loss (teacher guidance)
    p_s = F.log_softmax(student_logits / T, dim=1)
    p_t = F.softmax(teacher_logits / T, dim=1)
    kd_loss = F.kl_div(p_s, p_t, reduction='batchmean') * (T * T)

    # Weighted sum
    return (1 - alpha) * ce_loss + alpha * kd_loss
```

---




### Inference

Use either
$$
\hat y=p_{\text{cls}}\quad\text{or}\quad
\hat y=\tfrac12(p_{\text{cls}}+p_{\text{dist}})
$$
Averaging usually gives slightly higher accuracy.


---

## Textual Architecture Diagram

```
Image (224×224×3)
   ↓
Patch split (16×16) → 196 patches
   ↓
Linear projection → Patch embeddings (196×D)
   ↓
Add [CLS] and [DIST] tokens
   ↓
Add positional embeddings
   ↓
Transformer encoder (12 layers)
   ↓
Outputs:
   [CLS] → Classification head → CE loss
   [DIST] → Distillation head → KD loss
   ↓
Combined loss → update weights
```

---




## **Choosing Architectures For Teacher and Student**


DeiT’s teacher–student setup is designed as follows:

| Role        | Model Type                   | Example Architecture         | Training                                     |
| ----------- | ---------------------------- | ---------------------------- | -------------------------------------------- |
| **Teacher** | **CNN**                      | **RegNetY-16GF**             | Pretrained on ImageNet-1k                    |
| **Student** | **Vision Transformer (ViT)** | **DeiT-Tiny / Small / Base** | Trained from scratch on ImageNet-1k using KD |

So DeiT uses a **CNN teacher** and a **Transformer student**.

That’s important because it transfers **CNN inductive biases** (local spatial patterns, hierarchical representations) into the ViT-like student, which otherwise lacks them.

---

#### Teacher Architecture — **RegNetY**

**RegNetY** (from Facebook AI, 2020) is a family of efficient convolutional networks.

**Why RegNetY?**

* Stable training on ImageNet.
* Scalable family of models (RegNetY-4GF, 8GF, 16GF, etc.).
* Excellent accuracy–efficiency tradeoff.
* Provides good *soft targets* (probability distributions) for distillation.

**Teacher setup in DeiT paper:**

* RegNetY-16GF trained on ImageNet-1k.
* ImageNet-1k: involved `1,000` classes, contains `1,281,167` training images, `50,000` validation images and `100,000` test images
* Used only to produce logits $ Z_t $ for distillation — not updated during DeiT training.

---

#### Student Architecture — **Vision Transformer (ViT-like)**

DeiT’s student is **exactly a ViT**, with minimal modifications:

Common structure:

| Stage                   | Description                                                 | DeiT implementation |
| ----------------------- | ----------------------------------------------------------- | ------------------- |
| **Patch embedding**     | Split image into $16 \times 16$ patches → linear projection | same as ViT         |
| **Tokens**              | Add `[CLS]` + `[DIST]` + patch tokens                       | DeiT modification   |
| **Positional encoding** | Learnable 1D positional embedding                           | same as ViT         |
| **Encoder**             | 12 Transformer blocks (MHSA + MLP + LayerNorm)              | same as ViT         |
| **Heads**               | Two linear heads: classification and distillation           | DeiT modification   |

---

#### DeiT Model Variants

| Model      | Layers | Hidden dim (D) | MLP dim | Heads | Params | Patch size |
| ---------- | ------ | -------------- | ------- | ----- | ------ | ---------- |
| DeiT-Tiny  | 12     | 192            | 768     | 3     | 5M     | 16         |
| DeiT-Small | 12     | 384            | 1536    | 6     | 22M    | 16         |
| DeiT-Base  | 12     | 768            | 3072    | 12    | 86M    | 16         |

All trained **from scratch** on ImageNet-1k.

---

#### Teacher–Student Interaction Diagram

```
             ┌─────────────────────────┐
             │   Teacher: RegNetY-16GF │
             └──────────┬──────────────┘
                        │
                 soft targets (Z_t)
                        │
────────────────────────────────────────────
                        │
             ┌─────────────────────────────┐
             │  Student: DeiT Transformer  │
             │  (with [CLS] and [DIST])    │
             └─────────────────────────────┘
                        │
           ┌────────────┴────────────┐
           │                         │
   [CLS] Head (CrossEntropy)   [DIST] Head (KD Loss with Z_t)
           │                         │
           └────────────┬────────────┘
                        ↓
                 Total Loss = CE + KD
```

---

#### Why This Combination Works

| Aspect                          | CNN (Teacher)                                    | Transformer (Student)                   |
| ------------------------------- | ------------------------------------------------ | --------------------------------------- |
| **Inductive bias**              | Strong (local filters, translation equivariance) | Weak (global attention)                 |
| **Data efficiency**             | High                                             | Low                                     |
| **Training stability**          | Very stable                                      | Needs large data                        |
| **Knowledge distillation role** | Provides local, structured soft targets          | Learns global features more efficiently |

So the CNN acts as a **structural prior** for the Transformer, teaching it to focus on *semantically meaningful local regions* even with limited data.

---


#### Load Teacher (CNN) and Student (DeiT)

Teacher: pretrained CNN (frozen)






```python
teacher = create_model('regnety_160.pycls_in1k', pretrained=True)
```

Alternative Models:

If you want different capacity levels, you can use:

| Model Name | FLOPs | Parameters |
|------------|-------|------------|
| `regnety_016.pycls_in1k` | 1.6 GF | ~11M |
| `regnety_032.pycls_in1k` | 3.2 GF | ~20M |
| `regnety_080.pycls_in1k` | 8.0 GF | ~39M |
| `regnety_160.pycls_in1k` | **16 GF** | **84M** ← **You're using this** |
| `regnety_320.pycls_in1k` | 32 GF | 145M |

---

```python
teacher.eval()
for p in teacher.parameters():
    p.requires_grad_(False)

# Student: Vision Transformer with [CLS] + [DIST] tokens
student = create_model(
    'deit_base_distilled_patch16_224.fb_in1k', pretrained=False)
# Enable distillation mode to return (cls_logits, dist_logits)
student.distilled_training = True
student.train()  # Must be in training mode to get both cls and dist logits

# Example input (batch of 4 images, 3×224×224)
x = torch.randn(4, 3, 224, 224)
# random ground-truth labels, since we have 1k classes in ImageNet, classes are between 0-1000 and the shape should be (4,)
y = torch.randint(0, 1000, (4,))

# --------------------------------------------
# 2. Forward Pass
# --------------------------------------------
with torch.no_grad():
    teacher_logits = teacher(x)   # [B, num_classes]

print(teacher_logits)

# Student forward:
# timm's DeiT returns:
#   - if distilled: a tuple (cls_logits, dist_logits)
#   - if not distilled: a single tensor
out = student(x)
```

This line is a **fallback mechanism** to handle both distilled and non-distilled models with the same code. 

```python
    cls_logits = dist_logits = out  # fallback (non-distilled)
```

#### Why This Line Exists?

DeiT Model Outputs:

**Distilled DeiT models** (like `deit_base_distilled_patch16_224`) have two special tokens:
- **[CLS] token**: Standard classification token → produces `cls_logits`
- **[DIST] token**: Distillation token → produces `dist_logits`

These models return a **tuple**: `(cls_logits, dist_logits)`

**Non-distilled models** (like regular `deit_base_patch16_224`) only have:
- **[CLS] token**: Standard classification token

These models return a **single tensor**

**The Problem**

The loss function on line 66 expects BOTH arguments:

```python
loss = deit_distillation_loss(cls_logits, dist_logits, teacher_logits, y)
```

**The Solution**

The code handles both cases:

```python
if isinstance(out, tuple):
    cls_logits, dist_logits = out
else:
    cls_logits = dist_logits = out  # fallback (non-distilled)
```

so:
- `cls_logits` = the model's output
- `dist_logits` = the model's output (same reference)

This way, if you accidentally use a non-distilled model, the code still works. Both losses (CE loss and KD loss) will be computed using the same logits.

```python
if isinstance(out, tuple):
    print("distilled")
    cls_logits, dist_logits = out
else:
    print("non-distilled")
    cls_logits = dist_logits = out  # fallback (non-distilled)
```


```python
# --------------------------------------------
# 3. Define Distillation Loss
# --------------------------------------------


def deit_distillation_loss(cls_logits, dist_logits, teacher_logits, labels, T=2.0, alpha=0.5):
    """
    DeiT-style distillation loss
    """
    # Cross-entropy with ground truth (for [CLS] token)
    ce_loss = F.cross_entropy(cls_logits, labels)

    # KL divergence with teacher soft targets (for [DIST] token)
    p_s = F.log_softmax(dist_logits / T, dim=1)
    p_t = F.softmax(teacher_logits / T, dim=1)
    kd_loss = F.kl_div(p_s, p_t, reduction='batchmean') * (T * T)

    # Weighted combination
    return (1 - alpha) * ce_loss + alpha * kd_loss


# --------------------------------------------
# 4. Compute total loss
# --------------------------------------------
loss = deit_distillation_loss(cls_logits, dist_logits, teacher_logits, y)
print(f"Total training loss: {loss.item():.4f}")

# --------------------------------------------
# 5. Backpropagation
# --------------------------------------------
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)
optimizer.zero_grad()
loss.backward()
optimizer.step()
```