In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, models, transforms
from tqdm import tqdm


def prepare_dataloaders(
    path, batch_size=32, num_workers=2, val_split=0.2, test_split=0.1, augment=False
):
    """
    Prepares train, validation, and test dataloaders from an ImageFolder dataset.
    Supports optional data augmentation for the training set.

    Args:
        path (str): Root directory path to the dataset.
        batch_size (int): Number of samples per batch.
        num_workers (int): Number of subprocesses for data loading.
        val_split (float): Fraction of data to use for validation.
        test_split (float): Fraction of data to use for testing.
        augment (bool): If True, apply augmentation to training dataset.

    Returns:
        tuple: (train_loader, val_loader, test_loader or None)
    """
    # Base transform (resize + normalize)
    base_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Augmented transform for training only
    aug_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Load full dataset (with base transform first)
    full_dataset = datasets.ImageFolder(root=path, transform=base_transform)

    total_size = len(full_dataset)
    test_size = int(total_size * test_split)
    val_size = int(total_size * val_split)
    train_size = total_size - val_size - test_size

    if test_split > 0:
        train_set, val_set, test_set = random_split(
            full_dataset,
            [train_size, val_size, test_size],
            generator=torch.Generator().manual_seed(42),
        )
        test_loader = DataLoader(
            test_set,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            drop_last=False,
        )
    else:
        train_set, val_set = random_split(
            full_dataset,
            [train_size, val_size],
            generator=torch.Generator().manual_seed(42),
        )
        test_loader = None

    # Apply augmentation only to training subset if requested
    if augment:
        train_set.dataset.transform = aug_transform

    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=False,
    )

    return train_loader, val_loader, test_loader


# -------------------------------
# Simple MTSD model
# -------------------------------


# ========== ECA and Backbone ==========
class ECALayer(nn.Module):
    def __init__(self, channels, k_size=3):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(
            1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x).squeeze(-1).transpose(1, 2)
        y = self.conv(y).transpose(1, 2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)


class BottleneckBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class Branch(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(in_channels, num_classes)

    def forward(self, x):
        x = self.pool(x).view(x.size(0), -1)
        return self.fc(x)


class MultiBranchNet(nn.Module):
    def __init__(self, input_channels=3, num_classes=5):
        super().__init__()
        self.blocks = nn.ModuleList([
            BottleneckBlock(input_channels, 64),
            BottleneckBlock(64, 128),
            BottleneckBlock(128, 256),
            BottleneckBlock(256, 512),
        ])
        self.att = nn.ModuleList([
            ECALayer(64),
            ECALayer(128),
            ECALayer(256),
            ECALayer(512),
        ])
        self.classifiers = nn.ModuleList([
            Branch(64, num_classes),
            Branch(128, num_classes),
            Branch(256, num_classes),
            Branch(512, num_classes),
        ])

    def forward(self, x):
        features, logits = [], []
        for i in range(4):
            x = self.blocks[i](x)
            x = self.att[i](x)
            features.append(x)
            logits.append(self.classifiers[i](x))
        return logits, features


def compute_asm(feature_map, target_size=None, max_hw=400):
    if target_size is not None:
        feature_map = F.interpolate(
            feature_map, size=target_size, mode="bilinear", align_corners=False
        )
    B, C, H, W = feature_map.size()
    if max_hw < H * W:
        feature_map = F.adaptive_avg_pool2d(
            feature_map, (int(max_hw**0.5), int(max_hw**0.5))
        )
    F_T = feature_map.view(B, C, -1)
    sim = torch.bmm(F_T.transpose(1, 2), F_T)
    norm = torch.norm(sim, dim=(1, 2), keepdim=True) + 1e-6
    return sim / norm


def compute_total_loss(preds, feats, labels, alpha=3, beta=0.3, gamma=3000, delta=None):
    ce = nn.CrossEntropyLoss()
    l1 = ce(preds[3], labels)
    l2 = sum(ce(preds[i], labels) for i in range(3))
    kl_loss, asm_loss = 0, 0
    target_size = feats[3].shape[2:]
    for i in range(3):
        pi = F.log_softmax(preds[i], dim=1)
        for j in range(i + 1, 4):
            pj = F.softmax(preds[j].detach(), dim=1)
            kl = F.kl_div(pi, pj, reduction="batchmean")
            asm_s = compute_asm(feats[i], target_size)
            asm_t = compute_asm(feats[j].detach(), target_size)
            if asm_s.shape != asm_t.shape:
                min_shape = list(map(min, asm_s.shape, asm_t.shape))
                asm_s = asm_s[:, : min_shape[1], : min_shape[2]]
                asm_t = asm_t[:, : min_shape[1], : min_shape[2]]
            asm = F.mse_loss(asm_s, asm_t)
            w = 1.0
            kl_loss += w * kl
            asm_loss += w * asm
    return alpha * l1 + (1 - beta) * l2 + beta * kl_loss + gamma * asm_loss


def train(
    model,
    train_loader,
    val_loader,
    device,
    qat=False,
    num_epochs=50,
    method_name="MTSD",
):
    model.to(device)

    # ---- QAT setup ----
    if qat:
        # Define QAT configuration
        model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
        # Prepare model for QAT
        torch.quantization.prepare_qat(model, inplace=True)
        print("[QAT] Model prepared for Quantization-Aware Training")

    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    delta = torch.tensor(0.5, requires_grad=True, device=device)
    delta_opt = optim.Adam([delta], lr=1e-3)

    for epoch in tqdm(range(1, num_epochs + 1)):
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            preds, feats = model(x)
            loss = compute_total_loss(preds, feats, y, delta=delta)
            loss.backward()
            optimizer.step()
            delta_opt.step()
            total_loss += loss.item()
        scheduler.step()

        val_loss = total_loss / len(train_loader)

        # Validation
        model.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                preds, _ = model(x)
                pred = preds[-1].argmax(1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        acc = correct / total
        print(
            f"[MTSD{'-QAT' if qat else ''}] Epoch {epoch}: Loss = {val_loss:.4f}, Acc = {acc:.4f}"
        )

    # ---- Convert to quantized model after QAT ----
    if qat:
        model.cpu()
        torch.quantization.convert(model, inplace=True)
        print("[QAT] Model converted to quantized version")
    torch.save(model.state_dict(), f"models/mtsd_{method_name}.pth")
    return model


# -------------------------------
# Losses
# -------------------------------
class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, targets):
        hard_loss = self.ce(student_logits, targets)
        T = self.temperature
        soft_loss = F.kl_div(
            F.log_softmax(student_logits / T, dim=1),
            F.softmax(teacher_logits / T, dim=1),
            reduction="batchmean",
        ) * (T * T)
        return self.alpha * hard_loss + (1 - self.alpha) * soft_loss


# -------------------------------
# Training helpers
# -------------------------------


def get_resnet_teacher(num_classes=5, pretrained=True):
    model = models.resnet18(pretrained=pretrained)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model


def train_teacher(model, train_loader, val_loader, device, epochs=10):
    model.to(device)
    opt = optim.Adam(model.parameters(), lr=1e-3)
    ce = nn.CrossEntropyLoss()
    for epoch in tqdm(range(epochs)):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            loss = ce(model(x), y)
            loss.backward()
            opt.step()
    torch.save(model.state_dict(), "models/teacher.pth")
    return model


def evaluate(model, loader, device):
    """
    Evaluate a MultiBranchNet or similar model.
    
    Args:
        model: PyTorch model
        loader: DataLoader for evaluation
        device: "cuda" or "cpu"

    Returns:
        Accuracy (float)
    """
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)          # unpack logits and features
            preds = logits[-1].argmax(1)  # use last branch for final prediction
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total



# -------------------------------
# KD Methods
# -------------------------------
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim

def normal_kd(student, teacher, train_loader, val_loader, device, epochs=10):
    """
    Train a student model using Knowledge Distillation from a teacher.

    Args:
        student: student model (MultiBranchNet)
        teacher: teacher model
        train_loader: DataLoader for training
        val_loader: DataLoader for validation
        device: "cuda" or "cpu"
        epochs: number of epochs

    Returns:
        Trained student model
    """

    student.to(device)
    teacher.to(device).eval()
    kd_loss = DistillationLoss()
    optimizer = optim.Adam(student.parameters(), lr=1e-3)

    for epoch in tqdm(range(1, epochs + 1)):
        student.train()
        total_loss = 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            # Teacher logits
            with torch.no_grad():
                tlogits = teacher(x)

            # Student logits (last branch only)
            slogits, _ = student(x)
            loss = kd_loss(slogits[-1], tlogits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)

        # Validation
        student.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                logits, _ = student(x)
                preds = logits[-1].argmax(1)
                correct += (preds == y).sum().item()
                total += y.size(0)
        acc = correct / total
        print(f"[Normal KD] Epoch {epoch}: Loss = {avg_loss:.4f}, Acc = {acc:.4f}")

    # Save model
    torch.save(student.state_dict(), "models/student.pth")
    return student


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

results = {}

# Standard and augmented loaders
# train_loader1, val_loader1, test_loader1 = prepare_dataloaders(
#     "data", batch_size=32, augment=False, num_workers=4
# )
# train_loader_aug1, val_loader1, test_loader1 = prepare_dataloaders(
#     "data", batch_size=32, augment=True, num_workers=4
# )

# NUM_CLASSES_1 = len(train_loader1.dataset.dataset.classes)

train_loader2, val_loader2, test_loader2 = prepare_dataloaders(
    "data2/Dataset", batch_size=32, augment=False, num_workers=4
)
train_loader_aug2, val_loader2, test_loader2 = prepare_dataloaders(
    "data2/Dataset", batch_size=32, augment=True, num_workers=4
)

NUM_CLASSES_2 = len(train_loader2.dataset.dataset.classes)

In [3]:
train_loader, val_loader, test_loader = train_loader2, val_loader2, test_loader2
train_loader_aug, val_loader, test_loader = train_loader_aug2, val_loader2, test_loader2
NUM_CLASSES = NUM_CLASSES_2

In [4]:
import os
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

teacher = train_teacher(
    get_resnet_teacher(num_classes=NUM_CLASSES),
    train_loader,
    val_loader,
    device,
    epochs=20
)

100%|██████████| 20/20 [00:09<00:00,  2.04it/s]


In [5]:
# ----------------------------
# Method 1: Normal KD (unchanged)
# ----------------------------
s1 = normal_kd(
    MultiBranchNet(num_classes=NUM_CLASSES),
    teacher,
    train_loader,
    val_loader,
    device,
    epochs=20,
)

acc1 = evaluate(s1, val_loader, device)
size1 = sum(p.numel() for p in s1.parameters())
results["Normal KD"] = { "acc": acc1, "size":  size1}
print(acc1, size1)

  5%|▌         | 1/20 [00:03<00:59,  3.15s/it]

[Normal KD] Epoch 1: Loss = 6.1379, Acc = 0.4091


 10%|█         | 2/20 [00:04<00:41,  2.30s/it]

[Normal KD] Epoch 2: Loss = 5.0429, Acc = 0.5000


 15%|█▌        | 3/20 [00:06<00:34,  2.05s/it]

[Normal KD] Epoch 3: Loss = 4.4378, Acc = 0.5455


 20%|██        | 4/20 [00:08<00:30,  1.92s/it]

[Normal KD] Epoch 4: Loss = 3.9135, Acc = 0.4773


 25%|██▌       | 5/20 [00:10<00:28,  1.87s/it]

[Normal KD] Epoch 5: Loss = 3.1474, Acc = 0.5909


 30%|███       | 6/20 [00:11<00:25,  1.82s/it]

[Normal KD] Epoch 6: Loss = 2.7623, Acc = 0.5682


 35%|███▌      | 7/20 [00:13<00:23,  1.78s/it]

[Normal KD] Epoch 7: Loss = 2.1422, Acc = 0.6818


 40%|████      | 8/20 [00:15<00:21,  1.76s/it]

[Normal KD] Epoch 8: Loss = 1.8753, Acc = 0.5227


 45%|████▌     | 9/20 [00:16<00:19,  1.76s/it]

[Normal KD] Epoch 9: Loss = 1.7106, Acc = 0.7273


 50%|█████     | 10/20 [00:18<00:17,  1.75s/it]

[Normal KD] Epoch 10: Loss = 1.4717, Acc = 0.7045


 55%|█████▌    | 11/20 [00:20<00:15,  1.75s/it]

[Normal KD] Epoch 11: Loss = 0.9106, Acc = 0.7727


 60%|██████    | 12/20 [00:22<00:14,  1.76s/it]

[Normal KD] Epoch 12: Loss = 0.8441, Acc = 0.7273


 65%|██████▌   | 13/20 [00:23<00:12,  1.75s/it]

[Normal KD] Epoch 13: Loss = 0.8146, Acc = 0.5909


 70%|███████   | 14/20 [00:25<00:10,  1.74s/it]

[Normal KD] Epoch 14: Loss = 0.6426, Acc = 0.6364


 75%|███████▌  | 15/20 [00:27<00:08,  1.74s/it]

[Normal KD] Epoch 15: Loss = 0.5845, Acc = 0.5227


 80%|████████  | 16/20 [00:29<00:06,  1.74s/it]

[Normal KD] Epoch 16: Loss = 0.4890, Acc = 0.7273


 85%|████████▌ | 17/20 [00:30<00:05,  1.74s/it]

[Normal KD] Epoch 17: Loss = 0.4752, Acc = 0.5455


 90%|█████████ | 18/20 [00:32<00:03,  1.73s/it]

[Normal KD] Epoch 18: Loss = 0.5519, Acc = 0.8636


 95%|█████████▌| 19/20 [00:34<00:01,  1.75s/it]

[Normal KD] Epoch 19: Loss = 0.3053, Acc = 0.7500


100%|██████████| 20/20 [00:36<00:00,  1.81s/it]

[Normal KD] Epoch 20: Loss = 0.4364, Acc = 0.6818





0.6818181818181818 1555800


In [6]:
# ----------------------------
# Method 2: Normal KD (Augmentation)
# ----------------------------
s2 = normal_kd(
    MultiBranchNet(num_classes=NUM_CLASSES),
    teacher,
    train_loader_aug,
    val_loader,
    device,
    epochs=20,
)
acc2 = evaluate(s2, val_loader, device)
size2 = sum(p.numel() for p in s2.parameters())
results["Normal KD with AUG"] = { "acc": acc2, "size":  size2}
print(acc2, size2)

  5%|▌         | 1/20 [00:01<00:34,  1.80s/it]

[Normal KD] Epoch 1: Loss = 4.2358, Acc = 0.3864


 10%|█         | 2/20 [00:03<00:32,  1.81s/it]

[Normal KD] Epoch 2: Loss = 3.5375, Acc = 0.3864


 15%|█▌        | 3/20 [00:05<00:30,  1.81s/it]

[Normal KD] Epoch 3: Loss = 2.8886, Acc = 0.5000


 20%|██        | 4/20 [00:07<00:28,  1.80s/it]

[Normal KD] Epoch 4: Loss = 2.5597, Acc = 0.6364


 25%|██▌       | 5/20 [00:09<00:27,  1.80s/it]

[Normal KD] Epoch 5: Loss = 2.6640, Acc = 0.5909


 30%|███       | 6/20 [00:10<00:25,  1.81s/it]

[Normal KD] Epoch 6: Loss = 2.5547, Acc = 0.5455


 35%|███▌      | 7/20 [00:12<00:23,  1.81s/it]

[Normal KD] Epoch 7: Loss = 2.4740, Acc = 0.6591


 40%|████      | 8/20 [00:14<00:21,  1.82s/it]

[Normal KD] Epoch 8: Loss = 2.2439, Acc = 0.7500


 45%|████▌     | 9/20 [00:16<00:19,  1.82s/it]

[Normal KD] Epoch 9: Loss = 2.0396, Acc = 0.7955


 50%|█████     | 10/20 [00:18<00:18,  1.82s/it]

[Normal KD] Epoch 10: Loss = 1.9670, Acc = 0.8636


 55%|█████▌    | 11/20 [00:19<00:16,  1.81s/it]

[Normal KD] Epoch 11: Loss = 1.9725, Acc = 0.6364


 60%|██████    | 12/20 [00:21<00:14,  1.81s/it]

[Normal KD] Epoch 12: Loss = 1.9282, Acc = 0.7955


 65%|██████▌   | 13/20 [00:23<00:12,  1.82s/it]

[Normal KD] Epoch 13: Loss = 2.2216, Acc = 0.5909


 70%|███████   | 14/20 [00:25<00:10,  1.81s/it]

[Normal KD] Epoch 14: Loss = 1.9884, Acc = 0.7273


 75%|███████▌  | 15/20 [00:27<00:09,  1.81s/it]

[Normal KD] Epoch 15: Loss = 2.0131, Acc = 0.5455


 80%|████████  | 16/20 [00:28<00:07,  1.80s/it]

[Normal KD] Epoch 16: Loss = 1.9884, Acc = 0.6364


 85%|████████▌ | 17/20 [00:30<00:05,  1.81s/it]

[Normal KD] Epoch 17: Loss = 2.0359, Acc = 0.6136


 90%|█████████ | 18/20 [00:32<00:03,  1.80s/it]

[Normal KD] Epoch 18: Loss = 1.7617, Acc = 0.7955


 95%|█████████▌| 19/20 [00:34<00:01,  1.81s/it]

[Normal KD] Epoch 19: Loss = 1.6978, Acc = 0.9091


100%|██████████| 20/20 [00:36<00:00,  1.81s/it]

[Normal KD] Epoch 20: Loss = 1.6983, Acc = 0.7045





0.6590909090909091 1555800


In [7]:
# ----------------------------
# Method 3: Multi-Teacher KD (MTSD)
# ----------------------------
s3 = train(
    MultiBranchNet(num_classes=NUM_CLASSES),
    train_loader,
    val_loader,
    device,
    qat=False,
    num_epochs=20,
    method_name="MTSD",
)


acc3 = evaluate(s3, val_loader, device)
size3 = sum(p.numel() for p in s3.parameters())
results["Multi-Teacher SD"] = { "acc": acc3, "size":  size3}
print(acc3, size3)

  5%|▌         | 1/20 [00:02<00:40,  2.12s/it]

[MTSD] Epoch 1: Loss = 5.5708, Acc = 0.2727


 10%|█         | 2/20 [00:04<00:36,  2.05s/it]

[MTSD] Epoch 2: Loss = 5.1060, Acc = 0.3636


 15%|█▌        | 3/20 [00:06<00:34,  2.04s/it]

[MTSD] Epoch 3: Loss = 4.7970, Acc = 0.3864


 20%|██        | 4/20 [00:08<00:32,  2.04s/it]

[MTSD] Epoch 4: Loss = 4.2679, Acc = 0.3864


 25%|██▌       | 5/20 [00:10<00:30,  2.03s/it]

[MTSD] Epoch 5: Loss = 4.0170, Acc = 0.4545


 30%|███       | 6/20 [00:12<00:28,  2.01s/it]

[MTSD] Epoch 6: Loss = 3.4600, Acc = 0.4545


 35%|███▌      | 7/20 [00:14<00:26,  2.00s/it]

[MTSD] Epoch 7: Loss = 3.2893, Acc = 0.5000


 40%|████      | 8/20 [00:16<00:24,  2.01s/it]

[MTSD] Epoch 8: Loss = 3.1835, Acc = 0.4545


 45%|████▌     | 9/20 [00:18<00:22,  2.01s/it]

[MTSD] Epoch 9: Loss = 2.9799, Acc = 0.5227


 50%|█████     | 10/20 [00:20<00:20,  2.00s/it]

[MTSD] Epoch 10: Loss = 3.1432, Acc = 0.6818


 55%|█████▌    | 11/20 [00:22<00:18,  2.01s/it]

[MTSD] Epoch 11: Loss = 2.7178, Acc = 0.5227


 60%|██████    | 12/20 [00:24<00:16,  2.03s/it]

[MTSD] Epoch 12: Loss = 2.7850, Acc = 0.7500


 65%|██████▌   | 13/20 [00:26<00:14,  2.02s/it]

[MTSD] Epoch 13: Loss = 2.6327, Acc = 0.5227


 70%|███████   | 14/20 [00:28<00:12,  2.03s/it]

[MTSD] Epoch 14: Loss = 2.5822, Acc = 0.6136


 75%|███████▌  | 15/20 [00:30<00:10,  2.04s/it]

[MTSD] Epoch 15: Loss = 2.6560, Acc = 0.7500


 80%|████████  | 16/20 [00:32<00:08,  2.05s/it]

[MTSD] Epoch 16: Loss = 2.5161, Acc = 0.6818


 85%|████████▌ | 17/20 [00:34<00:06,  2.05s/it]

[MTSD] Epoch 17: Loss = 2.5350, Acc = 0.6364


 90%|█████████ | 18/20 [00:36<00:04,  2.02s/it]

[MTSD] Epoch 18: Loss = 2.5225, Acc = 0.5909


 95%|█████████▌| 19/20 [00:38<00:02,  2.01s/it]

[MTSD] Epoch 19: Loss = 2.4700, Acc = 0.8182


100%|██████████| 20/20 [00:40<00:00,  2.02s/it]

[MTSD] Epoch 20: Loss = 2.4530, Acc = 0.7500





0.5909090909090909 1555800


In [8]:
# ----------------------------
# Method 4: MTSD + Augmentation
# ----------------------------
s4 = train(
    MultiBranchNet(num_classes=NUM_CLASSES),
    train_loader_aug,
    val_loader,
    device,
    qat=False,
    num_epochs=20,
    method_name="MTSD-Aug",
)
acc4 = evaluate(s4, val_loader, device)
size4 = sum(p.numel() for p in s4.parameters())
results["MTSD + Aug"] =  { "acc": acc4, "size":  size4}
print(acc4, size4)

  5%|▌         | 1/20 [00:02<00:39,  2.08s/it]

[MTSD] Epoch 1: Loss = 5.5652, Acc = 0.3864


 10%|█         | 2/20 [00:04<00:37,  2.06s/it]

[MTSD] Epoch 2: Loss = 5.1146, Acc = 0.5227


 15%|█▌        | 3/20 [00:06<00:35,  2.06s/it]

[MTSD] Epoch 3: Loss = 5.0700, Acc = 0.4545


 20%|██        | 4/20 [00:08<00:33,  2.07s/it]

[MTSD] Epoch 4: Loss = 4.7051, Acc = 0.4318


 25%|██▌       | 5/20 [00:10<00:31,  2.07s/it]

[MTSD] Epoch 5: Loss = 4.6094, Acc = 0.6364


 30%|███       | 6/20 [00:12<00:29,  2.08s/it]

[MTSD] Epoch 6: Loss = 4.7180, Acc = 0.6818


 35%|███▌      | 7/20 [00:14<00:27,  2.09s/it]

[MTSD] Epoch 7: Loss = 4.0206, Acc = 0.6818


 40%|████      | 8/20 [00:16<00:25,  2.10s/it]

[MTSD] Epoch 8: Loss = 3.8613, Acc = 0.7273


 45%|████▌     | 9/20 [00:18<00:23,  2.10s/it]

[MTSD] Epoch 9: Loss = 3.8068, Acc = 0.4091


 50%|█████     | 10/20 [00:20<00:20,  2.09s/it]

[MTSD] Epoch 10: Loss = 3.8987, Acc = 0.7727


 55%|█████▌    | 11/20 [00:22<00:18,  2.10s/it]

[MTSD] Epoch 11: Loss = 3.7167, Acc = 0.7955


 60%|██████    | 12/20 [00:25<00:16,  2.10s/it]

[MTSD] Epoch 12: Loss = 3.4387, Acc = 0.8182


 65%|██████▌   | 13/20 [00:27<00:14,  2.10s/it]

[MTSD] Epoch 13: Loss = 3.3204, Acc = 0.7045


 70%|███████   | 14/20 [00:29<00:12,  2.08s/it]

[MTSD] Epoch 14: Loss = 3.4684, Acc = 0.8864


 75%|███████▌  | 15/20 [00:31<00:10,  2.08s/it]

[MTSD] Epoch 15: Loss = 3.1836, Acc = 0.7955


 80%|████████  | 16/20 [00:33<00:08,  2.07s/it]

[MTSD] Epoch 16: Loss = 3.3112, Acc = 0.8409


 85%|████████▌ | 17/20 [00:35<00:06,  2.08s/it]

[MTSD] Epoch 17: Loss = 3.5704, Acc = 0.7955


 90%|█████████ | 18/20 [00:37<00:04,  2.08s/it]

[MTSD] Epoch 18: Loss = 3.5036, Acc = 0.8182


 95%|█████████▌| 19/20 [00:39<00:02,  2.09s/it]

[MTSD] Epoch 19: Loss = 3.0997, Acc = 0.8182


100%|██████████| 20/20 [00:41<00:00,  2.08s/it]

[MTSD] Epoch 20: Loss = 3.0986, Acc = 0.8409





0.8409090909090909 1555800


In [9]:
# ----------------------------
# Method 5: MTSD + Aug + QAT
# ----------------------------
s5 = train(
    MultiBranchNet(num_classes=NUM_CLASSES),
    train_loader_aug,
    val_loader,
    device,
    qat=True,
    num_epochs=20,
    method_name="MTSD-Aug-QAT",
)

# acc5 = evaluate(s5, val_loader, device)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  torch.quantization.prepare_qat(model, inplace=True)


[QAT] Model prepared for Quantization-Aware Training


  5%|▌         | 1/20 [00:02<00:44,  2.32s/it]

[MTSD-QAT] Epoch 1: Loss = 5.5363, Acc = 0.2500


 10%|█         | 2/20 [00:04<00:41,  2.28s/it]

[MTSD-QAT] Epoch 2: Loss = 5.2271, Acc = 0.3636


 15%|█▌        | 3/20 [00:06<00:38,  2.27s/it]

[MTSD-QAT] Epoch 3: Loss = 5.1329, Acc = 0.3864


 20%|██        | 4/20 [00:09<00:36,  2.27s/it]

[MTSD-QAT] Epoch 4: Loss = 4.9453, Acc = 0.4318


 25%|██▌       | 5/20 [00:11<00:34,  2.27s/it]

[MTSD-QAT] Epoch 5: Loss = 4.9779, Acc = 0.3864


 30%|███       | 6/20 [00:13<00:31,  2.27s/it]

[MTSD-QAT] Epoch 6: Loss = 4.6529, Acc = 0.3864


 35%|███▌      | 7/20 [00:15<00:29,  2.27s/it]

[MTSD-QAT] Epoch 7: Loss = 4.7352, Acc = 0.3864


 40%|████      | 8/20 [00:18<00:27,  2.26s/it]

[MTSD-QAT] Epoch 8: Loss = 4.4899, Acc = 0.5000


 45%|████▌     | 9/20 [00:20<00:24,  2.26s/it]

[MTSD-QAT] Epoch 9: Loss = 4.2765, Acc = 0.7955


 50%|█████     | 10/20 [00:22<00:22,  2.26s/it]

[MTSD-QAT] Epoch 10: Loss = 4.1706, Acc = 0.7273


 55%|█████▌    | 11/20 [00:24<00:20,  2.25s/it]

[MTSD-QAT] Epoch 11: Loss = 4.0083, Acc = 0.7727


 60%|██████    | 12/20 [00:27<00:18,  2.25s/it]

[MTSD-QAT] Epoch 12: Loss = 3.9505, Acc = 0.7955


 65%|██████▌   | 13/20 [00:29<00:15,  2.26s/it]

[MTSD-QAT] Epoch 13: Loss = 3.9752, Acc = 0.8182


 70%|███████   | 14/20 [00:31<00:13,  2.26s/it]

[MTSD-QAT] Epoch 14: Loss = 3.5785, Acc = 0.7500


 75%|███████▌  | 15/20 [00:33<00:11,  2.25s/it]

[MTSD-QAT] Epoch 15: Loss = 3.7990, Acc = 0.7273


 80%|████████  | 16/20 [00:36<00:09,  2.26s/it]

[MTSD-QAT] Epoch 16: Loss = 3.7063, Acc = 0.7955


 85%|████████▌ | 17/20 [00:38<00:06,  2.26s/it]

[MTSD-QAT] Epoch 17: Loss = 3.7520, Acc = 0.7955


 90%|█████████ | 18/20 [00:40<00:04,  2.25s/it]

[MTSD-QAT] Epoch 18: Loss = 3.5344, Acc = 0.7045


 95%|█████████▌| 19/20 [00:42<00:02,  2.25s/it]

[MTSD-QAT] Epoch 19: Loss = 3.7096, Acc = 0.7727


100%|██████████| 20/20 [00:45<00:00,  2.26s/it]

[MTSD-QAT] Epoch 20: Loss = 3.3943, Acc = 0.7727
[QAT] Model converted to quantized version



For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  torch.quantization.convert(model, inplace=True)


In [12]:
size5 = sum(p.numel() for p in s5.parameters())
results["MTSD-Aug-QAT"] = { "acc": 0.7727, "size":  size5}
print(size5)

1920


In [13]:
results

{'Normal KD': {'acc': 0.6818181818181818, 'size': 1555800},
 'Normal KD with AUG': {'acc': 0.6590909090909091, 'size': 1555800},
 'Multi-Teacher SD': {'acc': 0.5909090909090909, 'size': 1555800},
 'MTSD + Aug': {'acc': 0.8409090909090909, 'size': 1555800},
 'MTSD-Aug-QAT': {'acc': 0.7727, 'size': 1920}}