In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, models, transforms
from tqdm import tqdm


def prepare_dataloaders(
    path, batch_size=32, num_workers=2, val_split=0.2, test_split=0.1, augment=False
):
    """
    Prepares train, validation, and test dataloaders from an ImageFolder dataset.
    Supports optional data augmentation for the training set.

    Args:
        path (str): Root directory path to the dataset.
        batch_size (int): Number of samples per batch.
        num_workers (int): Number of subprocesses for data loading.
        val_split (float): Fraction of data to use for validation.
        test_split (float): Fraction of data to use for testing.
        augment (bool): If True, apply augmentation to training dataset.

    Returns:
        tuple: (train_loader, val_loader, test_loader or None)
    """
    # Base transform (resize + normalize)
    base_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Augmented transform for training only
    aug_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Load full dataset (with base transform first)
    full_dataset = datasets.ImageFolder(root=path, transform=base_transform)

    total_size = len(full_dataset)
    test_size = int(total_size * test_split)
    val_size = int(total_size * val_split)
    train_size = total_size - val_size - test_size

    if test_split > 0:
        train_set, val_set, test_set = random_split(
            full_dataset,
            [train_size, val_size, test_size],
            generator=torch.Generator().manual_seed(42),
        )
        test_loader = DataLoader(
            test_set,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            drop_last=False,
        )
    else:
        train_set, val_set = random_split(
            full_dataset,
            [train_size, val_size],
            generator=torch.Generator().manual_seed(42),
        )
        test_loader = None

    # Apply augmentation only to training subset if requested
    if augment:
        train_set.dataset.transform = aug_transform

    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=False,
    )

    return train_loader, val_loader, test_loader


# -------------------------------
# Simple MTSD model
# -------------------------------


# ========== ECA and Backbone ==========
class ECALayer(nn.Module):
    def __init__(self, channels, k_size=3):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(
            1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x).squeeze(-1).transpose(1, 2)
        y = self.conv(y).transpose(1, 2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)


class BottleneckBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class Branch(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(in_channels, num_classes)

    def forward(self, x):
        x = self.pool(x).view(x.size(0), -1)
        return self.fc(x)


class MultiBranchNet(nn.Module):
    def __init__(self, input_channels=3, num_classes=5):
        super().__init__()
        self.blocks = nn.ModuleList([
            BottleneckBlock(input_channels, 64),
            BottleneckBlock(64, 128),
            BottleneckBlock(128, 256),
            BottleneckBlock(256, 512),
        ])
        self.att = nn.ModuleList([
            ECALayer(64),
            ECALayer(128),
            ECALayer(256),
            ECALayer(512),
        ])
        self.classifiers = nn.ModuleList([
            Branch(64, num_classes),
            Branch(128, num_classes),
            Branch(256, num_classes),
            Branch(512, num_classes),
        ])

    def forward(self, x):
        features, logits = [], []
        for i in range(4):
            x = self.blocks[i](x)
            x = self.att[i](x)
            features.append(x)
            logits.append(self.classifiers[i](x))
        return logits, features


def compute_asm(feature_map, target_size=None, max_hw=400):
    if target_size is not None:
        feature_map = F.interpolate(
            feature_map, size=target_size, mode="bilinear", align_corners=False
        )
    B, C, H, W = feature_map.size()
    if max_hw < H * W:
        feature_map = F.adaptive_avg_pool2d(
            feature_map, (int(max_hw**0.5), int(max_hw**0.5))
        )
    F_T = feature_map.view(B, C, -1)
    sim = torch.bmm(F_T.transpose(1, 2), F_T)
    norm = torch.norm(sim, dim=(1, 2), keepdim=True) + 1e-6
    return sim / norm


def compute_total_loss(preds, feats, labels, alpha=3, beta=0.3, gamma=3000, delta=None):
    ce = nn.CrossEntropyLoss()
    l1 = ce(preds[3], labels)
    l2 = sum(ce(preds[i], labels) for i in range(3))
    kl_loss, asm_loss = 0, 0
    target_size = feats[3].shape[2:]
    for i in range(3):
        pi = F.log_softmax(preds[i], dim=1)
        for j in range(i + 1, 4):
            pj = F.softmax(preds[j].detach(), dim=1)
            kl = F.kl_div(pi, pj, reduction="batchmean")
            asm_s = compute_asm(feats[i], target_size)
            asm_t = compute_asm(feats[j].detach(), target_size)
            if asm_s.shape != asm_t.shape:
                min_shape = list(map(min, asm_s.shape, asm_t.shape))
                asm_s = asm_s[:, : min_shape[1], : min_shape[2]]
                asm_t = asm_t[:, : min_shape[1], : min_shape[2]]
            asm = F.mse_loss(asm_s, asm_t)
            w = 1.0
            kl_loss += w * kl
            asm_loss += w * asm
    return alpha * l1 + (1 - beta) * l2 + beta * kl_loss + gamma * asm_loss


def train(
    model,
    train_loader,
    val_loader,
    device,
    qat=False,
    num_epochs=50,
    method_name="MTSD",
):
    model.to(device)

    # ---- QAT setup ----
    if qat:
        # Define QAT configuration
        model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
        # Prepare model for QAT
        torch.quantization.prepare_qat(model, inplace=True)
        print("[QAT] Model prepared for Quantization-Aware Training")

    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    delta = torch.tensor(0.5, requires_grad=True, device=device)
    delta_opt = optim.Adam([delta], lr=1e-3)

    for epoch in tqdm(range(1, num_epochs + 1)):
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            preds, feats = model(x)
            loss = compute_total_loss(preds, feats, y, delta=delta)
            loss.backward()
            optimizer.step()
            delta_opt.step()
            total_loss += loss.item()
        scheduler.step()

        val_loss = total_loss / len(train_loader)

        # Validation
        model.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                preds, _ = model(x)
                pred = preds[-1].argmax(1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        acc = correct / total
        print(
            f"[MTSD{'-QAT' if qat else ''}] Epoch {epoch}: Loss = {val_loss:.4f}, Acc = {acc:.4f}"
        )

    # ---- Convert to quantized model after QAT ----
    if qat:
        model.cpu()
        torch.quantization.convert(model, inplace=True)
        print("[QAT] Model converted to quantized version")
    torch.save(model.state_dict(), f"models/mtsd_{method_name}.pth")
    return model


# -------------------------------
# Losses
# -------------------------------
class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, targets):
        hard_loss = self.ce(student_logits, targets)
        T = self.temperature
        soft_loss = F.kl_div(
            F.log_softmax(student_logits / T, dim=1),
            F.softmax(teacher_logits / T, dim=1),
            reduction="batchmean",
        ) * (T * T)
        return self.alpha * hard_loss + (1 - self.alpha) * soft_loss


# -------------------------------
# Training helpers
# -------------------------------


def get_resnet_teacher(num_classes=5, pretrained=True):
    model = models.resnet18(pretrained=pretrained)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model


def train_teacher(model, train_loader, val_loader, device, epochs=10):
    model.to(device)
    opt = optim.Adam(model.parameters(), lr=1e-3)
    ce = nn.CrossEntropyLoss()
    for epoch in tqdm(range(epochs)):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            loss = ce(model(x), y)
            loss.backward()
            opt.step()
    torch.save(model.state_dict(), "models/teacher.pth")
    return model


def evaluate(model, loader, device):
    """
    Evaluate a MultiBranchNet or similar model.
    
    Args:
        model: PyTorch model
        loader: DataLoader for evaluation
        device: "cuda" or "cpu"

    Returns:
        Accuracy (float)
    """
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)          # unpack logits and features
            preds = logits[-1].argmax(1)  # use last branch for final prediction
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total



# -------------------------------
# KD Methods
# -------------------------------
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim

def normal_kd(student, teacher, train_loader, val_loader, device, epochs=10):
    """
    Train a student model using Knowledge Distillation from a teacher.

    Args:
        student: student model (MultiBranchNet)
        teacher: teacher model
        train_loader: DataLoader for training
        val_loader: DataLoader for validation
        device: "cuda" or "cpu"
        epochs: number of epochs

    Returns:
        Trained student model
    """

    student.to(device)
    teacher.to(device).eval()
    kd_loss = DistillationLoss()
    optimizer = optim.Adam(student.parameters(), lr=1e-3)

    for epoch in tqdm(range(1, epochs + 1)):
        student.train()
        total_loss = 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            # Teacher logits
            with torch.no_grad():
                tlogits = teacher(x)

            # Student logits (last branch only)
            slogits, _ = student(x)
            loss = kd_loss(slogits[-1], tlogits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)

        # Validation
        student.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                logits, _ = student(x)
                preds = logits[-1].argmax(1)
                correct += (preds == y).sum().item()
                total += y.size(0)
        acc = correct / total
        print(f"[Normal KD] Epoch {epoch}: Loss = {avg_loss:.4f}, Acc = {acc:.4f}")

    # Save model
    torch.save(student.state_dict(), "models/student.pth")
    return student


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

results = {}

# Standard and augmented loaders
# train_loader1, val_loader1, test_loader1 = prepare_dataloaders(
#     "data", batch_size=32, augment=False, num_workers=4
# )
# train_loader_aug1, val_loader1, test_loader1 = prepare_dataloaders(
#     "data", batch_size=32, augment=True, num_workers=4
# )

# NUM_CLASSES_1 = len(train_loader1.dataset.dataset.classes)


path = 'data4/Peach Dataset'
train_loader4, _, _ = prepare_dataloaders(
    path, batch_size=32, augment=False, num_workers=4
)
train_loader_aug4, val_loader4, test_loader4 = prepare_dataloaders(
    path, batch_size=32, augment=True, num_workers=4
)

NUM_CLASSES_4 = len(train_loader4.dataset.dataset.classes)
NUM_CLASSES_4

6

In [3]:
train_loader = train_loader4
train_loader_aug, val_loader, test_loader = train_loader_aug4, val_loader4, test_loader4
NUM_CLASSES = NUM_CLASSES_4

In [4]:
import os
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

teacher = train_teacher(
    get_resnet_teacher(num_classes=NUM_CLASSES),
    train_loader,
    val_loader,
    device,
    epochs=20
)

100%|██████████| 20/20 [00:46<00:00,  2.35s/it]


In [5]:
# ----------------------------
# Method 1: Normal KD (unchanged)
# ----------------------------
s1 = normal_kd(
    MultiBranchNet(num_classes=NUM_CLASSES),
    teacher,
    train_loader,
    val_loader,
    device,
    epochs=20,
)

acc1 = evaluate(s1, val_loader, device)
size1 = sum(p.numel() for p in s1.parameters())
results["Normal KD"] = { "acc": acc1, "size":  size1}
print(acc1, size1)

  5%|▌         | 1/20 [00:15<04:56, 15.61s/it]

[Normal KD] Epoch 1: Loss = 4.5540, Acc = 0.2213


 10%|█         | 2/20 [00:29<04:26, 14.81s/it]

[Normal KD] Epoch 2: Loss = 4.0486, Acc = 0.2391


 15%|█▌        | 3/20 [00:44<04:07, 14.56s/it]

[Normal KD] Epoch 3: Loss = 3.8885, Acc = 0.2609


 20%|██        | 4/20 [00:58<03:50, 14.43s/it]

[Normal KD] Epoch 4: Loss = 3.7745, Acc = 0.2668


 25%|██▌       | 5/20 [01:12<03:35, 14.36s/it]

[Normal KD] Epoch 5: Loss = 3.7381, Acc = 0.2055


 30%|███       | 6/20 [01:26<03:20, 14.35s/it]

[Normal KD] Epoch 6: Loss = 3.6272, Acc = 0.2451


 35%|███▌      | 7/20 [01:41<03:06, 14.34s/it]

[Normal KD] Epoch 7: Loss = 3.5827, Acc = 0.2095


 40%|████      | 8/20 [01:55<02:52, 14.35s/it]

[Normal KD] Epoch 8: Loss = 3.5616, Acc = 0.2628


 45%|████▌     | 9/20 [02:09<02:37, 14.35s/it]

[Normal KD] Epoch 9: Loss = 3.5005, Acc = 0.2846


 50%|█████     | 10/20 [02:24<02:23, 14.33s/it]

[Normal KD] Epoch 10: Loss = 3.3613, Acc = 0.2747


 55%|█████▌    | 11/20 [02:38<02:08, 14.33s/it]

[Normal KD] Epoch 11: Loss = 3.3871, Acc = 0.1818


 60%|██████    | 12/20 [02:52<01:54, 14.32s/it]

[Normal KD] Epoch 12: Loss = 3.2686, Acc = 0.2549


 65%|██████▌   | 13/20 [03:07<01:40, 14.34s/it]

[Normal KD] Epoch 13: Loss = 3.1450, Acc = 0.1700


 70%|███████   | 14/20 [03:21<01:26, 14.34s/it]

[Normal KD] Epoch 14: Loss = 3.0980, Acc = 0.1996


 75%|███████▌  | 15/20 [03:35<01:11, 14.35s/it]

[Normal KD] Epoch 15: Loss = 2.9957, Acc = 0.2826


 80%|████████  | 16/20 [03:50<00:57, 14.34s/it]

[Normal KD] Epoch 16: Loss = 2.8489, Acc = 0.1996


 85%|████████▌ | 17/20 [04:04<00:43, 14.35s/it]

[Normal KD] Epoch 17: Loss = 2.7645, Acc = 0.2470


 90%|█████████ | 18/20 [04:18<00:28, 14.32s/it]

[Normal KD] Epoch 18: Loss = 2.8190, Acc = 0.2451


 95%|█████████▌| 19/20 [04:33<00:14, 14.34s/it]

[Normal KD] Epoch 19: Loss = 2.7040, Acc = 0.1858


100%|██████████| 20/20 [04:47<00:00, 14.38s/it]

[Normal KD] Epoch 20: Loss = 2.6210, Acc = 0.1581





0.15810276679841898 1558692


In [6]:
# ----------------------------
# Method 2: Normal KD (Augmentation)
# ----------------------------
s2 = normal_kd(
    MultiBranchNet(num_classes=NUM_CLASSES),
    teacher,
    train_loader_aug,
    val_loader,
    device,
    epochs=20,
)
acc2 = evaluate(s2, val_loader, device)
size2 = sum(p.numel() for p in s2.parameters())
results["Normal KD with AUG"] = { "acc": acc2, "size":  size2}
print(acc2, size2)

  5%|▌         | 1/20 [00:14<04:33, 14.41s/it]

[Normal KD] Epoch 1: Loss = 3.5617, Acc = 0.1897


 10%|█         | 2/20 [00:28<04:19, 14.41s/it]

[Normal KD] Epoch 2: Loss = 3.2281, Acc = 0.1779


 15%|█▌        | 3/20 [00:43<04:05, 14.42s/it]

[Normal KD] Epoch 3: Loss = 3.2125, Acc = 0.1818


 20%|██        | 4/20 [00:57<03:50, 14.41s/it]

[Normal KD] Epoch 4: Loss = 3.1500, Acc = 0.1640


 25%|██▌       | 5/20 [01:11<03:35, 14.39s/it]

[Normal KD] Epoch 5: Loss = 3.1543, Acc = 0.1976


 30%|███       | 6/20 [01:26<03:21, 14.41s/it]

[Normal KD] Epoch 6: Loss = 3.1496, Acc = 0.2194


 35%|███▌      | 7/20 [01:40<03:07, 14.41s/it]

[Normal KD] Epoch 7: Loss = 3.0936, Acc = 0.2016


 40%|████      | 8/20 [01:55<02:52, 14.41s/it]

[Normal KD] Epoch 8: Loss = 3.1473, Acc = 0.2253


 45%|████▌     | 9/20 [02:09<02:38, 14.41s/it]

[Normal KD] Epoch 9: Loss = 3.0300, Acc = 0.2115


 50%|█████     | 10/20 [02:24<02:24, 14.42s/it]

[Normal KD] Epoch 10: Loss = 3.0178, Acc = 0.2154


 55%|█████▌    | 11/20 [02:38<02:09, 14.41s/it]

[Normal KD] Epoch 11: Loss = 2.9971, Acc = 0.1680


 60%|██████    | 12/20 [02:52<01:55, 14.40s/it]

[Normal KD] Epoch 12: Loss = 3.0262, Acc = 0.2036


 65%|██████▌   | 13/20 [03:07<01:40, 14.41s/it]

[Normal KD] Epoch 13: Loss = 2.9861, Acc = 0.2292


 70%|███████   | 14/20 [03:21<01:26, 14.36s/it]

[Normal KD] Epoch 14: Loss = 2.9989, Acc = 0.2411


 75%|███████▌  | 15/20 [03:35<01:11, 14.37s/it]

[Normal KD] Epoch 15: Loss = 2.9782, Acc = 0.2470


 80%|████████  | 16/20 [03:50<00:57, 14.37s/it]

[Normal KD] Epoch 16: Loss = 3.0107, Acc = 0.2194


 85%|████████▌ | 17/20 [04:04<00:43, 14.36s/it]

[Normal KD] Epoch 17: Loss = 3.0246, Acc = 0.2451


 90%|█████████ | 18/20 [04:19<00:28, 14.38s/it]

[Normal KD] Epoch 18: Loss = 3.0291, Acc = 0.2411


 95%|█████████▌| 19/20 [04:33<00:14, 14.38s/it]

[Normal KD] Epoch 19: Loss = 2.9771, Acc = 0.2352


100%|██████████| 20/20 [04:47<00:00, 14.39s/it]

[Normal KD] Epoch 20: Loss = 2.9804, Acc = 0.2292





0.23517786561264822 1558692


In [7]:
# ----------------------------
# Method 3: Multi-Teacher KD (MTSD)
# ----------------------------
s3 = train(
    MultiBranchNet(num_classes=NUM_CLASSES),
    train_loader,
    val_loader,
    device,
    qat=False,
    num_epochs=20,
    method_name="MTSD",
)


acc3 = evaluate(s3, val_loader, device)
size3 = sum(p.numel() for p in s3.parameters())
results["Multi-Teacher SD"] = { "acc": acc3, "size":  size3}
print(acc3, size3)

  5%|▌         | 1/20 [00:17<05:32, 17.49s/it]

[MTSD] Epoch 1: Loss = 8.7475, Acc = 0.2609


 10%|█         | 2/20 [00:34<05:15, 17.50s/it]

[MTSD] Epoch 2: Loss = 8.4556, Acc = 0.1798


 15%|█▌        | 3/20 [00:52<04:56, 17.46s/it]

[MTSD] Epoch 3: Loss = 8.4080, Acc = 0.1976


 20%|██        | 4/20 [01:09<04:39, 17.45s/it]

[MTSD] Epoch 4: Loss = 8.2935, Acc = 0.2411


 25%|██▌       | 5/20 [01:27<04:20, 17.38s/it]

[MTSD] Epoch 5: Loss = 8.2966, Acc = 0.1957


 30%|███       | 6/20 [01:44<04:03, 17.41s/it]

[MTSD] Epoch 6: Loss = 8.2127, Acc = 0.2115


 35%|███▌      | 7/20 [02:02<03:46, 17.43s/it]

[MTSD] Epoch 7: Loss = 8.0470, Acc = 0.2352


 40%|████      | 8/20 [02:19<03:29, 17.44s/it]

[MTSD] Epoch 8: Loss = 8.0054, Acc = 0.2213


 45%|████▌     | 9/20 [02:36<03:11, 17.44s/it]

[MTSD] Epoch 9: Loss = 7.9996, Acc = 0.2372


 50%|█████     | 10/20 [02:54<02:54, 17.44s/it]

[MTSD] Epoch 10: Loss = 7.9082, Acc = 0.1976


 55%|█████▌    | 11/20 [03:11<02:36, 17.43s/it]

[MTSD] Epoch 11: Loss = 7.8552, Acc = 0.2510


 60%|██████    | 12/20 [03:29<02:19, 17.43s/it]

[MTSD] Epoch 12: Loss = 7.6731, Acc = 0.2312


 65%|██████▌   | 13/20 [03:46<02:01, 17.43s/it]

[MTSD] Epoch 13: Loss = 7.5987, Acc = 0.2233


 70%|███████   | 14/20 [04:04<01:44, 17.44s/it]

[MTSD] Epoch 14: Loss = 7.5715, Acc = 0.2273


 75%|███████▌  | 15/20 [04:21<01:27, 17.46s/it]

[MTSD] Epoch 15: Loss = 7.5896, Acc = 0.2134


 80%|████████  | 16/20 [04:39<01:09, 17.47s/it]

[MTSD] Epoch 16: Loss = 7.4965, Acc = 0.2292


 85%|████████▌ | 17/20 [04:56<00:52, 17.47s/it]

[MTSD] Epoch 17: Loss = 7.5161, Acc = 0.2549


 90%|█████████ | 18/20 [05:13<00:34, 17.44s/it]

[MTSD] Epoch 18: Loss = 7.5441, Acc = 0.2194


 95%|█████████▌| 19/20 [05:31<00:17, 17.46s/it]

[MTSD] Epoch 19: Loss = 7.5404, Acc = 0.2352


100%|██████████| 20/20 [05:48<00:00, 17.44s/it]

[MTSD] Epoch 20: Loss = 7.4855, Acc = 0.2115





0.23122529644268774 1558692


In [8]:
# ----------------------------
# Method 4: MTSD + Augmentation
# ----------------------------
s4 = train(
    MultiBranchNet(num_classes=NUM_CLASSES),
    train_loader_aug,
    val_loader,
    device,
    qat=False,
    num_epochs=20,
    method_name="MTSD-Aug",
)
acc4 = evaluate(s4, val_loader, device)
size4 = sum(p.numel() for p in s4.parameters())
results["MTSD + Aug"] =  { "acc": acc4, "size":  size4}
print(acc4, size4)

  5%|▌         | 1/20 [00:17<05:31, 17.47s/it]

[MTSD] Epoch 1: Loss = 8.9883, Acc = 0.1759


 10%|█         | 2/20 [00:34<05:14, 17.47s/it]

[MTSD] Epoch 2: Loss = 8.7764, Acc = 0.1779


 15%|█▌        | 3/20 [00:52<04:57, 17.49s/it]

[MTSD] Epoch 3: Loss = 8.6520, Acc = 0.2292


 20%|██        | 4/20 [01:10<04:40, 17.53s/it]

[MTSD] Epoch 4: Loss = 8.6094, Acc = 0.3142


 25%|██▌       | 5/20 [01:27<04:23, 17.54s/it]

[MTSD] Epoch 5: Loss = 8.6321, Acc = 0.2648


 30%|███       | 6/20 [01:45<04:05, 17.52s/it]

[MTSD] Epoch 6: Loss = 8.5527, Acc = 0.2905


 35%|███▌      | 7/20 [02:02<03:47, 17.52s/it]

[MTSD] Epoch 7: Loss = 8.5319, Acc = 0.2589


 40%|████      | 8/20 [02:20<03:30, 17.51s/it]

[MTSD] Epoch 8: Loss = 8.5396, Acc = 0.3241


 45%|████▌     | 9/20 [02:37<03:12, 17.50s/it]

[MTSD] Epoch 9: Loss = 8.4724, Acc = 0.3360


 50%|█████     | 10/20 [02:55<02:54, 17.50s/it]

[MTSD] Epoch 10: Loss = 8.4384, Acc = 0.3320


 55%|█████▌    | 11/20 [03:12<02:37, 17.48s/it]

[MTSD] Epoch 11: Loss = 8.3202, Acc = 0.3696


 60%|██████    | 12/20 [03:30<02:19, 17.49s/it]

[MTSD] Epoch 12: Loss = 8.2802, Acc = 0.3656


 65%|██████▌   | 13/20 [03:47<02:02, 17.50s/it]

[MTSD] Epoch 13: Loss = 8.2413, Acc = 0.3458


 70%|███████   | 14/20 [04:05<01:44, 17.50s/it]

[MTSD] Epoch 14: Loss = 8.2357, Acc = 0.3458


 75%|███████▌  | 15/20 [04:22<01:27, 17.51s/it]

[MTSD] Epoch 15: Loss = 8.2394, Acc = 0.3755


 80%|████████  | 16/20 [04:40<01:09, 17.50s/it]

[MTSD] Epoch 16: Loss = 8.2520, Acc = 0.3498


 85%|████████▌ | 17/20 [04:57<00:52, 17.47s/it]

[MTSD] Epoch 17: Loss = 8.2834, Acc = 0.3557


 90%|█████████ | 18/20 [05:14<00:34, 17.46s/it]

[MTSD] Epoch 18: Loss = 8.2562, Acc = 0.3656


 95%|█████████▌| 19/20 [05:32<00:17, 17.48s/it]

[MTSD] Epoch 19: Loss = 8.2556, Acc = 0.3715


100%|██████████| 20/20 [05:49<00:00, 17.50s/it]

[MTSD] Epoch 20: Loss = 8.2128, Acc = 0.3893





0.3715415019762846 1558692


In [9]:
# ----------------------------
# Method 5: MTSD + Aug + QAT
# ----------------------------
s5 = train(
    MultiBranchNet(num_classes=NUM_CLASSES),
    train_loader_aug,
    val_loader,
    device,
    qat=True,
    num_epochs=20,
    method_name="MTSD-Aug-QAT",
)

# acc5 = evaluate(s5, val_loader, device)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  torch.quantization.prepare_qat(model, inplace=True)


[QAT] Model prepared for Quantization-Aware Training


  5%|▌         | 1/20 [00:19<06:11, 19.54s/it]

[MTSD-QAT] Epoch 1: Loss = 9.0939, Acc = 0.2115


 10%|█         | 2/20 [00:39<05:50, 19.50s/it]

[MTSD-QAT] Epoch 2: Loss = 8.9733, Acc = 0.2727


 15%|█▌        | 3/20 [00:58<05:31, 19.50s/it]

[MTSD-QAT] Epoch 3: Loss = 8.7204, Acc = 0.2115


 20%|██        | 4/20 [01:18<05:11, 19.50s/it]

[MTSD-QAT] Epoch 4: Loss = 8.6626, Acc = 0.2273


 25%|██▌       | 5/20 [01:37<04:52, 19.49s/it]

[MTSD-QAT] Epoch 5: Loss = 8.6690, Acc = 0.3478


 30%|███       | 6/20 [01:56<04:32, 19.47s/it]

[MTSD-QAT] Epoch 6: Loss = 8.5107, Acc = 0.2846


 35%|███▌      | 7/20 [02:16<04:13, 19.49s/it]

[MTSD-QAT] Epoch 7: Loss = 8.6618, Acc = 0.2549


 40%|████      | 8/20 [02:35<03:54, 19.51s/it]

[MTSD-QAT] Epoch 8: Loss = 8.4813, Acc = 0.1858


 45%|████▌     | 9/20 [02:55<03:34, 19.52s/it]

[MTSD-QAT] Epoch 9: Loss = 8.4695, Acc = 0.2372


 50%|█████     | 10/20 [03:15<03:15, 19.51s/it]

[MTSD-QAT] Epoch 10: Loss = 8.4735, Acc = 0.2984


 55%|█████▌    | 11/20 [03:34<02:55, 19.50s/it]

[MTSD-QAT] Epoch 11: Loss = 8.3343, Acc = 0.3636


 60%|██████    | 12/20 [03:54<02:36, 19.51s/it]

[MTSD-QAT] Epoch 12: Loss = 8.2691, Acc = 0.3715


 65%|██████▌   | 13/20 [04:13<02:16, 19.48s/it]

[MTSD-QAT] Epoch 13: Loss = 8.2327, Acc = 0.3538


 70%|███████   | 14/20 [04:32<01:56, 19.49s/it]

[MTSD-QAT] Epoch 14: Loss = 8.2887, Acc = 0.3913


 75%|███████▌  | 15/20 [04:52<01:37, 19.45s/it]

[MTSD-QAT] Epoch 15: Loss = 8.2777, Acc = 0.3498


 80%|████████  | 16/20 [05:11<01:17, 19.45s/it]

[MTSD-QAT] Epoch 16: Loss = 8.2750, Acc = 0.3636


 85%|████████▌ | 17/20 [05:31<00:58, 19.47s/it]

[MTSD-QAT] Epoch 17: Loss = 8.2286, Acc = 0.3577


 90%|█████████ | 18/20 [05:50<00:38, 19.47s/it]

[MTSD-QAT] Epoch 18: Loss = 8.2772, Acc = 0.3518


 95%|█████████▌| 19/20 [06:10<00:19, 19.47s/it]

[MTSD-QAT] Epoch 19: Loss = 8.2508, Acc = 0.3597


100%|██████████| 20/20 [06:29<00:00, 19.49s/it]

[MTSD-QAT] Epoch 20: Loss = 8.2352, Acc = 0.3874
[QAT] Model converted to quantized version



For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  torch.quantization.convert(model, inplace=True)


In [12]:
size5 = sum(p.numel() for p in s5.parameters())
results["MTSD-Aug-QAT"] = { "acc": 0.3874, "size":  size5}
print(size5)

1920


In [13]:
results

{'Normal KD': {'acc': 0.15810276679841898, 'size': 1558692},
 'Normal KD with AUG': {'acc': 0.23517786561264822, 'size': 1558692},
 'Multi-Teacher SD': {'acc': 0.23122529644268774, 'size': 1558692},
 'MTSD + Aug': {'acc': 0.3715415019762846, 'size': 1558692},
 'MTSD-Aug-QAT': {'acc': 0.3874, 'size': 1920}}