In [14]:
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, models, transforms
from tqdm import tqdm

# Global lists to store metrics for plotting
training_metrics = {}

def prepare_dataloaders(
    path, batch_size=32, num_workers=2, val_split=0.2, test_split=0.1, augment=False
):
    """
    Prepares train, validation, and test dataloaders from an ImageFolder dataset.
    Supports optional data augmentation for the training set.

    Args:
        path (str): Root directory path to the dataset.
        batch_size (int): Number of samples per batch.
        num_workers (int): Number of subprocesses for data loading.
        val_split (float): Fraction of data to use for validation.
        test_split (float): Fraction of data to use for testing.
        augment (bool): If True, apply augmentation to training dataset.

    Returns:
        tuple: (train_loader, val_loader, test_loader or None)
    """
    # Base transform (resize + normalize)
    base_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Augmented transform for training only
    aug_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Load full dataset (with base transform first)
    full_dataset = datasets.ImageFolder(root=path, transform=base_transform)

    total_size = len(full_dataset)
    test_size = int(total_size * test_split)
    val_size = int(total_size * val_split)
    train_size = total_size - val_size - test_size

    if test_split > 0:
        train_set, val_set, test_set = random_split(
            full_dataset,
            [train_size, val_size, test_size],
            generator=torch.Generator().manual_seed(42),
        )
        test_loader = DataLoader(
            test_set,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            drop_last=False,
        )
    else:
        train_set, val_set = random_split(
            full_dataset,
            [train_size, val_size],
            generator=torch.Generator().manual_seed(42),
        )
        test_loader = None

    # Apply augmentation only to training subset if requested
    if augment:
        train_set.dataset.transform = aug_transform

    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=False,
    )

    return train_loader, val_loader, test_loader


# -------------------------------
# Simple MTSD model
# -------------------------------


# ========== ECA and Backbone ==========
class ECALayer(nn.Module):
    def __init__(self, channels, k_size=3):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(
            1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x).squeeze(-1).transpose(1, 2)
        y = self.conv(y).transpose(1, 2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)


class BottleneckBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class Branch(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(in_channels, num_classes)

    def forward(self, x):
        x = self.pool(x).view(x.size(0), -1)
        return self.fc(x)


class MultiBranchNet(nn.Module):
    def __init__(self, input_channels=3, num_classes=5):
        super().__init__()
        self.blocks = nn.ModuleList([
            BottleneckBlock(input_channels, 64),
            BottleneckBlock(64, 128),
            BottleneckBlock(128, 256),
            BottleneckBlock(256, 512),
        ])
        self.att = nn.ModuleList([
            ECALayer(64),
            ECALayer(128),
            ECALayer(256),
            ECALayer(512),
        ])
        self.classifiers = nn.ModuleList([
            Branch(64, num_classes),
            Branch(128, num_classes),
            Branch(256, num_classes),
            Branch(512, num_classes),
        ])

    def forward(self, x):
        features, logits = [], []
        for i in range(4):
            x = self.blocks[i](x)
            x = self.att[i](x)
            features.append(x)
            logits.append(self.classifiers[i](x))
        return logits, features


def compute_asm(feature_map, target_size=None, max_hw=400):
    if target_size is not None:
        feature_map = F.interpolate(
            feature_map, size=target_size, mode="bilinear", align_corners=False
        )
    B, C, H, W = feature_map.size()
    if max_hw < H * W:
        feature_map = F.adaptive_avg_pool2d(
            feature_map, (int(max_hw**0.5), int(max_hw**0.5))
        )
    F_T = feature_map.view(B, C, -1)
    sim = torch.bmm(F_T.transpose(1, 2), F_T)
    norm = torch.norm(sim, dim=(1, 2), keepdim=True) + 1e-6
    return sim / norm


def compute_total_loss(preds, feats, labels, alpha=3, beta=0.3, gamma=3000, delta=None):
    ce = nn.CrossEntropyLoss()
    l1 = ce(preds[3], labels)
    l2 = sum(ce(preds[i], labels) for i in range(3))
    kl_loss, asm_loss = 0, 0
    target_size = feats[3].shape[2:]
    for i in range(3):
        pi = F.log_softmax(preds[i], dim=1)
        for j in range(i + 1, 4):
            pj = F.softmax(preds[j].detach(), dim=1)
            kl = F.kl_div(pi, pj, reduction="batchmean")
            asm_s = compute_asm(feats[i], target_size)
            asm_t = compute_asm(feats[j].detach(), target_size)
            if asm_s.shape != asm_t.shape:
                min_shape = list(map(min, asm_s.shape, asm_t.shape))
                asm_s = asm_s[:, : min_shape[1], : min_shape[2]]
                asm_t = asm_t[:, : min_shape[1], : min_shape[2]]
            asm = F.mse_loss(asm_s, asm_t)
            w = 1.0
            kl_loss += w * kl
            asm_loss += w * asm
    return alpha * l1 + (1 - beta) * l2 + beta * kl_loss + gamma * asm_loss


def train(
    model,
    train_loader,
    val_loader,
    device,
    qat=False,
    num_epochs=50,
    method_name="MTSD",
):
    model.to(device)

    # ---- QAT setup ----
    if qat:
        # Define QAT configuration
        model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
        # Prepare model for QAT
        torch.quantization.prepare_qat(model, inplace=True)
        print("[QAT] Model prepared for Quantization-Aware Training")

    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    delta = torch.tensor(0.5, requires_grad=True, device=device)
    delta_opt = optim.Adam([delta], lr=1e-3)

    for epoch in tqdm(range(1, num_epochs + 1)):
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            preds, feats = model(x)
            loss = compute_total_loss(preds, feats, y, delta=delta)
            loss.backward()
            optimizer.step()
            delta_opt.step()
            total_loss += loss.item()
        scheduler.step()

        val_loss = total_loss / len(train_loader)

        # Validation
        model.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                preds, _ = model(x)
                pred = preds[-1].argmax(1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        acc = correct / total
        training_metrics[method_name] = {"loss": [], "acc": []}
        training_metrics[method_name]["loss"].append(val_loss)
        training_metrics[method_name]["acc"].append(acc)
        print(
            f"[MTSD{'-QAT' if qat else ''}] Epoch {epoch}: Loss = {val_loss:.4f}, Acc = {acc:.4f}"
        )

    # ---- Convert to quantized model after QAT ----
    if qat:
        model.cpu()
        torch.quantization.convert(model, inplace=True)
        print("[QAT] Model converted to quantized version")
    torch.save(model.state_dict(), f"models/mtsd_{method_name}.pth")
    return model


# -------------------------------
# Losses
# -------------------------------
class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, targets):
        hard_loss = self.ce(student_logits, targets)
        T = self.temperature
        soft_loss = F.kl_div(
            F.log_softmax(student_logits / T, dim=1),
            F.softmax(teacher_logits / T, dim=1),
            reduction="batchmean",
        ) * (T * T)
        return self.alpha * hard_loss + (1 - self.alpha) * soft_loss


# -------------------------------
# Training helpers
# -------------------------------


def get_resnet_teacher(num_classes=5, pretrained=True):
    model = models.resnet18(pretrained=pretrained)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model


def train_teacher(model, train_loader, val_loader, device, epochs=10):
    model.to(device)
    opt = optim.Adam(model.parameters(), lr=1e-3)
    ce = nn.CrossEntropyLoss()
    for epoch in tqdm(range(epochs)):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            loss = ce(model(x), y)
            loss.backward()
            opt.step()
    torch.save(model.state_dict(), "models/teacher.pth")
    return model


def evaluate(model, loader, device):
    """
    Evaluate a MultiBranchNet or similar model.
    
    Args:
        model: PyTorch model
        loader: DataLoader for evaluation
        device: "cuda" or "cpu"

    Returns:
        Accuracy (float)
    """
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)          # unpack logits and features
            preds = logits[-1].argmax(1)  # use last branch for final prediction
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total



# -------------------------------
# KD Methods
# -------------------------------
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim

def normal_kd(student, teacher, train_loader, val_loader, device, epochs=10, training_metrics=None):
    """
    Train a student model using Knowledge Distillation from a teacher.

    Args:
        student: student model (MultiBranchNet)
        teacher: teacher model
        train_loader: DataLoader for training
        val_loader: DataLoader for validation
        device: "cuda" or "cpu"
        epochs: number of epochs
        training_metrics: dict to log 'loss' and 'acc'

    Returns:
        Trained student model
    """
    if training_metrics is None:
        training_metrics = {"Normal KD": {"loss": [], "acc": []}}

    student.to(device)
    teacher.to(device).eval()
    kd_loss = DistillationLoss()
    optimizer = optim.Adam(student.parameters(), lr=1e-3)

    for epoch in tqdm(range(1, epochs + 1)):
        student.train()
        total_loss = 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            # Teacher logits
            with torch.no_grad():
                tlogits = teacher(x)

            # Student logits (last branch only)
            slogits, _ = student(x)
            loss = kd_loss(slogits[-1], tlogits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)

        # Validation
        student.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                logits, _ = student(x)
                preds = logits[-1].argmax(1)
                correct += (preds == y).sum().item()
                total += y.size(0)
        acc = correct / total

        # Log metrics
        training_metrics["Normal KD"]["loss"].append(avg_loss)
        training_metrics["Normal KD"]["acc"].append(acc)

        print(f"[Normal KD] Epoch {epoch}: Loss = {avg_loss:.4f}, Acc = {acc:.4f}")

    # Save model
    torch.save(student.state_dict(), "models/student.pth")
    return student


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Standard and augmented loaders
train_loader, val_loader, test_loader = prepare_dataloaders(
    "data4/rice_dataset", batch_size=32, augment=False, num_workers=4
)
train_loader_aug, val_loader, test_loader = prepare_dataloaders(
    "data4/rice_dataset", batch_size=32, augment=True, num_workers=4
)

In [5]:
# --- Teachers ---
results = []
teacher = train_teacher(
    get_resnet_teacher(num_classes=4), train_loader, val_loader, device, epochs=20
)

100%|██████████| 20/20 [00:59<00:00,  2.99s/it]


In [8]:
# ----------------------------
# Method 1: Normal KD (unchanged)
# ----------------------------
s1 = normal_kd(
    MultiBranchNet(num_classes=4),
    teacher,
    train_loader,
    val_loader,
    device,
    epochs=20,
)

  5%|▌         | 1/20 [00:16<05:04, 16.02s/it]

[Normal KD] Epoch 1: Loss = 3.8519, Acc = 0.6012


 10%|█         | 2/20 [00:31<04:46, 15.90s/it]

[Normal KD] Epoch 2: Loss = 2.1748, Acc = 0.6501


 15%|█▌        | 3/20 [00:47<04:29, 15.83s/it]

[Normal KD] Epoch 3: Loss = 1.7409, Acc = 0.6501


 20%|██        | 4/20 [01:03<04:13, 15.84s/it]

[Normal KD] Epoch 4: Loss = 1.4389, Acc = 0.6223


 25%|██▌       | 5/20 [01:19<03:57, 15.81s/it]

[Normal KD] Epoch 5: Loss = 1.2234, Acc = 0.7007


 30%|███       | 6/20 [01:34<03:41, 15.79s/it]

[Normal KD] Epoch 6: Loss = 1.0680, Acc = 0.7040


 35%|███▌      | 7/20 [01:50<03:24, 15.76s/it]

[Normal KD] Epoch 7: Loss = 0.9259, Acc = 0.7192


 40%|████      | 8/20 [02:06<03:09, 15.77s/it]

[Normal KD] Epoch 8: Loss = 0.9054, Acc = 0.7293


 45%|████▌     | 9/20 [02:22<02:53, 15.77s/it]

[Normal KD] Epoch 9: Loss = 0.8314, Acc = 0.7530


 50%|█████     | 10/20 [02:37<02:37, 15.77s/it]

[Normal KD] Epoch 10: Loss = 0.7193, Acc = 0.7715


 55%|█████▌    | 11/20 [02:53<02:21, 15.75s/it]

[Normal KD] Epoch 11: Loss = 0.6513, Acc = 0.7369


 60%|██████    | 12/20 [03:09<02:06, 15.75s/it]

[Normal KD] Epoch 12: Loss = 0.6372, Acc = 0.7378


 65%|██████▌   | 13/20 [03:25<01:50, 15.76s/it]

[Normal KD] Epoch 13: Loss = 0.6219, Acc = 0.7580


 70%|███████   | 14/20 [03:40<01:34, 15.76s/it]

[Normal KD] Epoch 14: Loss = 0.5328, Acc = 0.7799


 75%|███████▌  | 15/20 [03:56<01:18, 15.76s/it]

[Normal KD] Epoch 15: Loss = 0.5390, Acc = 0.7344


 80%|████████  | 16/20 [04:12<01:03, 15.77s/it]

[Normal KD] Epoch 16: Loss = 0.5188, Acc = 0.7218


 85%|████████▌ | 17/20 [04:28<00:47, 15.77s/it]

[Normal KD] Epoch 17: Loss = 0.4546, Acc = 0.7344


 90%|█████████ | 18/20 [04:44<00:31, 15.76s/it]

[Normal KD] Epoch 18: Loss = 0.4766, Acc = 0.7167


 95%|█████████▌| 19/20 [04:59<00:15, 15.77s/it]

[Normal KD] Epoch 19: Loss = 0.4184, Acc = 0.7664


100%|██████████| 20/20 [05:15<00:00, 15.78s/it]

[Normal KD] Epoch 20: Loss = 0.4646, Acc = 0.7791





AttributeError: 'tuple' object has no attribute 'argmax'

In [11]:
acc1 = evaluate(s1, val_loader, device)
size1 = sum(p.numel() for p in s1.parameters())
results.append(["Normal KD", acc1, size1])
print(acc1, size1)

0.7900505902192243 1556764


In [15]:
# ----------------------------
# Method 2: Multi-Teacher KD (MTSD)
# ----------------------------
s2 = train(
    MultiBranchNet(num_classes=4),
    train_loader,
    val_loader,
    device,
    qat=False,
    num_epochs=20,
    method_name="MTSD",
)
acc2 = evaluate(s2, val_loader, device)
size2 = sum(p.numel() for p in s2.parameters())
results.append(["Multi-Teacher SD", acc2, size2])
print(acc2, size2)

  5%|▌         | 1/20 [00:21<06:45, 21.32s/it]

[MTSD] Epoch 1: Loss = 4.3539, Acc = 0.4553


 10%|█         | 2/20 [00:42<06:20, 21.15s/it]

[MTSD] Epoch 2: Loss = 2.9054, Acc = 0.6037


 15%|█▌        | 3/20 [01:03<05:58, 21.12s/it]

[MTSD] Epoch 3: Loss = 2.3851, Acc = 0.6602


 20%|██        | 4/20 [01:24<05:37, 21.10s/it]

[MTSD] Epoch 4: Loss = 1.7426, Acc = 0.7032


 25%|██▌       | 5/20 [01:45<05:16, 21.08s/it]

[MTSD] Epoch 5: Loss = 1.5520, Acc = 0.6745


 30%|███       | 6/20 [02:06<04:54, 21.07s/it]

[MTSD] Epoch 6: Loss = 1.4819, Acc = 0.6518


 35%|███▌      | 7/20 [02:27<04:33, 21.06s/it]

[MTSD] Epoch 7: Loss = 1.1972, Acc = 0.6686


 40%|████      | 8/20 [02:48<04:12, 21.05s/it]

[MTSD] Epoch 8: Loss = 1.0907, Acc = 0.7150


 45%|████▌     | 9/20 [03:09<03:51, 21.07s/it]

[MTSD] Epoch 9: Loss = 0.9254, Acc = 0.7007


 50%|█████     | 10/20 [03:30<03:30, 21.06s/it]

[MTSD] Epoch 10: Loss = 1.0007, Acc = 0.7032


 55%|█████▌    | 11/20 [03:51<03:09, 21.06s/it]

[MTSD] Epoch 11: Loss = 0.7695, Acc = 0.7395


 60%|██████    | 12/20 [04:12<02:48, 21.06s/it]

[MTSD] Epoch 12: Loss = 0.7062, Acc = 0.7167


 65%|██████▌   | 13/20 [04:33<02:27, 21.06s/it]

[MTSD] Epoch 13: Loss = 0.6927, Acc = 0.7175


 70%|███████   | 14/20 [04:55<02:06, 21.10s/it]

[MTSD] Epoch 14: Loss = 0.6786, Acc = 0.7487


 75%|███████▌  | 15/20 [05:16<01:45, 21.08s/it]

[MTSD] Epoch 15: Loss = 0.6715, Acc = 0.7336


 80%|████████  | 16/20 [05:37<01:24, 21.07s/it]

[MTSD] Epoch 16: Loss = 0.6510, Acc = 0.7344


 85%|████████▌ | 17/20 [05:58<01:03, 21.08s/it]

[MTSD] Epoch 17: Loss = 0.6607, Acc = 0.7192


 90%|█████████ | 18/20 [06:19<00:42, 21.07s/it]

[MTSD] Epoch 18: Loss = 0.6680, Acc = 0.7327


 95%|█████████▌| 19/20 [06:40<00:21, 21.07s/it]

[MTSD] Epoch 19: Loss = 0.6338, Acc = 0.7293


100%|██████████| 20/20 [07:01<00:00, 21.08s/it]

[MTSD] Epoch 20: Loss = 0.6458, Acc = 0.7344





0.7251264755480608 1556764


In [16]:
# ----------------------------
# Method 3: MTSD + Augmentation
# ----------------------------
s3 = train(
    MultiBranchNet(num_classes=4),
    train_loader_aug,
    val_loader,
    device,
    qat=False,
    num_epochs=20,
    method_name="MTSD-Aug",
)
acc3 = evaluate(s3, val_loader, device)
size3 = sum(p.numel() for p in s3.parameters())
results.append(["MTSD + Aug", acc3, size3])


  5%|▌         | 1/20 [00:21<06:41, 21.15s/it]

[MTSD] Epoch 1: Loss = 5.2228, Acc = 0.6501


 10%|█         | 2/20 [00:42<06:19, 21.11s/it]

[MTSD] Epoch 2: Loss = 4.2412, Acc = 0.7690


 15%|█▌        | 3/20 [01:03<05:59, 21.12s/it]

[MTSD] Epoch 3: Loss = 3.6893, Acc = 0.8331


 20%|██        | 4/20 [01:24<05:37, 21.12s/it]

[MTSD] Epoch 4: Loss = 3.4524, Acc = 0.8111


 25%|██▌       | 5/20 [01:45<05:16, 21.12s/it]

[MTSD] Epoch 5: Loss = 3.2569, Acc = 0.8255


 30%|███       | 6/20 [02:06<04:55, 21.13s/it]

[MTSD] Epoch 6: Loss = 3.1550, Acc = 0.7901


 35%|███▌      | 7/20 [02:27<04:34, 21.12s/it]

[MTSD] Epoch 7: Loss = 3.0183, Acc = 0.8044


 40%|████      | 8/20 [02:48<04:13, 21.12s/it]

[MTSD] Epoch 8: Loss = 2.8455, Acc = 0.8676


 45%|████▌     | 9/20 [03:10<03:52, 21.13s/it]

[MTSD] Epoch 9: Loss = 2.7113, Acc = 0.9056


 50%|█████     | 10/20 [03:31<03:31, 21.13s/it]

[MTSD] Epoch 10: Loss = 2.6042, Acc = 0.8541


 55%|█████▌    | 11/20 [03:52<03:10, 21.12s/it]

[MTSD] Epoch 11: Loss = 2.3179, Acc = 0.9325


 60%|██████    | 12/20 [04:13<02:48, 21.12s/it]

[MTSD] Epoch 12: Loss = 2.2216, Acc = 0.9477


 65%|██████▌   | 13/20 [04:34<02:27, 21.13s/it]

[MTSD] Epoch 13: Loss = 2.1772, Acc = 0.9503


 70%|███████   | 14/20 [04:55<02:06, 21.13s/it]

[MTSD] Epoch 14: Loss = 2.1203, Acc = 0.9494


 75%|███████▌  | 15/20 [05:16<01:45, 21.14s/it]

[MTSD] Epoch 15: Loss = 2.1368, Acc = 0.9519


 80%|████████  | 16/20 [05:38<01:24, 21.13s/it]

[MTSD] Epoch 16: Loss = 2.0635, Acc = 0.9570


 85%|████████▌ | 17/20 [05:59<01:03, 21.13s/it]

[MTSD] Epoch 17: Loss = 2.1048, Acc = 0.9646


 90%|█████████ | 18/20 [06:20<00:42, 21.13s/it]

[MTSD] Epoch 18: Loss = 2.0652, Acc = 0.9604


 95%|█████████▌| 19/20 [06:41<00:21, 21.12s/it]

[MTSD] Epoch 19: Loss = 2.0190, Acc = 0.9696


100%|██████████| 20/20 [07:02<00:00, 21.13s/it]

[MTSD] Epoch 20: Loss = 2.0138, Acc = 0.9671





In [17]:
print(acc3, size3)

0.975548060708263 1556764


In [18]:
# ----------------------------
# Method 4: MTSD + Aug + QAT
# ----------------------------
s4 = train(
    MultiBranchNet(num_classes=4),
    train_loader_aug,
    val_loader,
    device,
    qat=True,
    num_epochs=20,
    method_name="MTSD-Aug-QAT",
)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  torch.quantization.prepare_qat(model, inplace=True)


[QAT] Model prepared for Quantization-Aware Training


  5%|▌         | 1/20 [00:24<07:39, 24.17s/it]

[MTSD-QAT] Epoch 1: Loss = 5.3231, Acc = 0.7285


 10%|█         | 2/20 [00:48<07:12, 24.05s/it]

[MTSD-QAT] Epoch 2: Loss = 4.3093, Acc = 0.8069


 15%|█▌        | 3/20 [01:12<06:47, 24.00s/it]

[MTSD-QAT] Epoch 3: Loss = 3.7450, Acc = 0.8212


 20%|██        | 4/20 [01:36<06:23, 23.98s/it]

[MTSD-QAT] Epoch 4: Loss = 3.4882, Acc = 0.8398


 25%|██▌       | 5/20 [01:59<05:59, 23.97s/it]

[MTSD-QAT] Epoch 5: Loss = 3.3144, Acc = 0.8204


 30%|███       | 6/20 [02:23<05:35, 23.97s/it]

[MTSD-QAT] Epoch 6: Loss = 2.9218, Acc = 0.8406


 35%|███▌      | 7/20 [02:47<05:11, 23.97s/it]

[MTSD-QAT] Epoch 7: Loss = 2.9363, Acc = 0.8752


 40%|████      | 8/20 [03:11<04:47, 23.97s/it]

[MTSD-QAT] Epoch 8: Loss = 2.6886, Acc = 0.8449


 45%|████▌     | 9/20 [03:35<04:23, 23.98s/it]

[MTSD-QAT] Epoch 9: Loss = 2.4628, Acc = 0.8997


 50%|█████     | 10/20 [03:59<03:59, 23.97s/it]

[MTSD-QAT] Epoch 10: Loss = 2.3543, Acc = 0.9182


 55%|█████▌    | 11/20 [04:23<03:35, 23.98s/it]

[MTSD-QAT] Epoch 11: Loss = 2.0466, Acc = 0.9401


 60%|██████    | 12/20 [04:47<03:11, 23.98s/it]

[MTSD-QAT] Epoch 12: Loss = 1.9491, Acc = 0.9444


 65%|██████▌   | 13/20 [05:11<02:47, 23.98s/it]

[MTSD-QAT] Epoch 13: Loss = 1.9086, Acc = 0.9376


 70%|███████   | 14/20 [05:35<02:23, 23.98s/it]

[MTSD-QAT] Epoch 14: Loss = 1.9638, Acc = 0.9334


 75%|███████▌  | 15/20 [05:59<01:59, 23.97s/it]

[MTSD-QAT] Epoch 15: Loss = 1.8843, Acc = 0.9528


 80%|████████  | 16/20 [06:23<01:35, 23.97s/it]

[MTSD-QAT] Epoch 16: Loss = 1.8529, Acc = 0.9553


 85%|████████▌ | 17/20 [06:47<01:11, 23.97s/it]

[MTSD-QAT] Epoch 17: Loss = 1.8986, Acc = 0.9570


 90%|█████████ | 18/20 [07:11<00:47, 23.97s/it]

[MTSD-QAT] Epoch 18: Loss = 1.8401, Acc = 0.9519


 95%|█████████▌| 19/20 [07:35<00:23, 23.97s/it]

[MTSD-QAT] Epoch 19: Loss = 1.7566, Acc = 0.9570


100%|██████████| 20/20 [07:59<00:00, 23.98s/it]

[MTSD-QAT] Epoch 20: Loss = 1.7346, Acc = 0.9553
[QAT] Model converted to quantized version



For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  torch.quantization.convert(model, inplace=True)


NotImplementedError: Could not run 'quantized::conv2d.new' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::conv2d.new' is only available for these backends: [Meta, QuantizedCPU, QuantizedCUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMTIA, AutogradMAIA, AutogradMeta, Tracer, AutocastCPU, AutocastMTIA, AutocastMAIA, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Meta: registered at /pytorch/aten/src/ATen/core/MetaFallbackKernel.cpp:23 [backend fallback]
QuantizedCPU: registered at /pytorch/aten/src/ATen/native/quantized/cpu/qconv.cpp:2047 [kernel]
QuantizedCUDA: registered at /pytorch/aten/src/ATen/native/quantized/cudnn/Conv.cpp:386 [kernel]
BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:194 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:479 [backend fallback]
Functionalize: registered at /pytorch/aten/src/ATen/FunctionalizeFallbackKernel.cpp:375 [backend fallback]
Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /pytorch/aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at /pytorch/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at /pytorch/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:104 [backend fallback]
AutogradOther: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:63 [backend fallback]
AutogradCPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:67 [backend fallback]
AutogradCUDA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:75 [backend fallback]
AutogradXLA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:87 [backend fallback]
AutogradMPS: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:95 [backend fallback]
AutogradXPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:71 [backend fallback]
AutogradHPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:108 [backend fallback]
AutogradLazy: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:91 [backend fallback]
AutogradMTIA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:79 [backend fallback]
AutogradMAIA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:83 [backend fallback]
AutogradMeta: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:99 [backend fallback]
Tracer: registered at /pytorch/torch/csrc/autograd/TraceTypeManual.cpp:294 [backend fallback]
AutocastCPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:322 [backend fallback]
AutocastMTIA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:466 [backend fallback]
AutocastMAIA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:504 [backend fallback]
AutocastXPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:542 [backend fallback]
AutocastMPS: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:209 [backend fallback]
AutocastCUDA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:165 [backend fallback]
FuncTorchBatched: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:731 [backend fallback]
BatchedNestedTensor: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:758 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /pytorch/aten/src/ATen/functorch/VmapModeRegistrations.cpp:27 [backend fallback]
Batched: registered at /pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /pytorch/aten/src/ATen/functorch/TensorWrapper.cpp:210 [backend fallback]
PythonTLSSnapshot: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:202 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:475 [backend fallback]
PreDispatch: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:206 [backend fallback]
PythonDispatcher: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:198 [backend fallback]


In [23]:
size4 = sum(p.numel() for p in s4.parameters())
print(size4)

1920


In [29]:
training_metrics

{'MTSD': {'loss': [0.6458064128262128],
  'acc': [0.7344013490725126],
  'size': 1556764},
 'MTSD-Aug': {'loss': [2.013754116472348],
  'acc': [0.9671163575042159],
  'size': 1556764},
 'MTSD-Aug-QAT': {'loss': [1.7346172794815182],
  'acc': [0.9553119730185498],
  'size': 1920},
 'Normal KD': {'acc': 0.7934232715008431, 'size': 1556764}}

In [30]:
# ----------------------------
# Method 5: Normal KD (Augmentation)
# ----------------------------
s5 = normal_kd(
    MultiBranchNet(num_classes=4),
    teacher,
    train_loader_aug,
    val_loader,
    device,
    epochs=20,
)

  5%|▌         | 1/20 [00:15<05:02, 15.94s/it]

[Normal KD] Epoch 1: Loss = 3.2703, Acc = 0.7083


 10%|█         | 2/20 [00:31<04:45, 15.88s/it]

[Normal KD] Epoch 2: Loss = 2.3953, Acc = 0.7673


 15%|█▌        | 3/20 [00:47<04:29, 15.86s/it]

[Normal KD] Epoch 3: Loss = 2.1139, Acc = 0.7538


 20%|██        | 4/20 [01:03<04:13, 15.86s/it]

[Normal KD] Epoch 4: Loss = 1.9400, Acc = 0.7782


 25%|██▌       | 5/20 [01:19<03:57, 15.85s/it]

[Normal KD] Epoch 5: Loss = 1.8532, Acc = 0.8322


 30%|███       | 6/20 [01:35<03:41, 15.85s/it]

[Normal KD] Epoch 6: Loss = 1.7726, Acc = 0.8449


 35%|███▌      | 7/20 [01:50<03:25, 15.84s/it]

[Normal KD] Epoch 7: Loss = 1.6343, Acc = 0.8128


 40%|████      | 8/20 [02:06<03:10, 15.86s/it]

[Normal KD] Epoch 8: Loss = 1.6119, Acc = 0.8465


 45%|████▌     | 9/20 [02:22<02:54, 15.90s/it]

[Normal KD] Epoch 9: Loss = 1.5622, Acc = 0.8659


 50%|█████     | 10/20 [02:38<02:38, 15.90s/it]

[Normal KD] Epoch 10: Loss = 1.4762, Acc = 0.8887


 55%|█████▌    | 11/20 [02:54<02:23, 15.89s/it]

[Normal KD] Epoch 11: Loss = 1.4372, Acc = 0.8870


 60%|██████    | 12/20 [03:10<02:07, 15.88s/it]

[Normal KD] Epoch 12: Loss = 1.3354, Acc = 0.8752


 65%|██████▌   | 13/20 [03:26<01:51, 15.89s/it]

[Normal KD] Epoch 13: Loss = 1.2921, Acc = 0.8853


 70%|███████   | 14/20 [03:42<01:35, 15.91s/it]

[Normal KD] Epoch 14: Loss = 1.2350, Acc = 0.9123


 75%|███████▌  | 15/20 [03:58<01:19, 15.91s/it]

[Normal KD] Epoch 15: Loss = 1.1925, Acc = 0.8980


 80%|████████  | 16/20 [04:14<01:03, 15.90s/it]

[Normal KD] Epoch 16: Loss = 1.1817, Acc = 0.8845


 85%|████████▌ | 17/20 [04:30<00:47, 15.90s/it]

[Normal KD] Epoch 17: Loss = 1.1407, Acc = 0.8997


 90%|█████████ | 18/20 [04:46<00:31, 15.94s/it]

[Normal KD] Epoch 18: Loss = 1.1409, Acc = 0.9317


 95%|█████████▌| 19/20 [05:01<00:15, 15.91s/it]

[Normal KD] Epoch 19: Loss = 1.1026, Acc = 0.9140


100%|██████████| 20/20 [05:17<00:00, 15.89s/it]

[Normal KD] Epoch 20: Loss = 1.0362, Acc = 0.8895





In [31]:
acc5 = evaluate(s5, val_loader, device)
size5 = sum(p.numel() for p in s5.parameters())
training_metrics["Normal KD with AUG"] = { "acc": acc5, "size":  size5}
print(acc5, size5)

0.893760539629005 1556764


In [32]:
training_metrics

{'MTSD': {'loss': [0.6458064128262128],
  'acc': [0.7344013490725126],
  'size': 1556764},
 'MTSD-Aug': {'loss': [2.013754116472348],
  'acc': [0.9671163575042159],
  'size': 1556764},
 'MTSD-Aug-QAT': {'loss': [1.7346172794815182],
  'acc': [0.9553119730185498],
  'size': 1920},
 'Normal KD': {'acc': 0.7934232715008431, 'size': 1556764},
 'Normal KD with AUG': {'acc': 0.893760539629005, 'size': 1556764}}