In [1]:
# to load dataset
from datasets import Datasets
from kd_triplet_datasets import KDTripletDataset

# for network and training
from network import Net_teacher, Net_student
from network_fit import NetworkFit

# to calculate the score
import savescore
from score import Score
from score_calc import ScoreCalc

# pytorch
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

# numpy and matplotlib
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# initialize for  each parameters
DATASET = 'CIFAR10'
BATCH_SIZE = 100
NUM_WORKERS = 2

WEIGHT_DECAY = 0.007
LEARNING_RATE = 0.01
MOMENTUM = 0.9

SCHEDULER_STEPS = 100
SCHEDULER_GAMMA = 0.1

SEED = 1

EPOCH = 500

KD_LAMBDA = 2.0

TRIPLET_MARGINE = 5.0

In [3]:
# fixing the seed
torch.cuda.manual_seed_all(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

In [4]:
# check if gpu is available
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("gpu mode")
else:
    device = torch.device("cpu")
    print("cpu mode")

gpu mode


In [5]:
# the name of results files
codename = 'kd_example'

fnnname = codename + "_fnn_model"

total_loss_name = codename + "_total_loss"
soft_loss_name = codename + "_soft_loss"
tri_loss_name = codename + "_tri_loss"
acc_name = codename + "_accuracy"

result_name = codename + "_result"

In [7]:
# load the data set
instance_datasets = Datasets(DATASET, BATCH_SIZE, NUM_WORKERS, shuffle = False)
data_sets = instance_datasets.create()

#trainloader = data_sets[0]
#testloader = data_sets[1]
classes = data_sets[2]
based_labels = data_sets[3]
trainset = data_sets[4]
testset = data_sets[5]

Dataset : CIFAR10
Files already downloaded and verified
Files already downloaded and verified


In [9]:
# use the KD Triplet Dataset by using above dataset
tri_trainset = KDTripletDataset(trainset)
tri_testset = KDTripletDataset(testset)
tri_trainloader = torch.utils.data.DataLoader(tri_trainset, batch_size = BATCH_SIZE, shuffle = True, num_workers = NUM_WORKERS)
tri_testloader = torch.utils.data.DataLoader(tri_testset, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS)

In [6]:
# network and criterions
model_t = Net_teacher().to(device)
model_s = Net_student().to(device)

model_t.load_state_dict(torch.load("cnn_alex.pkl"))

optimizer = optim.SGD(model_s.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=SCHEDULER_STEPS, gamma=SCHEDULER_GAMMA)

soft_criterion = nn.CrossEntropyLoss()
triplet_loss = nn.TripletMarginLoss(margin=TRIPLET_MARGINE)

In [10]:
# fit for training and test
fit = NetworkFit(model_t, model_s, optimizer, soft_criterion, triplet_loss)

In [11]:
# to manage all scores
loss = Score()
loss_s = Score()
loss_t = Score()
correct = Score()
score_loss = [loss, loss_s, loss_t]
score_correct = [correct]
sc = ScoreCalc(score_loss, score_correct, BATCH_SIZE)

In [12]:
# training and test
for epoch in range(EPOCH):
    print('epoch', epoch+1)
    
    for (inputs, labels) in tri_trainloader:
        img1_t = inputs[0].to(device)
        img2_s = inputs[0].to(device)
        img3_s = inputs[1].to(device)
        
        images = (img1_t, img2_s, img3_s)
        
        label1_t = labels[0].to(device)
        label2_s = labels[0].to(device)
        label3_s = labels[1].to(device)
        
        label = (label1_t, label2_s, label3_s)
        
        fit.train(images, label, KD_LAMBDA)
        
    for (inputs, labels) in tri_trainloader:
        img1_t = inputs[0].to(device)
        img2_s = inputs[0].to(device)
        img3_s = inputs[1].to(device)
        
        images = (img1_t, img2_s, img3_s)
        
        label1_t = labels[0].to(device)
        label2_s = labels[0].to(device)
        label3_s = labels[1].to(device)
        
        label = (label1_t, label2_s, label3_s)
        
        losses, corrects = fit.test(images, label, KD_LAMBDA)
        
        sc.calc_sum(losses, corrects)
    
    sc.score_print(len(trainset))
    sc.score_append(len(trainset))
    sc.score_del()
    
    for (inputs, labels) in tri_testloader:
        img1_t = inputs[0].to(device)
        img2_s = inputs[0].to(device)
        img3_s = inputs[1].to(device)
        
        images = (img1_t, img2_s, img3_s)
        
        label1_t = labels[0].to(device)
        label2_s = labels[0].to(device)
        label3_s = labels[1].to(device)
        
        label = (label1_t, label2_s, label3_s)
        
        losses, corrects = fit.test(images, label, KD_LAMBDA)
        
        sc.calc_sum(losses, corrects)
    
    sc.score_print(len(testset), train = False)
    sc.score_append(len(testset), train = False)
    sc.score_del()
    
    scheduler.step()

epoch 1
train mean loss=6.441203088760376, accuracy=0.4472
test mean loss=6.384209885597229, accuracy=0.4507
epoch 2
train mean loss=5.3580407338142395, accuracy=0.55368
test mean loss=5.481476616859436, accuracy=0.5407
epoch 3
train mean loss=4.517965421676636, accuracy=0.6212
test mean loss=4.670205645561218, accuracy=0.6043
epoch 4
train mean loss=3.829943600177765, accuracy=0.6769
test mean loss=4.098728365898133, accuracy=0.6516
epoch 5
train mean loss=3.7483384084701536, accuracy=0.68222
test mean loss=4.095595626831055, accuracy=0.6521
epoch 6
train mean loss=3.4453044891357423, accuracy=0.70684
test mean loss=3.805477108955383, accuracy=0.6792
epoch 7
train mean loss=3.609884603500366, accuracy=0.69638
test mean loss=3.993795554637909, accuracy=0.667
epoch 8
train mean loss=3.747229362010956, accuracy=0.69416
test mean loss=4.122958419322967, accuracy=0.6609
epoch 9
train mean loss=3.20278408908844, accuracy=0.73164
test mean loss=3.6333179211616518, accuracy=0.7027
epoch 10
tr

In [13]:
# get the scores
train_losses, train_corrects = sc.get_value()
test_losses, test_corrects = sc.get_value(train = False)

In [14]:
# output the glaphs of the scores
torch.save(model_s.state_dict(), fnnname + '.pth')

savescore.plot_score(EPOCH, train_losses[0], test_losses[0], y_lim = 5.0, y_label = 'LOSS', legend = ['train loss', 'test loss'], title = 'total loss', filename = total_loss_name)

savescore.plot_score(EPOCH, train_losses[1], test_losses[1], y_lim = 5.0, y_label = 'LOSS', legend = ['train loss', 'test loss'], title = 'softmax loss', filename = soft_loss_name)

savescore.plot_score(EPOCH, train_losses[2], test_losses[2], y_lim = 5.0, y_label = 'LOSS', legend = ['train loss', 'test loss'], title = 'triplet loss', filename = tri_loss_name)

savescore.plot_score(EPOCH, train_corrects[0], test_corrects[0], y_lim = 1, y_label = 'ACCURACY', legend = ['train acc', 'test acc'], title = 'accuracy', filename = acc_name)

savescore.save_data(train_losses[0], test_losses[0], train_corrects[0], test_corrects[0], result_name)