# 加载并初始化

In [1]:
import sys, json
with open('config.json') as f:
    args = json.load(f)
sys.path.append('..')
PATH = '../RKDSC-github/results/'

import torch
import torch.nn as nn
from tqdm.auto import tqdm 
from models import GenTeacher, GenStudentCNN, GenClassifier
from models.model import AttentionAutoEncoder
from utils import save_model, test_on_val
from utils.dataset import get_datasets
from utils.channels import Channels

device = args["global"]["device"]
teacher_model_name = args["model"]["teacher_model_name"]
student = GenStudentCNN(device)
teacher = GenTeacher(teacher_model_name, device)

print(f"The number of parameters of student model is: {sum(p.numel() for p in student.parameters()) / 1e6:.2f}M" )
print(f"The number of parameters of teacher model is: {sum(p.numel() for p in teacher.parameters()) / 1e6:.2f}M" )

The number of parameters of student model is: 2.92M
The number of parameters of teacher model is: 87.85M


In [2]:
train_loader, test_loader = get_datasets(args["data"]["path"], args["data"]["batch_size"], args["data"]["dataset_name"])

# 蒸馏学生模型

In [None]:
teacher.eval()
student.train()
MIN_LOSS = 5
optimizer_stu = torch.optim.Adam(student.parameters(), lr=args["distillation"]["learning_rate"])
criterion_KD = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer_stu, step_size=args["distillation"]["epochs"] // 2, gamma=0.1)

epoch_bar = tqdm(range(args["distillation"]["epochs"]), desc="Distillation")
for epoch in epoch_bar:
    batch_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}", leave=False)
    for i, (data, _) in batch_bar:
        data = data.to(device)

        optimizer_stu.zero_grad()
        teacher_output = teacher(data)
        student_output = student(data)

        loss_KD = criterion_KD(student_output, teacher_output)
        batch_bar.set_postfix({'KD Loss': loss_KD.item(), 'LR': scheduler.get_last_lr()[0]})

        loss_KD.backward()
        optimizer_stu.step()

    scheduler.step()

    if loss_KD.item() < MIN_LOSS:
        MIN_LOSS = loss_KD.item()
        save_model(student, args["distillation"]["model_saved_path"] + 'source_encoder_kd.pt')

    epoch_bar.set_postfix({'KD Loss': loss_KD.item(), 'LR': scheduler.get_last_lr()[0]})

Distillation:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 0:   0%|          | 0/500 [00:00<?, ?it/s]

# 训练学生模型对应的分类网络

In [None]:
from tqdm.notebook import tqdm  # 或者 from tqdm.auto import tqdm

classifier = GenClassifier(512, 100, [], device)
classifier.train()
student.eval()

optimizer_cla = torch.optim.Adam(classifier.parameters(), lr=args["classification"]["learning_rate"])
criterion_CE = nn.CrossEntropyLoss()

# 外层进度条：显示 epoch 进度
for epoch in tqdm(range(args["classification"]["epochs"]), desc="Epoch", ncols=150):
    # 内层进度条：显示当前 epoch 内 batch 的进度，leave=False 表示本 epoch 结束后该进度条自动清除
    with tqdm(enumerate(train_loader),
              total=len(train_loader),
              desc=f"Epoch {epoch+1} Batch",
              ncols=150,
              leave=False) as batch_bar:
        for i, (data, target) in batch_bar:
            data, target = data.to(device), target.to(device)

            optimizer_cla.zero_grad()
            student_output = student(data)
            output = classifier(student_output)

            loss_CE = criterion_CE(output, target)
            acc = (output.argmax(dim=1) == target).float().mean()
            # 更新内层进度条的状态信息
            batch_bar.set_postfix({'CE Loss': loss_CE.item(), 'Accuracy': acc.item()})

            loss_CE.backward()
            optimizer_cla.step()

    # 每个 epoch 结束后保存模型
    save_model(classifier, args["distillation"]["model_saved_path"] + 'source_decoder_kd.pt')

# 测试分类器效果
acc_test = test_on_val(student, classifier, test_loader, device)

In [None]:
from tqdm.notebook import tqdm

atten_ae = AttentionAutoEncoder(512, 2).to(device)
channel = Channels(device)
optimizer_cla = torch.optim.Adam(classifier.parameters(), lr=5e-5)
optimizer_ae = torch.optim.Adam(atten_ae.parameters(), lr=5e-5)
optimizer_stu = torch.optim.Adam(student.parameters(), lr=5e-4)

criterion_AE = torch.nn.MSELoss()
criterion_CE = torch.nn.CrossEntropyLoss()

student.train()
classifier.train()
atten_ae.train()

TRAIN_SNR_LIST = [-5, -2, 0, 5, 10, 15, 20]

for epoch, snr in enumerate(tqdm(TRAIN_SNR_LIST, desc="Epoch", ncols=120)):
    batch_bar = tqdm(train_loader,
                     leave=False,
                     ncols=120,
                     desc=f"Epoch {epoch+1} Training on SNR={snr}")
    
    for i, (images, labels) in enumerate(batch_bar):
        images, labels = images.to(device), labels.to(device)

        optimizer_cla.zero_grad()
        optimizer_ae.zero_grad()
        optimizer_stu.zero_grad()

        features_stu = student(images)
        features_tea = teacher(images)

        latten_ae = atten_ae.encoder(features_stu)
        latten_ae = channel.AWGN(latten_ae, snr)
        features_hat = atten_ae.decoder(latten_ae)

        pre = classifier(features_hat)

        loss_AE = criterion_AE(features_hat, features_tea)
        loss_CE = criterion_CE(pre, labels)
        loss = loss_AE + loss_CE
        
        acc = (pre.argmax(dim=1) == labels).float().mean()

        loss.backward()

        optimizer_cla.step()
        optimizer_ae.step()
        optimizer_stu.step()

        batch_bar.set_postfix(loss=loss_AE.item(), acc=acc.item())

In [None]:
from tqdm.notebook import tqdm  # 或者使用 from tqdm.auto import tqdm

student.eval()
classifier.eval()
atten_ae.eval()
channels = Channels(device)

ACC_with_SNR = []

# 外层进度条：遍历 SNR 值
outer_bar = tqdm(range(-10, 25, 2), desc='Testing over SNR', ncols=120)
for snr in outer_bar:
    # 内层进度条：遍历 test_loader 的所有 batch
    batch_bar = tqdm(test_loader, desc=f"Testing on SNR: {snr}", leave=False, ncols=120)
    total_correct = 0
    total_samples = 0
    
    for images, labels in batch_bar:
        images, labels = images.to(device), labels.to(device)
        
        with torch.no_grad():
            features = student(images)
            features = atten_ae.encoder(features)
            features_awgn = channels.AWGN(features, snr)
            features_awgn = atten_ae.decoder(features_awgn)
            logits = classifier(features_awgn)
            preds = logits.argmax(dim=1)
            
            total_correct += (preds == labels).float().sum().item()
            total_samples += labels.size(0)
        
        # 实时更新内层进度条显示信息
        batch_bar.set_postfix({'Acc': total_correct / total_samples})
    
    # 计算当前 SNR 下的准确率，并更新外层进度条显示
    acc = total_correct / total_samples
    ACC_with_SNR.append(acc)
    outer_bar.set_postfix({'SNR': snr, 'Acc': acc})

print(ACC_with_SNR)