In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, models, transforms
from tqdm import tqdm


def prepare_dataloaders(
    path, batch_size=32, num_workers=2, val_split=0.2, test_split=0.1, augment=False
):
    """
    Prepares train, validation, and test dataloaders from an ImageFolder dataset.
    Supports optional data augmentation for the training set.

    Args:
        path (str): Root directory path to the dataset.
        batch_size (int): Number of samples per batch.
        num_workers (int): Number of subprocesses for data loading.
        val_split (float): Fraction of data to use for validation.
        test_split (float): Fraction of data to use for testing.
        augment (bool): If True, apply augmentation to training dataset.

    Returns:
        tuple: (train_loader, val_loader, test_loader or None)
    """
    # Base transform (resize + normalize)
    base_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Augmented transform for training only
    aug_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Load full dataset (with base transform first)
    full_dataset = datasets.ImageFolder(root=path, transform=base_transform)

    total_size = len(full_dataset)
    test_size = int(total_size * test_split)
    val_size = int(total_size * val_split)
    train_size = total_size - val_size - test_size

    if test_split > 0:
        train_set, val_set, test_set = random_split(
            full_dataset,
            [train_size, val_size, test_size],
            generator=torch.Generator().manual_seed(42),
        )
        test_loader = DataLoader(
            test_set,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            drop_last=False,
        )
    else:
        train_set, val_set = random_split(
            full_dataset,
            [train_size, val_size],
            generator=torch.Generator().manual_seed(42),
        )
        test_loader = None

    # Apply augmentation only to training subset if requested
    if augment:
        train_set.dataset.transform = aug_transform

    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=False,
    )

    return train_loader, val_loader, test_loader


# -------------------------------
# Simple MTSD model
# -------------------------------


# ========== ECA and Backbone ==========
class ECALayer(nn.Module):
    def __init__(self, channels, k_size=3):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(
            1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x).squeeze(-1).transpose(1, 2)
        y = self.conv(y).transpose(1, 2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)


class BottleneckBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class Branch(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(in_channels, num_classes)

    def forward(self, x):
        x = self.pool(x).view(x.size(0), -1)
        return self.fc(x)


class MultiBranchNet(nn.Module):
    def __init__(self, input_channels=3, num_classes=5):
        super().__init__()
        self.blocks = nn.ModuleList([
            BottleneckBlock(input_channels, 64),
            BottleneckBlock(64, 128),
            BottleneckBlock(128, 256),
            BottleneckBlock(256, 512),
        ])
        self.att = nn.ModuleList([
            ECALayer(64),
            ECALayer(128),
            ECALayer(256),
            ECALayer(512),
        ])
        self.classifiers = nn.ModuleList([
            Branch(64, num_classes),
            Branch(128, num_classes),
            Branch(256, num_classes),
            Branch(512, num_classes),
        ])

    def forward(self, x):
        features, logits = [], []
        for i in range(4):
            x = self.blocks[i](x)
            x = self.att[i](x)
            features.append(x)
            logits.append(self.classifiers[i](x))
        return logits, features


def compute_asm(feature_map, target_size=None, max_hw=400):
    if target_size is not None:
        feature_map = F.interpolate(
            feature_map, size=target_size, mode="bilinear", align_corners=False
        )
    B, C, H, W = feature_map.size()
    if max_hw < H * W:
        feature_map = F.adaptive_avg_pool2d(
            feature_map, (int(max_hw**0.5), int(max_hw**0.5))
        )
    F_T = feature_map.view(B, C, -1)
    sim = torch.bmm(F_T.transpose(1, 2), F_T)
    norm = torch.norm(sim, dim=(1, 2), keepdim=True) + 1e-6
    return sim / norm


def compute_total_loss(preds, feats, labels, alpha=3, beta=0.3, gamma=3000, delta=None):
    ce = nn.CrossEntropyLoss()
    l1 = ce(preds[3], labels)
    l2 = sum(ce(preds[i], labels) for i in range(3))
    kl_loss, asm_loss = 0, 0
    target_size = feats[3].shape[2:]
    for i in range(3):
        pi = F.log_softmax(preds[i], dim=1)
        for j in range(i + 1, 4):
            pj = F.softmax(preds[j].detach(), dim=1)
            kl = F.kl_div(pi, pj, reduction="batchmean")
            asm_s = compute_asm(feats[i], target_size)
            asm_t = compute_asm(feats[j].detach(), target_size)
            if asm_s.shape != asm_t.shape:
                min_shape = list(map(min, asm_s.shape, asm_t.shape))
                asm_s = asm_s[:, : min_shape[1], : min_shape[2]]
                asm_t = asm_t[:, : min_shape[1], : min_shape[2]]
            asm = F.mse_loss(asm_s, asm_t)
            w = 1.0
            kl_loss += w * kl
            asm_loss += w * asm
    return alpha * l1 + (1 - beta) * l2 + beta * kl_loss + gamma * asm_loss


def train(
    model,
    train_loader,
    val_loader,
    device,
    qat=False,
    num_epochs=50,
    method_name="MTSD",
):
    model.to(device)

    # ---- QAT setup ----
    if qat:
        # Define QAT configuration
        model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
        # Prepare model for QAT
        torch.quantization.prepare_qat(model, inplace=True)
        print("[QAT] Model prepared for Quantization-Aware Training")

    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    delta = torch.tensor(0.5, requires_grad=True, device=device)
    delta_opt = optim.Adam([delta], lr=1e-3)

    for epoch in tqdm(range(1, num_epochs + 1)):
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            preds, feats = model(x)
            loss = compute_total_loss(preds, feats, y, delta=delta)
            loss.backward()
            optimizer.step()
            delta_opt.step()
            total_loss += loss.item()
        scheduler.step()

        val_loss = total_loss / len(train_loader)

        # Validation
        model.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                preds, _ = model(x)
                pred = preds[-1].argmax(1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        acc = correct / total
        print(
            f"[MTSD{'-QAT' if qat else ''}] Epoch {epoch}: Loss = {val_loss:.4f}, Acc = {acc:.4f}"
        )

    # ---- Convert to quantized model after QAT ----
    if qat:
        model.cpu()
        torch.quantization.convert(model, inplace=True)
        print("[QAT] Model converted to quantized version")
    torch.save(model.state_dict(), f"models/mtsd_{method_name}.pth")
    return model


# -------------------------------
# Losses
# -------------------------------
class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, targets):
        hard_loss = self.ce(student_logits, targets)
        T = self.temperature
        soft_loss = F.kl_div(
            F.log_softmax(student_logits / T, dim=1),
            F.softmax(teacher_logits / T, dim=1),
            reduction="batchmean",
        ) * (T * T)
        return self.alpha * hard_loss + (1 - self.alpha) * soft_loss


# -------------------------------
# Training helpers
# -------------------------------


def get_resnet_teacher(num_classes=5, pretrained=True):
    model = models.resnet18(pretrained=pretrained)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model


def train_teacher(model, train_loader, val_loader, device, epochs=10):
    model.to(device)
    opt = optim.Adam(model.parameters(), lr=1e-3)
    ce = nn.CrossEntropyLoss()
    for epoch in tqdm(range(epochs)):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            loss = ce(model(x), y)
            loss.backward()
            opt.step()
    torch.save(model.state_dict(), "models/teacher.pth")
    return model


def evaluate(model, loader, device):
    """
    Evaluate a MultiBranchNet or similar model.
    
    Args:
        model: PyTorch model
        loader: DataLoader for evaluation
        device: "cuda" or "cpu"

    Returns:
        Accuracy (float)
    """
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)          # unpack logits and features
            preds = logits[-1].argmax(1)  # use last branch for final prediction
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total



# -------------------------------
# KD Methods
# -------------------------------
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim

def normal_kd(student, teacher, train_loader, val_loader, device, epochs=10):
    """
    Train a student model using Knowledge Distillation from a teacher.

    Args:
        student: student model (MultiBranchNet)
        teacher: teacher model
        train_loader: DataLoader for training
        val_loader: DataLoader for validation
        device: "cuda" or "cpu"
        epochs: number of epochs

    Returns:
        Trained student model
    """

    student.to(device)
    teacher.to(device).eval()
    kd_loss = DistillationLoss()
    optimizer = optim.Adam(student.parameters(), lr=1e-3)

    for epoch in tqdm(range(1, epochs + 1)):
        student.train()
        total_loss = 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            # Teacher logits
            with torch.no_grad():
                tlogits = teacher(x)

            # Student logits (last branch only)
            slogits, _ = student(x)
            loss = kd_loss(slogits[-1], tlogits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)

        # Validation
        student.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                logits, _ = student(x)
                preds = logits[-1].argmax(1)
                correct += (preds == y).sum().item()
                total += y.size(0)
        acc = correct / total
        print(f"[Normal KD] Epoch {epoch}: Loss = {avg_loss:.4f}, Acc = {acc:.4f}")

    # Save model
    torch.save(student.state_dict(), "models/student.pth")
    return student


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

results = []

# Standard and augmented loaders
train_loader1, val_loader1, test_loader1 = prepare_dataloaders(
    "data", batch_size=32, augment=False, num_workers=4
)
train_loader_aug1, val_loader1, test_loader1 = prepare_dataloaders(
    "data", batch_size=32, augment=True, num_workers=4
)

NUM_CLASSES_1 = len(train_loader1.dataset.dataset.classes)

In [3]:
train_loader, val_loader, test_loader = train_loader1, val_loader1, test_loader1
train_loader_aug, val_loader, test_loader = train_loader_aug1, val_loader1, test_loader1
NUM_CLASSES = NUM_CLASSES_1

In [4]:
import os
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

teacher = train_teacher(
    get_resnet_teacher(num_classes=NUM_CLASSES),
    train_loader,
    val_loader,
    device,
    epochs=20
)

100%|██████████| 20/20 [01:25<00:00,  4.29s/it]


In [5]:
# ----------------------------
# Method 1: Normal KD (unchanged)
# ----------------------------
s1 = normal_kd(
    MultiBranchNet(num_classes=NUM_CLASSES),
    teacher,
    train_loader,
    val_loader,
    device,
    epochs=20,
)

acc1 = evaluate(s1, val_loader, device)
size1 = sum(p.numel() for p in s1.parameters())
results["Normal KD"] = { "acc": acc1, "size":  size1}
print(acc1, size1)

  5%|▌         | 1/20 [00:15<05:03, 16.00s/it]

[Normal KD] Epoch 1: Loss = 5.8749, Acc = 0.4325


 10%|█         | 2/20 [00:30<04:33, 15.17s/it]

[Normal KD] Epoch 2: Loss = 4.7545, Acc = 0.4663


 15%|█▌        | 3/20 [00:45<04:13, 14.91s/it]

[Normal KD] Epoch 3: Loss = 4.2228, Acc = 0.4881


 20%|██        | 4/20 [00:59<03:56, 14.79s/it]

[Normal KD] Epoch 4: Loss = 3.9286, Acc = 0.5000


 25%|██▌       | 5/20 [01:14<03:41, 14.76s/it]

[Normal KD] Epoch 5: Loss = 3.6666, Acc = 0.4802


 30%|███       | 6/20 [01:29<03:26, 14.73s/it]

[Normal KD] Epoch 6: Loss = 3.5559, Acc = 0.4643


 35%|███▌      | 7/20 [01:43<03:10, 14.69s/it]

[Normal KD] Epoch 7: Loss = 3.2369, Acc = 0.5635


 40%|████      | 8/20 [01:58<02:56, 14.67s/it]

[Normal KD] Epoch 8: Loss = 3.1382, Acc = 0.4504


 45%|████▌     | 9/20 [02:13<02:41, 14.67s/it]

[Normal KD] Epoch 9: Loss = 2.9549, Acc = 0.5556


 50%|█████     | 10/20 [02:27<02:26, 14.66s/it]

[Normal KD] Epoch 10: Loss = 2.8394, Acc = 0.5139


 55%|█████▌    | 11/20 [02:42<02:11, 14.66s/it]

[Normal KD] Epoch 11: Loss = 2.7355, Acc = 0.5357


 60%|██████    | 12/20 [02:56<01:57, 14.64s/it]

[Normal KD] Epoch 12: Loss = 2.5443, Acc = 0.4762


 65%|██████▌   | 13/20 [03:11<01:42, 14.63s/it]

[Normal KD] Epoch 13: Loss = 2.4171, Acc = 0.5377


 70%|███████   | 14/20 [03:26<01:27, 14.63s/it]

[Normal KD] Epoch 14: Loss = 2.3898, Acc = 0.5337


 75%|███████▌  | 15/20 [03:40<01:13, 14.63s/it]

[Normal KD] Epoch 15: Loss = 2.2772, Acc = 0.5516


 80%|████████  | 16/20 [03:55<00:58, 14.63s/it]

[Normal KD] Epoch 16: Loss = 2.0758, Acc = 0.5556


 85%|████████▌ | 17/20 [04:10<00:43, 14.61s/it]

[Normal KD] Epoch 17: Loss = 1.9791, Acc = 0.5952


 90%|█████████ | 18/20 [04:24<00:29, 14.62s/it]

[Normal KD] Epoch 18: Loss = 2.0139, Acc = 0.5496


 95%|█████████▌| 19/20 [04:39<00:14, 14.63s/it]

[Normal KD] Epoch 19: Loss = 1.9281, Acc = 0.5893


100%|██████████| 20/20 [04:53<00:00, 14.70s/it]

[Normal KD] Epoch 20: Loss = 1.7766, Acc = 0.5000





TypeError: list indices must be integers or slices, not str

In [6]:
results = {}
results["Normal KD"] = { "acc": acc1, "size":  size1}
print(acc1, size1)

0.5456349206349206 1557728


In [7]:
# ----------------------------
# Method 2: Normal KD (Augmentation)
# ----------------------------
s2 = normal_kd(
    MultiBranchNet(num_classes=NUM_CLASSES),
    teacher,
    train_loader_aug,
    val_loader,
    device,
    epochs=20,
)
acc2 = evaluate(s2, val_loader, device)
size2 = sum(p.numel() for p in s2.parameters())
results["Normal KD with AUG"] = { "acc": acc2, "size":  size2}
print(acc2, size2)

  5%|▌         | 1/20 [00:14<04:37, 14.63s/it]

[Normal KD] Epoch 1: Loss = 4.1576, Acc = 0.4563


 10%|█         | 2/20 [00:29<04:23, 14.62s/it]

[Normal KD] Epoch 2: Loss = 3.5486, Acc = 0.4901


 15%|█▌        | 3/20 [00:43<04:07, 14.57s/it]

[Normal KD] Epoch 3: Loss = 3.3667, Acc = 0.5238


 20%|██        | 4/20 [00:58<03:53, 14.62s/it]

[Normal KD] Epoch 4: Loss = 3.2010, Acc = 0.5575


 25%|██▌       | 5/20 [01:13<03:39, 14.63s/it]

[Normal KD] Epoch 5: Loss = 3.0187, Acc = 0.5258


 30%|███       | 6/20 [01:27<03:24, 14.63s/it]

[Normal KD] Epoch 6: Loss = 2.9585, Acc = 0.5833


 35%|███▌      | 7/20 [01:42<03:09, 14.61s/it]

[Normal KD] Epoch 7: Loss = 2.9092, Acc = 0.5456


 40%|████      | 8/20 [01:56<02:55, 14.63s/it]

[Normal KD] Epoch 8: Loss = 2.8210, Acc = 0.6131


 45%|████▌     | 9/20 [02:11<02:41, 14.67s/it]

[Normal KD] Epoch 9: Loss = 2.7604, Acc = 0.5595


 50%|█████     | 10/20 [02:26<02:26, 14.68s/it]

[Normal KD] Epoch 10: Loss = 2.7127, Acc = 0.5437


 55%|█████▌    | 11/20 [02:41<02:12, 14.71s/it]

[Normal KD] Epoch 11: Loss = 2.6236, Acc = 0.5952


 60%|██████    | 12/20 [02:55<01:57, 14.69s/it]

[Normal KD] Epoch 12: Loss = 2.5787, Acc = 0.6071


 65%|██████▌   | 13/20 [03:10<01:42, 14.70s/it]

[Normal KD] Epoch 13: Loss = 2.5268, Acc = 0.5933


 70%|███████   | 14/20 [03:25<01:28, 14.71s/it]

[Normal KD] Epoch 14: Loss = 2.4927, Acc = 0.5893


 75%|███████▌  | 15/20 [03:39<01:13, 14.68s/it]

[Normal KD] Epoch 15: Loss = 2.3087, Acc = 0.5853


 80%|████████  | 16/20 [03:54<00:58, 14.69s/it]

[Normal KD] Epoch 16: Loss = 2.4814, Acc = 0.6488


 85%|████████▌ | 17/20 [04:09<00:44, 14.69s/it]

[Normal KD] Epoch 17: Loss = 2.3286, Acc = 0.6389


 90%|█████████ | 18/20 [04:23<00:29, 14.68s/it]

[Normal KD] Epoch 18: Loss = 2.2708, Acc = 0.6230


 95%|█████████▌| 19/20 [04:38<00:14, 14.68s/it]

[Normal KD] Epoch 19: Loss = 2.2288, Acc = 0.6667


100%|██████████| 20/20 [04:53<00:00, 14.67s/it]

[Normal KD] Epoch 20: Loss = 2.1541, Acc = 0.6032





0.6369047619047619 1557728


In [8]:
# ----------------------------
# Method 3: Multi-Teacher KD (MTSD)
# ----------------------------
s3 = train(
    MultiBranchNet(num_classes=NUM_CLASSES),
    train_loader,
    val_loader,
    device,
    qat=False,
    num_epochs=20,
    method_name="MTSD",
)


acc3 = evaluate(s3, val_loader, device)
size3 = sum(p.numel() for p in s3.parameters())
results["Multi-Teacher SD"] = { "acc": acc3, "size":  size3}
print(acc3, size3)

  5%|▌         | 1/20 [00:17<05:37, 17.77s/it]

[MTSD] Epoch 1: Loss = 6.9342, Acc = 0.3968


 10%|█         | 2/20 [00:35<05:19, 17.78s/it]

[MTSD] Epoch 2: Loss = 5.9267, Acc = 0.4067


 15%|█▌        | 3/20 [00:53<05:01, 17.76s/it]

[MTSD] Epoch 3: Loss = 5.4061, Acc = 0.5000


 20%|██        | 4/20 [01:11<04:44, 17.78s/it]

[MTSD] Epoch 4: Loss = 4.9431, Acc = 0.4960


 25%|██▌       | 5/20 [01:28<04:26, 17.78s/it]

[MTSD] Epoch 5: Loss = 4.6585, Acc = 0.4841


 30%|███       | 6/20 [01:46<04:08, 17.75s/it]

[MTSD] Epoch 6: Loss = 4.5651, Acc = 0.4742


 35%|███▌      | 7/20 [02:04<03:50, 17.76s/it]

[MTSD] Epoch 7: Loss = 4.3749, Acc = 0.4643


 40%|████      | 8/20 [02:22<03:32, 17.73s/it]

[MTSD] Epoch 8: Loss = 4.3360, Acc = 0.4921


 45%|████▌     | 9/20 [02:39<03:15, 17.74s/it]

[MTSD] Epoch 9: Loss = 4.2134, Acc = 0.4921


 50%|█████     | 10/20 [02:57<02:57, 17.76s/it]

[MTSD] Epoch 10: Loss = 4.0029, Acc = 0.4782


 55%|█████▌    | 11/20 [03:15<02:39, 17.74s/it]

[MTSD] Epoch 11: Loss = 3.6526, Acc = 0.5139


 60%|██████    | 12/20 [03:33<02:21, 17.75s/it]

[MTSD] Epoch 12: Loss = 3.4226, Acc = 0.5317


 65%|██████▌   | 13/20 [03:50<02:04, 17.74s/it]

[MTSD] Epoch 13: Loss = 3.4544, Acc = 0.5198


 70%|███████   | 14/20 [04:08<01:46, 17.74s/it]

[MTSD] Epoch 14: Loss = 3.4261, Acc = 0.5575


 75%|███████▌  | 15/20 [04:26<01:28, 17.76s/it]

[MTSD] Epoch 15: Loss = 3.4125, Acc = 0.5317


 80%|████████  | 16/20 [04:44<01:11, 17.76s/it]

[MTSD] Epoch 16: Loss = 3.3089, Acc = 0.5595


 85%|████████▌ | 17/20 [05:01<00:53, 17.75s/it]

[MTSD] Epoch 17: Loss = 3.3242, Acc = 0.5417


 90%|█████████ | 18/20 [05:19<00:35, 17.76s/it]

[MTSD] Epoch 18: Loss = 3.2703, Acc = 0.5258


 95%|█████████▌| 19/20 [05:37<00:17, 17.76s/it]

[MTSD] Epoch 19: Loss = 3.2189, Acc = 0.5813


100%|██████████| 20/20 [05:55<00:00, 17.75s/it]

[MTSD] Epoch 20: Loss = 3.2279, Acc = 0.5238





0.5416666666666666 1557728


In [9]:
# ----------------------------
# Method 4: MTSD + Augmentation
# ----------------------------
s4 = train(
    MultiBranchNet(num_classes=NUM_CLASSES),
    train_loader_aug,
    val_loader,
    device,
    qat=False,
    num_epochs=20,
    method_name="MTSD-Aug",
)
acc4 = evaluate(s4, val_loader, device)
size4 = sum(p.numel() for p in s4.parameters())
results["MTSD + Aug"] =  { "acc": acc4, "size":  size4}
print(acc4, size4)

  5%|▌         | 1/20 [00:17<05:38, 17.81s/it]

[MTSD] Epoch 1: Loss = 7.5209, Acc = 0.4623


 10%|█         | 2/20 [00:35<05:20, 17.81s/it]

[MTSD] Epoch 2: Loss = 6.8627, Acc = 0.5377


 15%|█▌        | 3/20 [00:53<05:02, 17.79s/it]

[MTSD] Epoch 3: Loss = 6.5590, Acc = 0.6151


 20%|██        | 4/20 [01:11<04:44, 17.78s/it]

[MTSD] Epoch 4: Loss = 6.3412, Acc = 0.5754


 25%|██▌       | 5/20 [01:28<04:26, 17.80s/it]

[MTSD] Epoch 5: Loss = 6.2308, Acc = 0.6071


 30%|███       | 6/20 [01:46<04:09, 17.79s/it]

[MTSD] Epoch 6: Loss = 6.0170, Acc = 0.6111


 35%|███▌      | 7/20 [02:04<03:51, 17.80s/it]

[MTSD] Epoch 7: Loss = 5.9301, Acc = 0.5873


 40%|████      | 8/20 [02:22<03:33, 17.82s/it]

[MTSD] Epoch 8: Loss = 5.7742, Acc = 0.6111


 45%|████▌     | 9/20 [02:40<03:15, 17.79s/it]

[MTSD] Epoch 9: Loss = 5.6562, Acc = 0.6448


 50%|█████     | 10/20 [02:58<02:58, 17.81s/it]

[MTSD] Epoch 10: Loss = 5.5403, Acc = 0.6270


 55%|█████▌    | 11/20 [03:15<02:40, 17.81s/it]

[MTSD] Epoch 11: Loss = 5.2556, Acc = 0.7361


 60%|██████    | 12/20 [03:33<02:22, 17.81s/it]

[MTSD] Epoch 12: Loss = 5.1800, Acc = 0.7282


 65%|██████▌   | 13/20 [03:51<02:04, 17.81s/it]

[MTSD] Epoch 13: Loss = 5.1421, Acc = 0.7282


 70%|███████   | 14/20 [04:09<01:46, 17.79s/it]

[MTSD] Epoch 14: Loss = 5.0294, Acc = 0.7183


 75%|███████▌  | 15/20 [04:27<01:28, 17.80s/it]

[MTSD] Epoch 15: Loss = 4.9644, Acc = 0.7143


 80%|████████  | 16/20 [04:44<01:11, 17.79s/it]

[MTSD] Epoch 16: Loss = 5.0518, Acc = 0.7460


 85%|████████▌ | 17/20 [05:02<00:53, 17.79s/it]

[MTSD] Epoch 17: Loss = 5.0223, Acc = 0.7421


 90%|█████████ | 18/20 [05:20<00:35, 17.79s/it]

[MTSD] Epoch 18: Loss = 4.8811, Acc = 0.7103


 95%|█████████▌| 19/20 [05:38<00:17, 17.80s/it]

[MTSD] Epoch 19: Loss = 4.9980, Acc = 0.7599


100%|██████████| 20/20 [05:56<00:00, 17.80s/it]

[MTSD] Epoch 20: Loss = 4.9467, Acc = 0.7321





0.75 1557728


In [10]:
# ----------------------------
# Method 5: MTSD + Aug + QAT
# ----------------------------
s4 = train(
    MultiBranchNet(num_classes=NUM_CLASSES),
    train_loader_aug,
    val_loader,
    device,
    qat=True,
    num_epochs=20,
    method_name="MTSD-Aug-QAT",
)

acc5 = evaluate(s5, val_loader, device)
size5 = sum(p.numel() for p in s5.parameters())
results["MTSD-Aug-QAT"] = { "acc": acc5, "size":  size5}
print(acc5, size5)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  torch.quantization.prepare_qat(model, inplace=True)


[QAT] Model prepared for Quantization-Aware Training


  5%|▌         | 1/20 [00:19<06:16, 19.80s/it]

[MTSD-QAT] Epoch 1: Loss = 7.7175, Acc = 0.3413


 10%|█         | 2/20 [00:39<05:55, 19.74s/it]

[MTSD-QAT] Epoch 2: Loss = 7.0766, Acc = 0.5060


 15%|█▌        | 3/20 [00:59<05:36, 19.77s/it]

[MTSD-QAT] Epoch 3: Loss = 6.6703, Acc = 0.5655


 20%|██        | 4/20 [01:19<05:16, 19.79s/it]

[MTSD-QAT] Epoch 4: Loss = 6.5156, Acc = 0.5139


 25%|██▌       | 5/20 [01:38<04:56, 19.78s/it]

[MTSD-QAT] Epoch 5: Loss = 6.3399, Acc = 0.5119


 30%|███       | 6/20 [01:58<04:36, 19.78s/it]

[MTSD-QAT] Epoch 6: Loss = 6.2078, Acc = 0.5913


 35%|███▌      | 7/20 [02:18<04:17, 19.78s/it]

[MTSD-QAT] Epoch 7: Loss = 6.0909, Acc = 0.5357


 40%|████      | 8/20 [02:38<03:57, 19.78s/it]

[MTSD-QAT] Epoch 8: Loss = 6.0856, Acc = 0.5794


 45%|████▌     | 9/20 [02:58<03:37, 19.78s/it]

[MTSD-QAT] Epoch 9: Loss = 5.9569, Acc = 0.6210


 50%|█████     | 10/20 [03:17<03:17, 19.77s/it]

[MTSD-QAT] Epoch 10: Loss = 5.9060, Acc = 0.5794


 55%|█████▌    | 11/20 [03:37<02:57, 19.76s/it]

[MTSD-QAT] Epoch 11: Loss = 5.6727, Acc = 0.7321


 60%|██████    | 12/20 [03:57<02:38, 19.76s/it]

[MTSD-QAT] Epoch 12: Loss = 5.4905, Acc = 0.7083


 65%|██████▌   | 13/20 [04:16<02:18, 19.74s/it]

[MTSD-QAT] Epoch 13: Loss = 5.4177, Acc = 0.7222


 70%|███████   | 14/20 [04:36<01:58, 19.76s/it]

[MTSD-QAT] Epoch 14: Loss = 5.4095, Acc = 0.7163


 75%|███████▌  | 15/20 [04:56<01:38, 19.77s/it]

[MTSD-QAT] Epoch 15: Loss = 5.4595, Acc = 0.7321


 80%|████████  | 16/20 [05:16<01:19, 19.78s/it]

[MTSD-QAT] Epoch 16: Loss = 5.3908, Acc = 0.7282


 85%|████████▌ | 17/20 [05:36<00:59, 19.79s/it]

[MTSD-QAT] Epoch 17: Loss = 5.4493, Acc = 0.7619


 90%|█████████ | 18/20 [05:55<00:39, 19.79s/it]

[MTSD-QAT] Epoch 18: Loss = 5.4161, Acc = 0.7520


 95%|█████████▌| 19/20 [06:15<00:19, 19.79s/it]

[MTSD-QAT] Epoch 19: Loss = 5.3013, Acc = 0.7222


100%|██████████| 20/20 [06:35<00:00, 19.77s/it]

[MTSD-QAT] Epoch 20: Loss = 5.3844, Acc = 0.7361
[QAT] Model converted to quantized version



For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  torch.quantization.convert(model, inplace=True)


NameError: name 's5' is not defined

In [13]:
size5 = sum(p.numel() for p in s5.parameters())
results["MTSD-Aug-QAT"] = { "acc": 0.7520 , "size":  size5}

In [14]:
results

{'Normal KD': {'acc': 0.5456349206349206, 'size': 1557728},
 'Normal KD with AUG': {'acc': 0.6369047619047619, 'size': 1557728},
 'Multi-Teacher SD': {'acc': 0.5416666666666666, 'size': 1557728},
 'MTSD + Aug': {'acc': 0.75, 'size': 1557728},
 'MTSD-Aug-QAT': {'acc': 0.752, 'size': 1920}}