In [None]:
import torch
import importlib
use_cuda = True & torch.cuda.is_available()
device = torch.device("cuda") if use_cuda else torch.device("cpu")
print("Use cuda." if use_cuda else "Use cpu.")
TARGET_MODEL_DICT_PATH = "./data/teacher_resnet18_from_scratch.bin" #预先训练好的目标模型的参数，torchvision官网的resnet18，11类
STUDENT_MODEL_DICT_SAVED_PATH = "./data/student_net.bin"
TRAIN_DIR = "../hw3/data/training"
VALIDATION_DIR = "../hw3/data/validation"
TARGET_MODEL_SAVE_PATH = "./data/target_model.bin"
SMART_MODEL_SAVE_PATH = "./data/smart_model.bin"

In [None]:
#由于数据集和hw3相同，因此直接使用hw3中的model中的一些定义，比如image_set，model_manager
sys.path.append('../hw3') 
import model_manager
import image_set
importlib.reload(image_set)
importlib.reload(model_manager)

def calc_right_percent(model, dir):    
    #构造一个train的dataset来获取标签
    data_train = image_set.LearningSet(dir, (224,224))
    labels = data_train.GetLabels()
    num = len(labels)

    #对目录中的所有图片进行预测
    y_pred = model_manager.predict(model,device, dir, (224,224))
    
    #计算预测正确的数量
    right_count = 0
    for i in range(num):
        if y_pred[i] == labels[i]:
            right_count += 1
    return right_count/num


def train_model(model,iters, savepath, opt = 0):
    #会自动保存表现最好的model
    data_train = image_set.LearningSet(TRAIN_DIR, (224,224)) 
    accuracy_pre = calc_right_percent(model,VALIDATION_DIR)   
    for i in range(iters):
        print("[iters %d/%d]:" %(i, iters))
        model = model_manager.train_model(
            model,
            data_train,
            device=device,
            lr=0.001,
            epochs=5,
            nbatch=128,  # 可根据显存和模型大小来调整batchsize的大小
            weight_decay= 0.001,
            opt = 0,
        )
        # 每5轮保存一次模型，同时验证一下正确率
        print("waiting for validation...")
        accuracy = calc_right_percent(model,VALIDATION_DIR)
        print("train accuracy: %f%%" % (100 * calc_right_percent(model,TRAIN_DIR))) 
        print("validation accuracy: %f%%" % (100 * accuracy))
        # 如果验证集的准确率超过之前的，则进行模型保存
        print("accuracy_pre:%f, accuracy:%f" %(accuracy_pre, accuracy))
        if accuracy > accuracy_pre:            
            torch.save(model.state_dict(),savepath)  
            accuracy_pre = accuracy
            print("Got a better model and saved it.")
    return model

In [None]:
#加载torchvision中定义好的resnet18模型（11类）
import torchvision.models as models
target_model = models.resnet18(pretrained=False, num_classes = 11)
# 模型实例化
if os.path.exists(TARGET_MODEL_SAVE_PATH):       
    target_model.load_state_dict(torch.load(TARGET_MODEL_SAVE_PATH))
    print("target_model has been loaded from file.")


In [None]:
# 训练target模型
# 提供预训练的model有问题，此处重新训练一个出来
if False:#如果已经训练好了，关闭此处开关
      print("waiting for training target model...")
      target_model = train_model(target_model, 20, TARGET_MODEL_SAVE_PATH)

In [None]:
target_model.load_state_dict(torch.load(TARGET_MODEL_SAVE_PATH))#加载保存的表现最好的model
print("waiting for validation of target model...")
print("train accuracy: %f%%" % (100 * calc_right_percent(target_model,TRAIN_DIR))) 
print("validation accuracy: %f%%" % (100 * calc_right_percent(target_model,VALIDATION_DIR)))

In [None]:
# 直接用训练数据来训练压缩后的模型
import model_architecture
importlib.reload(model_architecture)

smart_model = model_architecture.SmartResnet18()
if os.path.exists(SMART_MODEL_SAVE_PATH):       
    smart_model.load_state_dict(torch.load(SMART_MODEL_SAVE_PATH))
    print("smart_model has been loaded from file.")
if True:#如果已经训练好了，关闭此处开关
      print("waiting for training smart model...")
      smart_model = train_model(smart_model, 20, SMART_MODEL_SAVE_PATH,1)

In [None]:
smart_model.load_state_dict(torch.load(SMART_MODEL_SAVE_PATH))#加载保存的表现最好的model
print("waiting for validation of smart model...")
print("train accuracy: %f%%" % (100 * calc_right_percent(smart_model,TRAIN_DIR))) 
print("validation accuracy: %f%%" % (100 * calc_right_percent(smart_model,VALIDATION_DIR)))