In [14]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import Models
from Models.embtrans_cifar import EmbTrans
from Dataset import CIFAR
from utils import colorstr, Save_Checkpoint, AverageMeter, DirectNormLoss, KDLoss

import numpy as np
from pathlib import Path
import time
import json
import random
import logging
import argparse
import warnings
from torch.utils.tensorboard import SummaryWriter
import pdb
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report, precision_recall_fscore_support


def compare_model_size(teacher, student):
    teacher_params = sum(p.numel() for p in teacher.parameters())
    student_params = sum(p.numel() for p in student.parameters())
    return teacher_params, student_params

def compare_inference_time(teacher, student, dataloader):
    inputs, _ = next(iter(dataloader))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    teacher = teacher.to(device)
    student = student.to(device)
    inputs = inputs.to(device)
    
    start_time = time.time()
    with torch.no_grad():
        _, teacher_outputs = teacher(inputs)
    teacher_time = time.time() - start_time

    start_time = time.time()
    with torch.no_grad():
        _, student_outputs = student(inputs)
    student_time = time.time() - start_time
    
    return teacher_time, student_time

def compare_performance_metrics(teacher, student, dataloader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    teacher.eval()
    student.eval()
    
    all_labels = []
    all_teacher_preds = []
    all_student_preds = []

    for inputs, labels in dataloader:
        with torch.no_grad():
            _, teacher_outputs = teacher(inputs.to(device))
            _, student_outputs = student(inputs.to(device))
        all_labels.append(labels.cpu().numpy())
        all_teacher_preds.append(torch.argmax(teacher_outputs, dim=1).cpu().numpy())
        all_student_preds.append(torch.argmax(student_outputs, dim=1).cpu().numpy())

    all_labels = np.concatenate(all_labels)
    all_teacher_preds = np.concatenate(all_teacher_preds)
    all_student_preds = np.concatenate(all_student_preds)
    
    metrics = {
        'accuracy': (accuracy_score(all_labels, all_teacher_preds), accuracy_score(all_labels, all_student_preds)),
        'precision': (precision_score(all_labels, all_teacher_preds, average='weighted', zero_division=0), precision_score(all_labels, all_student_preds, average='weighted', zero_division=0)),  # Updated line
        'recall': (recall_score(all_labels, all_teacher_preds, average='weighted'), recall_score(all_labels, all_student_preds, average='weighted')),
        'f1': (f1_score(all_labels, all_teacher_preds, average='weighted'), f1_score(all_labels, all_student_preds, average='weighted'))
    }

    return metrics

def train(student, teacher, T_EMB, train_dataloader, optimizer, criterion, kd_loss, nd_loss, args, epoch):
    train_loss = AverageMeter()
    train_error = AverageMeter()

    Cls_loss = AverageMeter()
    Div_loss = AverageMeter()
    Norm_Dir_loss = AverageMeter()

    # Model on train mode
    student.train()
    teacher.eval()
    step_per_epoch = len(train_dataloader)

    for step, (images, labels) in enumerate(train_dataloader):
        start = time.time()
        images, labels = images.cuda(), labels.cuda()

        # compute output
        s_emb, s_logits = student(images, embed=True)

        with torch.no_grad():
            t_emb, t_logits = teacher(images, embed=True)

        # cls loss
        cls_loss = criterion(s_logits, labels) * args.cls_loss_factor
        # KD loss
        div_loss = kd_loss(s_logits, t_logits) * min(1.0, epoch/args.warm_up)
        # ND loss
        norm_dir_loss = nd_loss(s_emb=s_emb, t_emb=t_emb, T_EMB=T_EMB, labels=labels)

        loss = cls_loss + div_loss + norm_dir_loss
        # measure accuracy and record loss
        batch_size = images.size(0)
        _, pred = s_logits.data.cpu().topk(1, dim=1)
        train_error.update(torch.ne(pred.squeeze(), labels.cpu()).float().sum().item() / batch_size, batch_size)
        train_loss.update(loss.item(), batch_size)

        Cls_loss.update(cls_loss.item(), batch_size)
        Div_loss.update(div_loss.item(), batch_size)
        Norm_Dir_loss.update(norm_dir_loss.item(), batch_size)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        t = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
        s1 = '\r{} [{}/{}]'.format(t, step+1, step_per_epoch)
        s2 = ' - {:.2f}ms/step - nd_loss: {:.3f} - kd_loss: {:.3f} - cls_loss: {:.3f} - train_loss: {:.3f} - train_acc: {:.3f}'.format(
             1000 * (time.time() - start), norm_dir_loss.item(), div_loss.item(), cls_loss.item(), train_loss.val, 1-train_error.val)

        print(s1+s2, end='', flush=True)

    print()
    return Norm_Dir_loss.avg, Div_loss.avg, Cls_loss.avg, train_loss.avg, train_error.avg


def test(student, test_dataloader, criterion):
    test_loss = AverageMeter()
    test_error = AverageMeter()

    # Model on eval mode
    student.eval()

    with torch.no_grad():
        for images, labels in test_dataloader:
            images, labels = images.cuda(), labels.cuda()

            # compute logits
            logits = student(images, embed=False)

            loss = criterion(logits, labels)

            # measure accuracy and record loss
            batch_size = images.size(0)
            _, pred = logits.data.cpu().topk(1, dim=1)
            test_error.update(torch.ne(pred.squeeze(), labels.cpu()).float().sum().item() / batch_size, batch_size)
            test_loss.update(loss.item(), batch_size)
    

    return test_loss.avg, test_error.avg


def epoch_loop(student, teacher, train_set, test_set, args):
    # data loaders
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers)
    test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers)

    # student
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # student = nn.DataParallel(student, device_ids=args.gpus)
    student = nn.DataParallel(student)
    student.to(device)
    # teacher = nn.DataParallel(teacher, device_ids=args.gpus)
    teacher = nn.DataParallel(teacher)
    teacher.to(device)

    # loss
    criterion = nn.CrossEntropyLoss().to(device)
    kd_loss = KDLoss(kl_loss_factor=args.kd_loss_factor, T=args.t).to(device)
    nd_loss = DirectNormLoss(num_class=100, nd_loss_factor=args.nd_loss_factor).to(device)
    # optimizer
    optimizer = torch.optim.SGD(params=student.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)

    # weights
    save_dir = Path(args.save_dir)
    weights = save_dir / 'weights'
    weights.mkdir(parents=True, exist_ok=True)
    last = weights / 'last'
    best = weights / 'best'

    # acc,loss
    acc_loss = save_dir / 'acc_loss'
    acc_loss.mkdir(parents=True, exist_ok=True)

    train_acc_savepath = acc_loss / 'train_acc.npy'
    train_loss_savepath = acc_loss / 'train_loss.npy'
    val_acc_savepath = acc_loss / 'val_acc.npy'
    val_loss_savepath = acc_loss / 'val_loss.npy'

    # tensorboard
    logdir = save_dir / 'logs'
    logdir.mkdir(parents=True, exist_ok=True)
    summary_writer = SummaryWriter(logdir, flush_secs=120)

    # resume
    if args.resume:
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        student.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        best_error = checkpoint['best_error']
        train_acc = checkpoint['train_acc']
        train_loss = checkpoint['train_loss']
        test_acc = checkpoint['test_acc']
        test_loss = checkpoint['test_loss']
        logger.info(colorstr('green', 'Resuming training from {} epoch'.format(start_epoch)))
    else:
        start_epoch = 0
        best_error = 0
        train_acc = []
        train_loss = []
        test_acc = []
        test_loss = []
        test_precision = []
        test_recall = []
        test_f1 = []

    # Train student
    best_error = 1
    ##
    patience = args.patience
    best_val_accuracy = 0
    best_val_loss = float('inf')
    epoch_val_losses = []
    epoch_val_accuracies = []
    ##
    for epoch in range(start_epoch, args.epochs):
        if epoch in [150, 180, 210]:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
        print("Epoch {}/{}".format(epoch + 1, args.epochs))
        norm_dir_loss, div_loss, cls_loss, train_epoch_loss, train_error = train(student=student,
                                                                                 teacher=teacher,
                                                                                 T_EMB=T_EMB,
                                                                                 train_dataloader=train_loader,
                                                                                 optimizer=optimizer,
                                                                                 criterion=criterion,
                                                                                 kd_loss=kd_loss,
                                                                                 nd_loss=nd_loss,
                                                                                 args=args,
                                                                                 epoch=epoch)
        test_epoch_loss, test_error = test(student=student,
                                        test_dataloader=test_loader,
                                        criterion=criterion)

        s = "Train Loss: {:.3f}, Train Acc: {:.3f}, Test Loss: {:.3f}, Test Acc: {:.3f}, lr: {:.5f}".format(
            train_epoch_loss, 1-train_error, test_epoch_loss, 1-test_error, optimizer.param_groups[0]['lr'])
        logger.info(colorstr('green', s))

        # save acc,loss
        train_loss.append(train_epoch_loss)
        train_acc.append(1-train_error)
        test_loss.append(test_epoch_loss)
        test_acc.append(1-test_error)

        epoch_val_accuracies.append(1-test_error)
        epoch_val_losses.append(test_epoch_loss)

        # save student
        is_best = test_error < best_error
        best_error = min(best_error, test_error)
        state = {
                'epoch': epoch + 1,
                'model_state_dict': student.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_error': best_error,
                'train_acc': train_acc,
                'train_loss': train_loss,
                'test_acc': test_acc,
                'test_loss': test_loss,
            }

        # last_path = last / 'epoch_{}_loss_{:.3f}_acc_{:.3f}'.format(
        last_path = last / 'last_'
        # best_path = best / 'epoch_{}_acc_{:.3f}'.format(
        best_path = best / 'best_'

        Save_Checkpoint(state, last, last_path, best, best_path, is_best)

        # tensorboard
        if epoch == 1:
            images, labels = next(iter(train_loader))
            img_grid = torchvision.utils.make_grid(images)
            summary_writer.add_image('Cifar Image', img_grid)
        summary_writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
        summary_writer.add_scalar('train_loss', train_epoch_loss, epoch)
        summary_writer.add_scalar('train_error', train_error, epoch)
        summary_writer.add_scalar('val_loss', test_epoch_loss, epoch)
        summary_writer.add_scalar('val_error', test_error, epoch)

        summary_writer.add_scalar('nd_loss', norm_dir_loss, epoch)
        summary_writer.add_scalar('kd_loss', div_loss, epoch)
        summary_writer.add_scalar('cls_loss', cls_loss, epoch)

        ### 
        # Check if current validation combined loss is lower than the best combined loss
        if test_epoch_loss < best_val_loss:
            best_val_loss = test_epoch_loss
            best_val_accuracy = 1-test_error
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            # print results
            best_checkpoint = torch.load(os.path.join(best_path, 'ckpt.pth'))['model_state_dict']
            student.load_state_dict(best_checkpoint)
            metrics = compare_performance_metrics(teacher, student, test_loader)
            teacher_time, student_time = compare_inference_time(teacher, student, test_loader)
            teacher_size, student_size = compare_model_size(teacher, student)

            final_report_banner = '- - - - - METRICS REPORT - - - - -'
            teacher_metrics = "TEACHER: accuracy: {:.3f}, precision: {:.3f}, recall: {:.3f}, f1: {:.3f}, teacher_inf: {:.3f}, teacher_size: {:.3f}".format(
                100*metrics['accuracy'][0], 100*metrics['precision'][0], 100*metrics['recall'][0], 100*metrics['f1'][0], 
                teacher_time, teacher_size,)
            student_metrics = "STUDENT: accuracy: {:.3f}, precision: {:.3f}, recall: {:.3f}, f1: {:.3f}, student_inf: {:.3f}, student_size: {:.3f}".format(
                100*metrics['accuracy'][1], 100*metrics['precision'][1], 100*metrics['recall'][1], 100*metrics['f1'][1], 
                 student_time, student_size)
            logger.info(colorstr('green', final_report_banner))
            logger.info(colorstr('green', teacher_metrics))
            logger.info(colorstr('green', student_metrics))
            break

        if epoch == (args.epochs - 1):
            best_checkpoint = torch.load(os.path.join(best_path, 'ckpt.pth'))['model_state_dict']
            student.load_state_dict(best_checkpoint)
            metrics = compare_performance_metrics(teacher, student, test_loader)
            teacher_time, student_time = compare_inference_time(teacher, student, test_loader)
            teacher_size, student_size = compare_model_size(teacher, student)

            final_report_banner = '- - - - - METRICS REPORT - - - - -'
            teacher_metrics = "TEACHER: accuracy: {:.3f}, precision: {:.3f}, recall: {:.3f}, f1: {:.3f}, teacher_inf: {:.3f}, teacher_size: {:.3f}".format(
                100*metrics['accuracy'][0], 100*metrics['precision'][0], 100*metrics['recall'][0], 100*metrics['f1'][0], 
                teacher_time, teacher_size,)
            student_metrics = "STUDENT: accuracy: {:.3f}, precision: {:.3f}, recall: {:.3f}, f1: {:.3f}, student_inf: {:.3f}, student_size: {:.3f}".format(
                100*metrics['accuracy'][1], 100*metrics['precision'][1], 100*metrics['recall'][1], 100*metrics['f1'][1], 
                 student_time, student_size)
            logger.info(colorstr('green', final_report_banner))
            logger.info(colorstr('green', teacher_metrics))
            logger.info(colorstr('green', student_metrics))
            
        
        ###

    summary_writer.close()
    if not os.path.exists(train_acc_savepath) or not os.path.exists(train_loss_savepath):
        np.save(train_acc_savepath, train_acc)
        np.save(train_loss_savepath, train_loss)
        np.save(val_acc_savepath, test_acc)
        np.save(val_loss_savepath, test_loss)
    
    return student, teacher, test_loader

    

if __name__ == "__main__":
    student_names = sorted(name for name in Models.__dict__
                         if name.islower() and not name.startswith("__")
                         and callable(Models.__dict__[name]))

    parser = argparse.ArgumentParser(description='PyTorch Cifar Training')
    parser.add_argument('-f') # added to make this run in collab
    parser.add_argument("--student_name", type=str, default="resnet20_cifar", choices=student_names, help="student architecture")
    parser.add_argument("--dataset", type=str, default='cifar100')
    parser.add_argument("--epochs", type=int, default=240)
    # parser.add_argument("--epochs", type=int, default=4)
    parser.add_argument("--batch_size", type=int, default=128, help="batch size per gpu")
    parser.add_argument('--workers', default=32, type=int, help='number of data loading workers')
    parser.add_argument("--lr", type=float, default=0.1)
    parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
    parser.add_argument("--weight_decay", type=float, default=5e-4)

    parser.add_argument("--teacher", type=str, default="resnet56_cifar", help="teacher architecture")
    parser.add_argument("--teacher_weights", type=str, default="./ckpt/cifar_teachers/resnet56_vanilla/ckpt_epoch_240.pth", help="teacher weights path")
    parser.add_argument("--cls_loss_factor", type=float, default=1.0, help="cls loss weight factor")
    parser.add_argument("--kd_loss_factor", type=float, default=1.0, help="KD loss weight factor")
    parser.add_argument("--t", type=float, default=4.0, help="temperature")
    parser.add_argument("--nd_loss_factor", type=float, default=1.0, help="ND loss weight factor")
    parser.add_argument("--warm_up", type=float, default=20.0, help='loss weight warm up epochs')
    parser.add_argument("--patience", type=int, default=5, help='loss weight warm up epochs')


    # parser.add_argument("--gpus", type=list, default=[0, 1])
    
    parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.')
    parser.add_argument("--resume", type=str, help="best ckpt's path to resume most recent training")
    parser.add_argument("--save_dir", type=str, default="./run/KD++", help="save path, eg, acc_loss, weights, tensorboard, and so on")
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        cudnn.benchmark = False
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    logging.basicConfig(level=logging.INFO, format='%(asctime)s [line:%(lineno)d] %(message)s', 
                        datefmt='%d %b %Y %H:%M:%S')
    logger = logging.getLogger(__name__)

    # args.batch_size = args.batch_size * len(args.gpus)
    args.batch_size = args.batch_size * 1

    # logger.info(colorstr('green', "Distribute train, gpus:{}, total batch size:{}, epoch:{}".format(args.gpus, args.batch_size, args.epochs)))
    logger.info(colorstr('green', "Distribute train, total batch size:{}, epoch:{}".format(args.batch_size, args.epochs)))


    train_set, test_set, num_class = CIFAR(name=args.dataset)
    student = Models.__dict__[args.student_name](num_class=num_class)

    if args.student_name in ['wrn40_1_cifar', 'mobilenetv2', 'shufflev1_cifar', 'shufflev2_cifar']:
        student = EmbTrans(student=student, student_name=args.student_name)

    teacher = Models.__dict__[args.teacher](num_class=num_class)

    if args.teacher_weights:
        print('Load Teacher Weights')
        teacher_ckpt = torch.load(args.teacher_weights)['model']
        teacher.load_state_dict(teacher_ckpt)

        for param in teacher.parameters():
            param.requires_grad = False

    # res56    ./ckpt/teacher/resnet56/center_emb_train.json
    # res32x4  ./ckpt/teacher/resnet32x4/center_emb_train.json
    # wrn40_2  ./ckpt/teacher/wrn_40_2/center_emb_train.json
    # res50    ./ckpt/teacher/resnet50/center_emb_train.json
    # class-mean
    with open("./ckpt/teacher/resnet56/center_emb_train.json", 'r') as f:
        T_EMB = json.load(f)
    f.close()

    logger.info(colorstr('green', 'Use ' + args.teacher + ' Training ' + args.student_name + ' ...'))
    # Train the student
    student, teacher, test_loader = epoch_loop(student=student, teacher=teacher, train_set=train_set, test_set=test_set, args=args)

    

11 Dec 2023 19:43:52 [line:423] [32mDistribute train, total batch size:128, epoch:240[0m


Files already downloaded and verified
Files already downloaded and verified


11 Dec 2023 19:43:53 [line:451] [32mUse resnet56_cifar Training resnet20_cifar ...[0m


Load Teacher Weights
Epoch 1/240
2023-12-11 19:44:18 [391/391] - 39.62ms/step - nd_loss: 0.268 - kd_loss: 0.000 - cls_loss: 3.492 - train_loss: 3.761 - train_acc: 0.162


11 Dec 2023 19:44:19 [line:261] [32mTrain Loss: 4.225, Train Acc: 0.101, Test Loss: 3.837, Test Acc: 0.144, lr: 0.10000[0m


Epoch 2/240
2023-12-11 19:44:44 [391/391] - 40.29ms/step - nd_loss: 0.240 - kd_loss: 0.451 - cls_loss: 2.943 - train_loss: 3.634 - train_acc: 0.250


11 Dec 2023 19:44:45 [line:261] [32mTrain Loss: 3.945, Train Acc: 0.223, Test Loss: 3.295, Test Acc: 0.236, lr: 0.10000[0m


Epoch 3/240
2023-12-11 19:45:11 [391/391] - 39.30ms/step - nd_loss: 0.259 - kd_loss: 1.000 - cls_loss: 2.419 - train_loss: 3.678 - train_acc: 0.350


11 Dec 2023 19:45:12 [line:261] [32mTrain Loss: 3.856, Train Acc: 0.310, Test Loss: 3.461, Test Acc: 0.252, lr: 0.10000[0m


Epoch 4/240
2023-12-11 19:45:37 [391/391] - 39.39ms/step - nd_loss: 0.246 - kd_loss: 1.146 - cls_loss: 2.548 - train_loss: 3.940 - train_acc: 0.350


11 Dec 2023 19:45:38 [line:261] [32mTrain Loss: 3.858, Train Acc: 0.371, Test Loss: 3.498, Test Acc: 0.257, lr: 0.10000[0m


Epoch 5/240
2023-12-11 19:46:03 [391/391] - 39.87ms/step - nd_loss: 0.247 - kd_loss: 1.514 - cls_loss: 2.186 - train_loss: 3.947 - train_acc: 0.350


11 Dec 2023 19:46:04 [line:261] [32mTrain Loss: 3.940, Train Acc: 0.409, Test Loss: 3.645, Test Acc: 0.274, lr: 0.10000[0m


Epoch 6/240
2023-12-11 19:46:29 [391/391] - 39.31ms/step - nd_loss: 0.242 - kd_loss: 1.624 - cls_loss: 1.571 - train_loss: 3.437 - train_acc: 0.575


11 Dec 2023 19:46:31 [line:261] [32mTrain Loss: 4.055, Train Acc: 0.443, Test Loss: 4.195, Test Acc: 0.271, lr: 0.10000[0m


Epoch 7/240
2023-12-11 19:46:55 [391/391] - 41.36ms/step - nd_loss: 0.234 - kd_loss: 1.752 - cls_loss: 2.166 - train_loss: 4.153 - train_acc: 0.450


11 Dec 2023 19:46:57 [line:261] [32mTrain Loss: 4.166, Train Acc: 0.466, Test Loss: 2.663, Test Acc: 0.386, lr: 0.10000[0m


Epoch 8/240
2023-12-11 19:47:21 [391/391] - 39.46ms/step - nd_loss: 0.250 - kd_loss: 2.471 - cls_loss: 2.116 - train_loss: 4.837 - train_acc: 0.450


11 Dec 2023 19:47:22 [line:261] [32mTrain Loss: 4.312, Train Acc: 0.487, Test Loss: 2.504, Test Acc: 0.424, lr: 0.10000[0m


Epoch 9/240
2023-12-11 19:47:47 [391/391] - 39.67ms/step - nd_loss: 0.247 - kd_loss: 2.375 - cls_loss: 2.180 - train_loss: 4.802 - train_acc: 0.463


11 Dec 2023 19:47:48 [line:261] [32mTrain Loss: 4.431, Train Acc: 0.504, Test Loss: 2.847, Test Acc: 0.398, lr: 0.10000[0m


Epoch 10/240
2023-12-11 19:48:13 [391/391] - 38.72ms/step - nd_loss: 0.262 - kd_loss: 2.608 - cls_loss: 2.301 - train_loss: 5.170 - train_acc: 0.475


11 Dec 2023 19:48:14 [line:261] [32mTrain Loss: 4.597, Train Acc: 0.515, Test Loss: 2.575, Test Acc: 0.420, lr: 0.10000[0m


Epoch 11/240
2023-12-11 19:48:38 [391/391] - 39.44ms/step - nd_loss: 0.252 - kd_loss: 2.437 - cls_loss: 1.692 - train_loss: 4.381 - train_acc: 0.512


11 Dec 2023 19:48:39 [line:261] [32mTrain Loss: 4.724, Train Acc: 0.528, Test Loss: 2.659, Test Acc: 0.425, lr: 0.10000[0m


Epoch 12/240
2023-12-11 19:49:03 [391/391] - 38.30ms/step - nd_loss: 0.262 - kd_loss: 2.908 - cls_loss: 1.595 - train_loss: 4.765 - train_acc: 0.575


11 Dec 2023 19:49:05 [line:261] [32mTrain Loss: 4.890, Train Acc: 0.536, Test Loss: 2.344, Test Acc: 0.467, lr: 0.10000[0m


Epoch 13/240
2023-12-11 19:49:29 [391/391] - 37.88ms/step - nd_loss: 0.281 - kd_loss: 3.362 - cls_loss: 1.391 - train_loss: 5.034 - train_acc: 0.550


11 Dec 2023 19:49:30 [line:261] [32mTrain Loss: 5.021, Train Acc: 0.545, Test Loss: 2.230, Test Acc: 0.484, lr: 0.10000[0m


Epoch 14/240
2023-12-11 19:49:54 [391/391] - 38.28ms/step - nd_loss: 0.278 - kd_loss: 3.751 - cls_loss: 1.827 - train_loss: 5.856 - train_acc: 0.600


11 Dec 2023 19:49:55 [line:261] [32mTrain Loss: 5.189, Train Acc: 0.554, Test Loss: 2.567, Test Acc: 0.454, lr: 0.10000[0m


Epoch 15/240
2023-12-11 19:50:19 [391/391] - 37.84ms/step - nd_loss: 0.275 - kd_loss: 3.636 - cls_loss: 1.911 - train_loss: 5.822 - train_acc: 0.5121


11 Dec 2023 19:50:21 [line:261] [32mTrain Loss: 5.341, Train Acc: 0.560, Test Loss: 2.206, Test Acc: 0.481, lr: 0.10000[0m


Epoch 16/240
2023-12-11 19:50:44 [391/391] - 38.06ms/step - nd_loss: 0.283 - kd_loss: 3.640 - cls_loss: 1.521 - train_loss: 5.444 - train_acc: 0.550


11 Dec 2023 19:50:46 [line:261] [32mTrain Loss: 5.501, Train Acc: 0.565, Test Loss: 2.098, Test Acc: 0.506, lr: 0.10000[0m


Epoch 17/240
2023-12-11 19:51:09 [391/391] - 37.66ms/step - nd_loss: 0.291 - kd_loss: 3.906 - cls_loss: 1.706 - train_loss: 5.904 - train_acc: 0.550


11 Dec 2023 19:51:10 [line:261] [32mTrain Loss: 5.657, Train Acc: 0.572, Test Loss: 2.124, Test Acc: 0.504, lr: 0.10000[0m


Epoch 18/240
2023-12-11 19:51:34 [391/391] - 37.90ms/step - nd_loss: 0.314 - kd_loss: 4.441 - cls_loss: 1.572 - train_loss: 6.327 - train_acc: 0.5372


11 Dec 2023 19:51:35 [line:261] [32mTrain Loss: 5.831, Train Acc: 0.577, Test Loss: 2.335, Test Acc: 0.487, lr: 0.10000[0m


Epoch 19/240
2023-12-11 19:51:58 [391/391] - 37.35ms/step - nd_loss: 0.332 - kd_loss: 4.051 - cls_loss: 1.723 - train_loss: 6.106 - train_acc: 0.613


11 Dec 2023 19:52:00 [line:261] [32mTrain Loss: 6.003, Train Acc: 0.582, Test Loss: 2.719, Test Acc: 0.455, lr: 0.10000[0m


Epoch 20/240
2023-12-11 19:52:23 [391/391] - 37.65ms/step - nd_loss: 0.323 - kd_loss: 4.397 - cls_loss: 1.773 - train_loss: 6.492 - train_acc: 0.575


11 Dec 2023 19:52:24 [line:261] [32mTrain Loss: 6.175, Train Acc: 0.581, Test Loss: 2.363, Test Acc: 0.485, lr: 0.10000[0m


Epoch 21/240
2023-12-11 19:52:48 [391/391] - 36.92ms/step - nd_loss: 0.359 - kd_loss: 4.584 - cls_loss: 1.478 - train_loss: 6.421 - train_acc: 0.6627


11 Dec 2023 19:52:49 [line:261] [32mTrain Loss: 6.318, Train Acc: 0.587, Test Loss: 2.090, Test Acc: 0.520, lr: 0.10000[0m


Epoch 22/240
2023-12-11 19:53:12 [391/391] - 37.19ms/step - nd_loss: 0.344 - kd_loss: 4.398 - cls_loss: 1.665 - train_loss: 6.406 - train_acc: 0.475


11 Dec 2023 19:53:14 [line:261] [32mTrain Loss: 6.256, Train Acc: 0.592, Test Loss: 2.515, Test Acc: 0.485, lr: 0.10000[0m


Epoch 23/240
2023-12-11 19:53:37 [391/391] - 37.36ms/step - nd_loss: 0.363 - kd_loss: 4.758 - cls_loss: 1.868 - train_loss: 6.989 - train_acc: 0.588


11 Dec 2023 19:53:38 [line:261] [32mTrain Loss: 6.184, Train Acc: 0.596, Test Loss: 2.418, Test Acc: 0.488, lr: 0.10000[0m


Epoch 24/240
2023-12-11 19:54:01 [391/391] - 37.12ms/step - nd_loss: 0.345 - kd_loss: 4.961 - cls_loss: 1.382 - train_loss: 6.689 - train_acc: 0.6250


11 Dec 2023 19:54:02 [line:261] [32mTrain Loss: 6.131, Train Acc: 0.597, Test Loss: 2.317, Test Acc: 0.490, lr: 0.10000[0m


Epoch 25/240
2023-12-11 19:54:26 [391/391] - 37.23ms/step - nd_loss: 0.333 - kd_loss: 4.267 - cls_loss: 1.218 - train_loss: 5.818 - train_acc: 0.637


11 Dec 2023 19:54:27 [line:261] [32mTrain Loss: 6.103, Train Acc: 0.600, Test Loss: 2.549, Test Acc: 0.475, lr: 0.10000[0m


Epoch 26/240
2023-12-11 19:54:50 [391/391] - 37.29ms/step - nd_loss: 0.327 - kd_loss: 4.255 - cls_loss: 1.721 - train_loss: 6.303 - train_acc: 0.600


11 Dec 2023 19:54:51 [line:261] [32mTrain Loss: 6.079, Train Acc: 0.601, Test Loss: 2.188, Test Acc: 0.506, lr: 0.10000[0m


Early stopping triggered at epoch 26


11 Dec 2023 19:54:54 [line:333] [32m- - - - - METRICS REPORT - - - - -[0m
11 Dec 2023 19:54:54 [line:334] [32mTEACHER: accuracy: 72.410, precision: 72.663, recall: 72.410, f1: 72.424, teacher_inf: 0.007, teacher_size: 861620.000[0m
11 Dec 2023 19:54:54 [line:335] [32mSTUDENT: accuracy: 52.000, precision: 60.450, recall: 52.000, f1: 51.904, student_inf: 0.003, student_size: 278324.000[0m
