# 加载并初始化

In [1]:
import sys, json
with open('config.json') as f:
    args = json.load(f)
sys.path.append('..')
PATH = '../RKD-SC/results/'

import torch
import torch.nn as nn
from tqdm.auto import tqdm 
from models import GenTeacher, GenStudentCNN, GenClassifier
from models.model import AttentionAutoEncoder
from utils import save_model, test_on_val
from utils.dataset import get_datasets
from utils.channels import Channels

device = args["global"]["device"]
teacher_model_name = args["model"]["teacher_model_name"]
student = GenStudentCNN(device)
teacher = GenTeacher(teacher_model_name, device)
classifier = GenClassifier(512, 100, [], device)
student.load_state_dict(torch.load('/home/ubuntu/users/dky/RKD-SC/results/distillation/source_encoder_kd.pt'))
classifier.load_state_dict(torch.load('/home/ubuntu/users/dky/RKD-SC/results/distillation/source_decoder_kd.pt'))

print(f"The number of parameters of student model is: {sum(p.numel() for p in student.parameters()) / 1e6:.2f}M" )
print(f"The number of parameters of teacher model is: {sum(p.numel() for p in teacher.parameters()) / 1e6:.2f}M" )

The number of parameters of student model is: 2.92M
The number of parameters of teacher model is: 87.85M


In [2]:
train_loader, test_loader = get_datasets(args["data"]["path"], args["data"]["batch_size"], args["data"]["dataset_name"])

# 蒸馏学生模型

In [3]:
teacher.eval()
student.train()
MIN_LOSS = 5
optimizer_stu = torch.optim.Adam(student.parameters(), lr=args["distillation"]["learning_rate"])
criterion_KD = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer_stu, step_size=args["distillation"]["epochs"] // 2, gamma=0.1)

epoch_bar = tqdm(range(args["distillation"]["epochs"]), desc="Distillation")
for epoch in epoch_bar:
    batch_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}", leave=False)
    for i, (data, _) in batch_bar:
        data = data.to(device)

        optimizer_stu.zero_grad()
        teacher_output = teacher(data)
        student_output = student(data)

        loss_KD = criterion_KD(student_output, teacher_output)
        batch_bar.set_postfix({'KD Loss': loss_KD.item(), 'LR': scheduler.get_last_lr()[0]})

        loss_KD.backward()
        optimizer_stu.step()

    scheduler.step()

    if loss_KD.item() < MIN_LOSS:
        MIN_LOSS = loss_KD.item()
        save_model(student, args["distillation"]["model_saved_path"] + 'source_encoder_kd.pt')

    epoch_bar.set_postfix({'KD Loss': loss_KD.item(), 'LR': scheduler.get_last_lr()[0]})

Distillation:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 0:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 3:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 5:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 6:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 7:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 8:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 9:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 10:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 11:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 12:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 13:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 14:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 15:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 16:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 17:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 18:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 19:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 20:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 21:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 22:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 23:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 24:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 25:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 26:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 27:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 28:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 29:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 30:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 31:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 32:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 33:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 34:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 35:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 36:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 37:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 38:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 39:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 40:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 41:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 42:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 43:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 44:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 45:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 46:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 47:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 48:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 49:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 51:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 52:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 53:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 54:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 55:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 56:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 57:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 58:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 59:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 60:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 61:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 62:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 63:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 64:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 65:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 66:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 67:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 68:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 69:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 70:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 71:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 72:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 73:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 74:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 75:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 76:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 77:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 78:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 79:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 80:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 81:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 82:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 83:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 84:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 85:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 86:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 87:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 88:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 89:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 90:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 91:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 92:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 93:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 94:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 95:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 96:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 97:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 98:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 99:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 100:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 101:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 102:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 103:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 104:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 105:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 106:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 107:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 108:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 109:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 110:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 111:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 112:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 113:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 114:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 115:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 116:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 117:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 118:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 119:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 120:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 121:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 122:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 123:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 124:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 125:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 126:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 127:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 128:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 129:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 130:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 131:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 132:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 133:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 134:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 135:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 136:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 137:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 138:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 139:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 140:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 141:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 142:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 143:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 144:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 145:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 146:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 147:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 148:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 149:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 150:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 151:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 152:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 153:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 154:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 155:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 156:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 157:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 158:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 159:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 160:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 161:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 162:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 163:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 164:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 165:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 166:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 167:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 168:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 169:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 170:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 171:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 172:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 173:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 174:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 175:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 176:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 177:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 178:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 179:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 180:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 181:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 182:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 183:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 184:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 185:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 186:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 187:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 188:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 189:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 190:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 191:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 192:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 193:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 194:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 195:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 196:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 197:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 198:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 199:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 200:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 201:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 202:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 203:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 204:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 205:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 206:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 207:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 208:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 209:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 210:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 211:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 212:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 213:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 214:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 215:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 216:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 217:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 218:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 219:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 220:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 221:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 222:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 223:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 224:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 225:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 226:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 227:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 228:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 229:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 230:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 231:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 232:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 233:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 234:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 235:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 236:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 237:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 238:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 239:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 240:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 241:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 242:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 243:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 244:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 245:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 246:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 247:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 248:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 249:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 250:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 251:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 252:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 253:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 254:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 255:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 256:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 257:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 258:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 259:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 260:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 261:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 262:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 263:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 264:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 265:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 266:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 267:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 268:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 269:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 270:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 271:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 272:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 273:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 274:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 275:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 276:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 277:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 278:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 279:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 280:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 281:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 282:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 283:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 284:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 285:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 286:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 287:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 288:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 289:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 290:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 291:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 292:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 293:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 294:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 295:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 296:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 297:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 298:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 299:   0%|          | 0/500 [00:00<?, ?it/s]

# 训练学生模型对应的分类网络

In [None]:
classifier = GenClassifier(512, 100, [], device)
classifier.train()
student.eval()

optimizer_cla = torch.optim.Adam(classifier.parameters(), lr=args["classification"]["learning_rate"])
criterion_CE = nn.CrossEntropyLoss()

for epoch in tqdm(range(args["classification"]["epochs"]), desc="Epoch"):
    with tqdm(enumerate(train_loader),
              total=len(train_loader),
              desc=f"Epoch {epoch+1} Batch",
              leave=False) as batch_bar:
        for i, (data, target) in batch_bar:
            data, target = data.to(device), target.to(device)

            optimizer_cla.zero_grad()
            student_output = student(data)
            output = classifier(student_output)

            loss_CE = criterion_CE(output, target)
            acc = (output.argmax(dim=1) == target).float().mean()
            batch_bar.set_postfix({'CE Loss': loss_CE.item(), 'Accuracy': acc.item()})

            loss_CE.backward()
            optimizer_cla.step()

    save_model(classifier, args["distillation"]["model_saved_path"] + 'source_decoder_kd.pt')

acc_test = test_on_val(student, classifier, test_loader, device)

Testing on val... 	Accuracy:  0.6836


In [4]:
atten_ae = AttentionAutoEncoder(512, 2).to(device)
channel = Channels(device)
optimizer_cla = torch.optim.Adam(classifier.parameters(), lr=1e-4)
optimizer_ae = torch.optim.Adam(atten_ae.parameters(), lr=1e-4)
optimizer_stu = torch.optim.Adam(student.parameters(), lr=1e-4)

criterion_AE = torch.nn.MSELoss()
criterion_CE = torch.nn.CrossEntropyLoss()

student.train()
classifier.train()
atten_ae.train()

TRAIN_SNR_LIST = [-5, -2, 0, 5, 10, 15, 20]

for epoch, snr in enumerate(tqdm(TRAIN_SNR_LIST, desc="Epoch")):
    batch_bar = tqdm(train_loader,
                     leave=False,
                     desc=f"Epoch {epoch+1} Training on SNR={snr}")
    
    for i, (images, labels) in enumerate(batch_bar):
        images, labels = images.to(device), labels.to(device)

        optimizer_cla.zero_grad()
        optimizer_ae.zero_grad()
        optimizer_stu.zero_grad()

        features_stu = student(images)
        features_tea = teacher(images)

        latten_ae = atten_ae.encoder(features_stu)
        latten_ae = channel.AWGN(latten_ae, snr)
        features_hat = atten_ae.decoder(latten_ae)

        pre = classifier(features_hat)

        loss_AE = criterion_AE(features_hat, features_tea)
        loss_CE = criterion_CE(pre, labels)
        loss = loss_AE + loss_CE
        
        acc = (pre.argmax(dim=1) == labels).float().mean()

        loss.backward()

        optimizer_cla.step()
        optimizer_ae.step()
        optimizer_stu.step()

        batch_bar.set_postfix(loss=loss_AE.item(), acc=acc.item())

Epoch:   0%|          | 0/7 [00:00<?, ?it/s]

Epoch 1 Training on SNR=-5:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 2 Training on SNR=-2:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 3 Training on SNR=0:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 4 Training on SNR=5:   0%|          | 0/500 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fe2d423d280>
Traceback (most recent call last):
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1462, in _shutdown_workers
    if w.is_alive():
  File "/home/ubuntu/users/dky/anaconda3/envs/KD/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fe2d423d280>
Traceback (most recent call last):
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packa

Epoch 5 Training on SNR=10:   0%|          | 0/500 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fe2d423d280>
Traceback (most recent call last):
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1462, in _shutdown_workers
    if w.is_alive():
  File "/home/ubuntu/users/dky/anaconda3/envs/KD/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fe2d423d280>
Traceback (most recent call last):
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packa

Epoch 6 Training on SNR=15:   0%|          | 0/500 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fe2d423d280>
Traceback (most recent call last):
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1462, in _shutdown_workers
    if w.is_alive():
  File "/home/ubuntu/users/dky/anaconda3/envs/KD/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fe2d423d280>
Traceback (most recent call last):
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packa

Epoch 7 Training on SNR=20:   0%|          | 0/500 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fe2d423d280>
Traceback (most recent call last):
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1462, in _shutdown_workers
    if w.is_alive():
  File "/home/ubuntu/users/dky/anaconda3/envs/KD/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fe2d423d280>
Traceback (most recent call last):
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fe2d423d280>  

In [5]:
student.eval()
classifier.eval()
atten_ae.eval()
channels = Channels(device)

ACC_with_SNR = []

outer_bar = tqdm(range(-10, 25, 2), desc='Testing over SNR')
for snr in outer_bar:

    batch_bar = tqdm(test_loader, desc=f"Testing on SNR: {snr}", leave=False)
    total_correct = 0
    total_samples = 0
    
    for images, labels in batch_bar:
        images, labels = images.to(device), labels.to(device)
        
        with torch.no_grad():
            features = student(images)
            features = atten_ae.encoder(features)
            features_awgn = channels.AWGN(features, snr)
            features_awgn = atten_ae.decoder(features_awgn)
            logits = classifier(features_awgn)
            preds = logits.argmax(dim=1)
            
            total_correct += (preds == labels).float().sum().item()
            total_samples += labels.size(0)
        
        batch_bar.set_postfix({'Acc': total_correct / total_samples})

    acc = total_correct / total_samples
    ACC_with_SNR.append(acc)
    outer_bar.set_postfix({'SNR': snr, 'Acc': acc})

print(ACC_with_SNR)

Testing over SNR:   0%|          | 0/18 [00:00<?, ?it/s]

Testing on SNR: -10:   0%|          | 0/100 [00:00<?, ?it/s]

Testing on SNR: -8:   0%|          | 0/100 [00:00<?, ?it/s]

Testing on SNR: -6:   0%|          | 0/100 [00:00<?, ?it/s]

Testing on SNR: -4:   0%|          | 0/100 [00:00<?, ?it/s]

Testing on SNR: -2:   0%|          | 0/100 [00:00<?, ?it/s]

Testing on SNR: 0:   0%|          | 0/100 [00:00<?, ?it/s]

Testing on SNR: 2:   0%|          | 0/100 [00:00<?, ?it/s]

Testing on SNR: 4:   0%|          | 0/100 [00:00<?, ?it/s]

Testing on SNR: 6:   0%|          | 0/100 [00:00<?, ?it/s]

Testing on SNR: 8:   0%|          | 0/100 [00:00<?, ?it/s]

Testing on SNR: 10:   0%|          | 0/100 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fe2d423d280>
Traceback (most recent call last):
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1462, in _shutdown_workers
    if w.is_alive():
  File "/home/ubuntu/users/dky/anaconda3/envs/KD/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fe2d423d280><function _MultiProcessingDataLoaderIter.__del__ at 0x7fe2d423d280>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/data

Testing on SNR: 12:   0%|          | 0/100 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fe2d423d280>
Exception ignored in: Exception ignored in: Traceback (most recent call last):
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fe2d423d280><function _MultiProcessingDataLoaderIter.__del__ at 0x7fe2d423d280>    

self._shutdown_workers()Traceback (most recent call last):

Traceback (most recent call last):
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
  File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
      File "/home/ubuntu/users/dky/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1462, in _shutdown_workers
Exception ignored in: Exception ignored in: Exception ignored in: Exception ignored in:     <funct

Testing on SNR: 14:   0%|          | 0/100 [00:00<?, ?it/s]

Testing on SNR: 16:   0%|          | 0/100 [00:00<?, ?it/s]

Testing on SNR: 18:   0%|          | 0/100 [00:00<?, ?it/s]

Testing on SNR: 20:   0%|          | 0/100 [00:00<?, ?it/s]

Testing on SNR: 22:   0%|          | 0/100 [00:00<?, ?it/s]

Testing on SNR: 24:   0%|          | 0/100 [00:00<?, ?it/s]

[0.2676, 0.3524, 0.4505, 0.5205, 0.5795, 0.6251, 0.6505, 0.6761, 0.6925, 0.6981, 0.7014, 0.7058, 0.71, 0.7106, 0.7098, 0.71, 0.7109, 0.7132]
