In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, models, transforms
from tqdm import tqdm


def prepare_dataloaders(
    path, batch_size=32, num_workers=2, val_split=0.2, test_split=0.1, augment=False
):
    """
    Prepares train, validation, and test dataloaders from an ImageFolder dataset.
    Supports optional data augmentation for the training set.

    Args:
        path (str): Root directory path to the dataset.
        batch_size (int): Number of samples per batch.
        num_workers (int): Number of subprocesses for data loading.
        val_split (float): Fraction of data to use for validation.
        test_split (float): Fraction of data to use for testing.
        augment (bool): If True, apply augmentation to training dataset.

    Returns:
        tuple: (train_loader, val_loader, test_loader or None)
    """
    # Base transform (resize + normalize)
    base_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Augmented transform for training only
    aug_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Load full dataset (with base transform first)
    full_dataset = datasets.ImageFolder(root=path, transform=base_transform)

    total_size = len(full_dataset)
    test_size = int(total_size * test_split)
    val_size = int(total_size * val_split)
    train_size = total_size - val_size - test_size

    if test_split > 0:
        train_set, val_set, test_set = random_split(
            full_dataset,
            [train_size, val_size, test_size],
            generator=torch.Generator().manual_seed(42),
        )
        test_loader = DataLoader(
            test_set,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            drop_last=False,
        )
    else:
        train_set, val_set = random_split(
            full_dataset,
            [train_size, val_size],
            generator=torch.Generator().manual_seed(42),
        )
        test_loader = None

    # Apply augmentation only to training subset if requested
    if augment:
        train_set.dataset.transform = aug_transform

    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=False,
    )

    return train_loader, val_loader, test_loader


# -------------------------------
# Simple MTSD model
# -------------------------------


# ========== ECA and Backbone ==========
class ECALayer(nn.Module):
    def __init__(self, channels, k_size=3):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(
            1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x).squeeze(-1).transpose(1, 2)
        y = self.conv(y).transpose(1, 2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)


class BottleneckBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class Branch(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(in_channels, num_classes)

    def forward(self, x):
        x = self.pool(x).view(x.size(0), -1)
        return self.fc(x)


class MultiBranchNet(nn.Module):
    def __init__(self, input_channels=3, num_classes=5):
        super().__init__()
        self.blocks = nn.ModuleList([
            BottleneckBlock(input_channels, 64),
            BottleneckBlock(64, 128),
            BottleneckBlock(128, 256),
            BottleneckBlock(256, 512),
        ])
        self.att = nn.ModuleList([
            ECALayer(64),
            ECALayer(128),
            ECALayer(256),
            ECALayer(512),
        ])
        self.classifiers = nn.ModuleList([
            Branch(64, num_classes),
            Branch(128, num_classes),
            Branch(256, num_classes),
            Branch(512, num_classes),
        ])

    def forward(self, x):
        features, logits = [], []
        for i in range(4):
            x = self.blocks[i](x)
            x = self.att[i](x)
            features.append(x)
            logits.append(self.classifiers[i](x))
        return logits, features


def compute_asm(feature_map, target_size=None, max_hw=400):
    if target_size is not None:
        feature_map = F.interpolate(
            feature_map, size=target_size, mode="bilinear", align_corners=False
        )
    B, C, H, W = feature_map.size()
    if max_hw < H * W:
        feature_map = F.adaptive_avg_pool2d(
            feature_map, (int(max_hw**0.5), int(max_hw**0.5))
        )
    F_T = feature_map.view(B, C, -1)
    sim = torch.bmm(F_T.transpose(1, 2), F_T)
    norm = torch.norm(sim, dim=(1, 2), keepdim=True) + 1e-6
    return sim / norm


def compute_total_loss(preds, feats, labels, alpha=3, beta=0.3, gamma=3000, delta=None):
    ce = nn.CrossEntropyLoss()
    l1 = ce(preds[3], labels)
    l2 = sum(ce(preds[i], labels) for i in range(3))
    kl_loss, asm_loss = 0, 0
    target_size = feats[3].shape[2:]
    for i in range(3):
        pi = F.log_softmax(preds[i], dim=1)
        for j in range(i + 1, 4):
            pj = F.softmax(preds[j].detach(), dim=1)
            kl = F.kl_div(pi, pj, reduction="batchmean")
            asm_s = compute_asm(feats[i], target_size)
            asm_t = compute_asm(feats[j].detach(), target_size)
            if asm_s.shape != asm_t.shape:
                min_shape = list(map(min, asm_s.shape, asm_t.shape))
                asm_s = asm_s[:, : min_shape[1], : min_shape[2]]
                asm_t = asm_t[:, : min_shape[1], : min_shape[2]]
            asm = F.mse_loss(asm_s, asm_t)
            w = 1.0
            kl_loss += w * kl
            asm_loss += w * asm
    return alpha * l1 + (1 - beta) * l2 + beta * kl_loss + gamma * asm_loss


def train(
    model,
    train_loader,
    val_loader,
    device,
    qat=False,
    num_epochs=50,
    method_name="MTSD",
):
    model.to(device)

    # ---- QAT setup ----
    if qat:
        # Define QAT configuration
        model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
        # Prepare model for QAT
        torch.quantization.prepare_qat(model, inplace=True)
        print("[QAT] Model prepared for Quantization-Aware Training")

    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    delta = torch.tensor(0.5, requires_grad=True, device=device)
    delta_opt = optim.Adam([delta], lr=1e-3)

    for epoch in tqdm(range(1, num_epochs + 1)):
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            preds, feats = model(x)
            loss = compute_total_loss(preds, feats, y, delta=delta)
            loss.backward()
            optimizer.step()
            delta_opt.step()
            total_loss += loss.item()
        scheduler.step()

        val_loss = total_loss / len(train_loader)

        # Validation
        model.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                preds, _ = model(x)
                pred = preds[-1].argmax(1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        acc = correct / total
        print(
            f"[MTSD{'-QAT' if qat else ''}] Epoch {epoch}: Loss = {val_loss:.4f}, Acc = {acc:.4f}"
        )

    # ---- Convert to quantized model after QAT ----
    if qat:
        model.cpu()
        torch.quantization.convert(model, inplace=True)
        print("[QAT] Model converted to quantized version")
    torch.save(model.state_dict(), f"models/mtsd_{method_name}.pth")
    return model


# -------------------------------
# Losses
# -------------------------------
class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, targets):
        hard_loss = self.ce(student_logits, targets)
        T = self.temperature
        soft_loss = F.kl_div(
            F.log_softmax(student_logits / T, dim=1),
            F.softmax(teacher_logits / T, dim=1),
            reduction="batchmean",
        ) * (T * T)
        return self.alpha * hard_loss + (1 - self.alpha) * soft_loss


# -------------------------------
# Training helpers
# -------------------------------


def get_resnet_teacher(num_classes=5, pretrained=True):
    model = models.resnet18(pretrained=pretrained)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model


def train_teacher(model, train_loader, val_loader, device, epochs=10):
    model.to(device)
    opt = optim.Adam(model.parameters(), lr=1e-3)
    ce = nn.CrossEntropyLoss()
    for epoch in tqdm(range(epochs)):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            loss = ce(model(x), y)
            loss.backward()
            opt.step()
    torch.save(model.state_dict(), "models/teacher.pth")
    return model


def evaluate(model, loader, device):
    """
    Evaluate a MultiBranchNet or similar model.
    
    Args:
        model: PyTorch model
        loader: DataLoader for evaluation
        device: "cuda" or "cpu"

    Returns:
        Accuracy (float)
    """
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)          # unpack logits and features
            preds = logits[-1].argmax(1)  # use last branch for final prediction
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total



# -------------------------------
# KD Methods
# -------------------------------
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim

def normal_kd(student, teacher, train_loader, val_loader, device, epochs=10):
    """
    Train a student model using Knowledge Distillation from a teacher.

    Args:
        student: student model (MultiBranchNet)
        teacher: teacher model
        train_loader: DataLoader for training
        val_loader: DataLoader for validation
        device: "cuda" or "cpu"
        epochs: number of epochs

    Returns:
        Trained student model
    """

    student.to(device)
    teacher.to(device).eval()
    kd_loss = DistillationLoss()
    optimizer = optim.Adam(student.parameters(), lr=1e-3)

    for epoch in tqdm(range(1, epochs + 1)):
        student.train()
        total_loss = 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            # Teacher logits
            with torch.no_grad():
                tlogits = teacher(x)

            # Student logits (last branch only)
            slogits, _ = student(x)
            loss = kd_loss(slogits[-1], tlogits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)

        # Validation
        student.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                logits, _ = student(x)
                preds = logits[-1].argmax(1)
                correct += (preds == y).sum().item()
                total += y.size(0)
        acc = correct / total
        print(f"[Normal KD] Epoch {epoch}: Loss = {avg_loss:.4f}, Acc = {acc:.4f}")

    # Save model
    torch.save(student.state_dict(), "models/student.pth")
    return student


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

results = {}

# Standard and augmented loaders
# train_loader1, val_loader1, test_loader1 = prepare_dataloaders(
#     "data", batch_size=32, augment=False, num_workers=4
# )
# train_loader_aug1, val_loader1, test_loader1 = prepare_dataloaders(
#     "data", batch_size=32, augment=True, num_workers=4
# )

# NUM_CLASSES_1 = len(train_loader1.dataset.dataset.classes)


path = 'data3/mendeley'
train_loader3, _, _ = prepare_dataloaders(
    path, batch_size=32, augment=False, num_workers=4
)
train_loader_aug3, val_loader3, test_loader3 = prepare_dataloaders(
    path, batch_size=32, augment=True, num_workers=4
)

NUM_CLASSES_3 = len(train_loader3.dataset.dataset.classes)
NUM_CLASSES_3

11

In [3]:
train_loader = train_loader3
train_loader_aug, val_loader, test_loader = train_loader_aug3, val_loader3, test_loader3
NUM_CLASSES = NUM_CLASSES_3

In [4]:
import os
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

teacher = train_teacher(
    get_resnet_teacher(num_classes=NUM_CLASSES),
    train_loader,
    val_loader,
    device,
    epochs=20
)

100%|██████████| 20/20 [02:45<00:00,  8.28s/it]


In [5]:
# ----------------------------
# Method 1: Normal KD (unchanged)
# ----------------------------
s1 = normal_kd(
    MultiBranchNet(num_classes=NUM_CLASSES),
    teacher,
    train_loader,
    val_loader,
    device,
    epochs=20,
)

acc1 = evaluate(s1, val_loader, device)
size1 = sum(p.numel() for p in s1.parameters())
results["Normal KD"] = { "acc": acc1, "size":  size1}
print(acc1, size1)

  5%|▌         | 1/20 [00:38<12:02, 38.02s/it]

[Normal KD] Epoch 1: Loss = 7.6782, Acc = 0.4559


 10%|█         | 2/20 [01:14<11:10, 37.27s/it]

[Normal KD] Epoch 2: Loss = 5.6542, Acc = 0.4885


 15%|█▌        | 3/20 [01:51<10:30, 37.09s/it]

[Normal KD] Epoch 3: Loss = 4.6273, Acc = 0.4937


 20%|██        | 4/20 [02:28<09:52, 37.00s/it]

[Normal KD] Epoch 4: Loss = 3.9774, Acc = 0.4685


 25%|██▌       | 5/20 [03:05<09:14, 36.93s/it]

[Normal KD] Epoch 5: Loss = 3.4172, Acc = 0.5019


 30%|███       | 6/20 [03:42<08:37, 36.94s/it]

[Normal KD] Epoch 6: Loss = 2.9021, Acc = 0.5367


 35%|███▌      | 7/20 [04:19<07:59, 36.90s/it]

[Normal KD] Epoch 7: Loss = 2.5429, Acc = 0.5515


 40%|████      | 8/20 [04:55<07:21, 36.82s/it]

[Normal KD] Epoch 8: Loss = 2.2243, Acc = 0.5211


 45%|████▌     | 9/20 [05:32<06:44, 36.78s/it]

[Normal KD] Epoch 9: Loss = 1.9098, Acc = 0.5701


 50%|█████     | 10/20 [06:09<06:08, 36.83s/it]

[Normal KD] Epoch 10: Loss = 1.6661, Acc = 0.5196


 55%|█████▌    | 11/20 [06:46<05:31, 36.85s/it]

[Normal KD] Epoch 11: Loss = 1.4574, Acc = 0.5715


 60%|██████    | 12/20 [07:23<04:54, 36.86s/it]

[Normal KD] Epoch 12: Loss = 1.4038, Acc = 0.5500


 65%|██████▌   | 13/20 [07:59<04:17, 36.84s/it]

[Normal KD] Epoch 13: Loss = 1.3217, Acc = 0.5523


 70%|███████   | 14/20 [08:36<03:40, 36.82s/it]

[Normal KD] Epoch 14: Loss = 1.2617, Acc = 0.5841


 75%|███████▌  | 15/20 [09:13<03:04, 36.85s/it]

[Normal KD] Epoch 15: Loss = 1.0698, Acc = 0.6012


 80%|████████  | 16/20 [09:50<02:27, 36.85s/it]

[Normal KD] Epoch 16: Loss = 1.0619, Acc = 0.5360


 85%|████████▌ | 17/20 [10:27<01:50, 36.84s/it]

[Normal KD] Epoch 17: Loss = 0.9692, Acc = 0.5537


 90%|█████████ | 18/20 [11:04<01:13, 36.86s/it]

[Normal KD] Epoch 18: Loss = 0.9626, Acc = 0.5893


 95%|█████████▌| 19/20 [11:40<00:36, 36.83s/it]

[Normal KD] Epoch 19: Loss = 0.8165, Acc = 0.5819


100%|██████████| 20/20 [12:17<00:00, 36.89s/it]

[Normal KD] Epoch 20: Loss = 0.8015, Acc = 0.5597





0.5841363973313566 1563512


In [6]:
# ----------------------------
# Method 2: Normal KD (Augmentation)
# ----------------------------
s2 = normal_kd(
    MultiBranchNet(num_classes=NUM_CLASSES),
    teacher,
    train_loader_aug,
    val_loader,
    device,
    epochs=20,
)
acc2 = evaluate(s2, val_loader, device)
size2 = sum(p.numel() for p in s2.parameters())
results["Normal KD with AUG"] = { "acc": acc2, "size":  size2}
print(acc2, size2)

  5%|▌         | 1/20 [00:36<11:35, 36.58s/it]

[Normal KD] Epoch 1: Loss = 3.7589, Acc = 0.4388


 10%|█         | 2/20 [01:13<10:59, 36.66s/it]

[Normal KD] Epoch 2: Loss = 3.1101, Acc = 0.5159


 15%|█▌        | 3/20 [01:50<10:25, 36.79s/it]

[Normal KD] Epoch 3: Loss = 2.9139, Acc = 0.5204


 20%|██        | 4/20 [02:27<09:49, 36.83s/it]

[Normal KD] Epoch 4: Loss = 2.6710, Acc = 0.5508


 25%|██▌       | 5/20 [03:04<09:13, 36.90s/it]

[Normal KD] Epoch 5: Loss = 2.5007, Acc = 0.5693


 30%|███       | 6/20 [03:41<08:36, 36.89s/it]

[Normal KD] Epoch 6: Loss = 2.4146, Acc = 0.6123


 35%|███▌      | 7/20 [04:17<07:59, 36.89s/it]

[Normal KD] Epoch 7: Loss = 2.3253, Acc = 0.6316


 40%|████      | 8/20 [04:54<07:23, 36.95s/it]

[Normal KD] Epoch 8: Loss = 2.2271, Acc = 0.6494


 45%|████▌     | 9/20 [05:31<06:46, 36.95s/it]

[Normal KD] Epoch 9: Loss = 2.1623, Acc = 0.6494


 50%|█████     | 10/20 [06:08<06:09, 36.92s/it]

[Normal KD] Epoch 10: Loss = 2.0475, Acc = 0.6694


 55%|█████▌    | 11/20 [06:45<05:32, 36.92s/it]

[Normal KD] Epoch 11: Loss = 1.9812, Acc = 0.5997


 60%|██████    | 12/20 [07:22<04:55, 36.93s/it]

[Normal KD] Epoch 12: Loss = 1.9718, Acc = 0.6590


 65%|██████▌   | 13/20 [07:59<04:18, 36.91s/it]

[Normal KD] Epoch 13: Loss = 1.9182, Acc = 0.6331


 70%|███████   | 14/20 [08:36<03:41, 36.93s/it]

[Normal KD] Epoch 14: Loss = 1.8454, Acc = 0.6864


 75%|███████▌  | 15/20 [09:13<03:04, 36.95s/it]

[Normal KD] Epoch 15: Loss = 1.7744, Acc = 0.6990


 80%|████████  | 16/20 [09:50<02:27, 36.94s/it]

[Normal KD] Epoch 16: Loss = 1.7575, Acc = 0.6857


 85%|████████▌ | 17/20 [10:27<01:50, 36.85s/it]

[Normal KD] Epoch 17: Loss = 1.7121, Acc = 0.6783


 90%|█████████ | 18/20 [11:03<01:13, 36.86s/it]

[Normal KD] Epoch 18: Loss = 1.7276, Acc = 0.7405


 95%|█████████▌| 19/20 [11:40<00:36, 36.80s/it]

[Normal KD] Epoch 19: Loss = 1.6571, Acc = 0.6909


100%|██████████| 20/20 [12:17<00:00, 36.88s/it]

[Normal KD] Epoch 20: Loss = 1.6326, Acc = 0.7035





0.6997776130467013 1563512


In [7]:
# ----------------------------
# Method 3: Multi-Teacher KD (MTSD)
# ----------------------------
s3 = train(
    MultiBranchNet(num_classes=NUM_CLASSES),
    train_loader,
    val_loader,
    device,
    qat=False,
    num_epochs=20,
    method_name="MTSD",
)


acc3 = evaluate(s3, val_loader, device)
size3 = sum(p.numel() for p in s3.parameters())
results["Multi-Teacher SD"] = { "acc": acc3, "size":  size3}
print(acc3, size3)

  5%|▌         | 1/20 [00:45<14:17, 45.13s/it]

[MTSD] Epoch 1: Loss = 8.7630, Acc = 0.4077


 10%|█         | 2/20 [01:30<13:30, 45.01s/it]

[MTSD] Epoch 2: Loss = 7.2048, Acc = 0.4329


 15%|█▌        | 3/20 [02:15<12:46, 45.12s/it]

[MTSD] Epoch 3: Loss = 6.4222, Acc = 0.4522


 20%|██        | 4/20 [03:00<12:01, 45.07s/it]

[MTSD] Epoch 4: Loss = 5.9646, Acc = 0.3647


 25%|██▌       | 5/20 [03:45<11:16, 45.11s/it]

[MTSD] Epoch 5: Loss = 5.6228, Acc = 0.4589


 30%|███       | 6/20 [04:30<10:31, 45.08s/it]

[MTSD] Epoch 6: Loss = 5.2439, Acc = 0.4292


 35%|███▌      | 7/20 [05:15<09:46, 45.11s/it]

[MTSD] Epoch 7: Loss = 4.9864, Acc = 0.4181


 40%|████      | 8/20 [06:00<09:01, 45.12s/it]

[MTSD] Epoch 8: Loss = 4.6881, Acc = 0.4855


 45%|████▌     | 9/20 [06:45<08:15, 45.08s/it]

[MTSD] Epoch 9: Loss = 4.4641, Acc = 0.4485


 50%|█████     | 10/20 [07:30<07:31, 45.11s/it]

[MTSD] Epoch 10: Loss = 4.1600, Acc = 0.4255


 55%|█████▌    | 11/20 [08:16<06:46, 45.14s/it]

[MTSD] Epoch 11: Loss = 3.6033, Acc = 0.4811


 60%|██████    | 12/20 [09:01<06:01, 45.14s/it]

[MTSD] Epoch 12: Loss = 3.4813, Acc = 0.4692


 65%|██████▌   | 13/20 [09:46<05:16, 45.15s/it]

[MTSD] Epoch 13: Loss = 3.4431, Acc = 0.4796


 70%|███████   | 14/20 [10:31<04:30, 45.13s/it]

[MTSD] Epoch 14: Loss = 3.3914, Acc = 0.5033


 75%|███████▌  | 15/20 [11:16<03:45, 45.14s/it]

[MTSD] Epoch 15: Loss = 3.3763, Acc = 0.4893


 80%|████████  | 16/20 [12:01<03:00, 45.11s/it]

[MTSD] Epoch 16: Loss = 3.3443, Acc = 0.4944


 85%|████████▌ | 17/20 [12:46<02:15, 45.14s/it]

[MTSD] Epoch 17: Loss = 3.2809, Acc = 0.4944


 90%|█████████ | 18/20 [13:32<01:30, 45.12s/it]

[MTSD] Epoch 18: Loss = 3.2785, Acc = 0.4752


 95%|█████████▌| 19/20 [14:17<00:45, 45.09s/it]

[MTSD] Epoch 19: Loss = 3.2328, Acc = 0.4766


100%|██████████| 20/20 [15:02<00:00, 45.11s/it]

[MTSD] Epoch 20: Loss = 3.2539, Acc = 0.5026





0.5040770941438102 1563512


In [8]:
# ----------------------------
# Method 4: MTSD + Augmentation
# ----------------------------
s4 = train(
    MultiBranchNet(num_classes=NUM_CLASSES),
    train_loader_aug,
    val_loader,
    device,
    qat=False,
    num_epochs=20,
    method_name="MTSD-Aug",
)
acc4 = evaluate(s4, val_loader, device)
size4 = sum(p.numel() for p in s4.parameters())
results["MTSD + Aug"] =  { "acc": acc4, "size":  size4}
print(acc4, size4)

  5%|▌         | 1/20 [00:45<14:19, 45.24s/it]

[MTSD] Epoch 1: Loss = 9.5831, Acc = 0.4285


 10%|█         | 2/20 [01:30<13:32, 45.16s/it]

[MTSD] Epoch 2: Loss = 8.3997, Acc = 0.5293


 15%|█▌        | 3/20 [02:15<12:46, 45.11s/it]

[MTSD] Epoch 3: Loss = 7.7424, Acc = 0.5567


 20%|██        | 4/20 [03:00<12:02, 45.14s/it]

[MTSD] Epoch 4: Loss = 7.2866, Acc = 0.5686


 25%|██▌       | 5/20 [03:45<11:17, 45.19s/it]

[MTSD] Epoch 5: Loss = 6.8582, Acc = 0.6042


 30%|███       | 6/20 [04:30<10:31, 45.14s/it]

[MTSD] Epoch 6: Loss = 6.6933, Acc = 0.6019


 35%|███▌      | 7/20 [05:16<09:47, 45.16s/it]

[MTSD] Epoch 7: Loss = 6.3243, Acc = 0.6664


 40%|████      | 8/20 [06:01<09:02, 45.17s/it]

[MTSD] Epoch 8: Loss = 6.0629, Acc = 0.6635


 45%|████▌     | 9/20 [06:46<08:17, 45.19s/it]

[MTSD] Epoch 9: Loss = 5.8250, Acc = 0.6442


 50%|█████     | 10/20 [07:31<07:31, 45.12s/it]

[MTSD] Epoch 10: Loss = 5.7909, Acc = 0.6657


 55%|█████▌    | 11/20 [08:16<06:45, 45.09s/it]

[MTSD] Epoch 11: Loss = 5.1479, Acc = 0.7591


 60%|██████    | 12/20 [09:01<06:01, 45.13s/it]

[MTSD] Epoch 12: Loss = 5.0687, Acc = 0.7620


 65%|██████▌   | 13/20 [09:46<05:15, 45.14s/it]

[MTSD] Epoch 13: Loss = 4.9428, Acc = 0.7776


 70%|███████   | 14/20 [10:31<04:30, 45.10s/it]

[MTSD] Epoch 14: Loss = 4.9422, Acc = 0.7761


 75%|███████▌  | 15/20 [11:16<03:45, 45.10s/it]

[MTSD] Epoch 15: Loss = 4.8551, Acc = 0.7917


 80%|████████  | 16/20 [12:02<03:00, 45.14s/it]

[MTSD] Epoch 16: Loss = 4.8277, Acc = 0.7895


 85%|████████▌ | 17/20 [12:47<02:15, 45.16s/it]

[MTSD] Epoch 17: Loss = 4.8373, Acc = 0.8058


 90%|█████████ | 18/20 [13:32<01:30, 45.16s/it]

[MTSD] Epoch 18: Loss = 4.7038, Acc = 0.8095


 95%|█████████▌| 19/20 [14:17<00:45, 45.08s/it]

[MTSD] Epoch 19: Loss = 4.6614, Acc = 0.7947


100%|██████████| 20/20 [15:02<00:00, 45.13s/it]

[MTSD] Epoch 20: Loss = 4.6564, Acc = 0.7939





0.7976278724981468 1563512


In [9]:
# ----------------------------
# Method 5: MTSD + Aug + QAT
# ----------------------------
s5 = train(
    MultiBranchNet(num_classes=NUM_CLASSES),
    train_loader_aug,
    val_loader,
    device,
    qat=True,
    num_epochs=20,
    method_name="MTSD-Aug-QAT",
)

# acc5 = evaluate(s5, val_loader, device)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  torch.quantization.prepare_qat(model, inplace=True)


[QAT] Model prepared for Quantization-Aware Training


  5%|▌         | 1/20 [00:50<16:01, 50.63s/it]

[MTSD-QAT] Epoch 1: Loss = 10.0534, Acc = 0.3781


 10%|█         | 2/20 [01:41<15:10, 50.56s/it]

[MTSD-QAT] Epoch 2: Loss = 8.6400, Acc = 0.4611


 15%|█▌        | 3/20 [02:31<14:18, 50.51s/it]

[MTSD-QAT] Epoch 3: Loss = 7.8725, Acc = 0.5419


 20%|██        | 4/20 [03:21<13:27, 50.46s/it]

[MTSD-QAT] Epoch 4: Loss = 7.4888, Acc = 0.5738


 25%|██▌       | 5/20 [04:12<12:35, 50.39s/it]

[MTSD-QAT] Epoch 5: Loss = 7.1770, Acc = 0.5271


 30%|███       | 6/20 [05:02<11:45, 50.38s/it]

[MTSD-QAT] Epoch 6: Loss = 6.8508, Acc = 0.5463


 35%|███▌      | 7/20 [05:53<10:55, 50.45s/it]

[MTSD-QAT] Epoch 7: Loss = 6.5743, Acc = 0.6227


 40%|████      | 8/20 [06:43<10:05, 50.42s/it]

[MTSD-QAT] Epoch 8: Loss = 6.4162, Acc = 0.6130


 45%|████▌     | 9/20 [07:34<09:15, 50.46s/it]

[MTSD-QAT] Epoch 9: Loss = 6.2423, Acc = 0.6360


 50%|█████     | 10/20 [08:24<08:24, 50.49s/it]

[MTSD-QAT] Epoch 10: Loss = 6.0232, Acc = 0.6442


 55%|█████▌    | 11/20 [09:15<07:34, 50.52s/it]

[MTSD-QAT] Epoch 11: Loss = 5.5077, Acc = 0.6976


 60%|██████    | 12/20 [10:05<06:43, 50.48s/it]

[MTSD-QAT] Epoch 12: Loss = 5.3178, Acc = 0.7213


 65%|██████▌   | 13/20 [10:56<05:53, 50.45s/it]

[MTSD-QAT] Epoch 13: Loss = 5.3194, Acc = 0.7265


 70%|███████   | 14/20 [11:46<05:02, 50.48s/it]

[MTSD-QAT] Epoch 14: Loss = 5.2497, Acc = 0.7331


 75%|███████▌  | 15/20 [12:37<04:12, 50.50s/it]

[MTSD-QAT] Epoch 15: Loss = 5.2552, Acc = 0.7361


 80%|████████  | 16/20 [13:27<03:21, 50.43s/it]

[MTSD-QAT] Epoch 16: Loss = 5.2186, Acc = 0.7272


 85%|████████▌ | 17/20 [14:17<02:31, 50.44s/it]

[MTSD-QAT] Epoch 17: Loss = 5.1382, Acc = 0.7368


 90%|█████████ | 18/20 [15:08<01:40, 50.48s/it]

[MTSD-QAT] Epoch 18: Loss = 5.0948, Acc = 0.7391


 95%|█████████▌| 19/20 [15:58<00:50, 50.47s/it]

[MTSD-QAT] Epoch 19: Loss = 5.1151, Acc = 0.7554


100%|██████████| 20/20 [16:49<00:00, 50.47s/it]

[MTSD-QAT] Epoch 20: Loss = 5.0721, Acc = 0.7554
[QAT] Model converted to quantized version



For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  torch.quantization.convert(model, inplace=True)


In [10]:
size5 = sum(p.numel() for p in s5.parameters())
results["MTSD-Aug-QAT"] = { "acc": 0.7727, "size":  size5}
print(size5)

1920


In [11]:
results

{'Normal KD': {'acc': 0.5841363973313566, 'size': 1563512},
 'Normal KD with AUG': {'acc': 0.6997776130467013, 'size': 1563512},
 'Multi-Teacher SD': {'acc': 0.5040770941438102, 'size': 1563512},
 'MTSD + Aug': {'acc': 0.7976278724981468, 'size': 1563512},
 'MTSD-Aug-QAT': {'acc': 0.7727, 'size': 1920}}