In [1]:
# 필요한 라이브러리 import
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import dataset
import model
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 각종 path및 하이퍼 파라미터 설정
data_path = 'C:\\Users\\USER\\Desktop\\GSH_CRP\\codes\\rock_sci_paper\\data\\pic128'
save_path = 'C:\\Users\\USER\\Desktop\\GSH_CRP\\codes\\rock_sci_paper\\model_para'
epochs = 100
batch_size = 16
learning_rate = 0.01
seed = 2023

In [3]:
# seed설정
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# gpu를 사용할 수 있으면 gpu를 사용
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
# transform 설정
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.Resize((128,128)),
    transforms.ColorJitter(brightness=0.3),
    transforms.ColorJitter(contrast=0.3),
    transforms.ColorJitter(saturation=0.3),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
)

test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
)

In [5]:
# dataset 설정
train_dataset = dataset.RockScissorsPaper(
    transform=train_transform,
    path = data_path,
    mode = 'train'
)
val_dataset = dataset.RockScissorsPaper(
    transform=test_transform,
    path = data_path,
    mode = 'val'
)
test_dataset = dataset.RockScissorsPaper(
    transform=test_transform,
    path = data_path,
    mode = 'test'
)

print(len(train_dataset))
print(len(val_dataset))
print(len(test_dataset))

240
30
30


In [6]:
# dataloader 설정
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [7]:
# 모델, 손실함수, 옵티마이저 설정
teacher_model = model.ResNet10(num_classes=3)
teacher_model = teacher_model.to(device)
student_model = model.ResNet10(num_classes=3)
student_model = student_model.to(device)

teacher_model.load_state_dict(torch.load(os.path.join(save_path, f'teacher.pth')))

criterion = nn.CrossEntropyLoss()
dist_criterion = nn.CosineSimilarity(dim=-1)
optimizer = optim.SGD(student_model.parameters(), lr=learning_rate)

In [8]:
def train(epoch):
    print('\nEpoch: %d'%epoch)
    # student model train mode로 전환
    # teacher model은 eval mode
    teacher_model.eval()
    student_model.train()
    running_loss = 0.0
    running_acc = 0.0
    total = 0
    for (inputs, labels) in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        t_outputs, t4, t3, t2, t1 = teacher_model(inputs)
        
        h,w = inputs.shape[-2], inputs.shape[-1]
        lr_inputs = F.interpolate(inputs, (h//4, w//4), mode='bilinear')
        lr_inputs = F.interpolate(lr_inputs, (h,w), mode='bilinear')
        
        outputs, s4, s3, s2, s1 = student_model(lr_inputs)
        _, pred = torch.max(outputs, 1)
        total += outputs.size(0)
        running_acc += (pred == labels).sum().item()
        
        classif_loss = criterion(outputs, labels)
        distil_loss4 = torch.mean(dist_criterion(t4, s4))
        distil_loss3 = torch.mean(dist_criterion(t3, s3))
        distil_loss2 = torch.mean(dist_criterion(t2, s2))
        distil_loss1 = torch.mean(dist_criterion(t1, s1))
        
        distil_loss = distil_loss4 + distil_loss3 + distil_loss2 + distil_loss1
        # distil_loss = distil_loss4
        distil_loss = torch.abs(distil_loss)
        loss = classif_loss + 0.2*distil_loss

        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    total_loss = running_loss / len(trainloader)
    total_acc = 100 * running_acc / total
    print(f'Train epoch : {epoch} cls loss : {classif_loss}  dist loss : {distil_loss} Acc : {total_acc}%')

In [9]:
def test(epoch, loader, mode='val', mode2=False):
    print('\nEpoch: %d'%epoch)
    # model eval mode로 전환
    teacher_model.eval()
    student_model.eval()
    running_loss = 0.0
    running_acc = 0.0
    total = 0
    label_dict = {0:0, 1:0, 2:0}
    correct_dict = {0:0, 1:0, 2:0}
    global BEST_SCORE
    for (inputs, labels) in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # _, t4, t3, t2, t1 = teacher_model(inputs)
        h,w = inputs.shape[-2], inputs.shape[-1]
        lr_inputs = F.interpolate(inputs, (h//4, w//4), mode='bilinear')
        lr_inputs = F.interpolate(lr_inputs, (h,w), mode='bilinear')
        
        outputs, s4, s3, s2, s1 = student_model(lr_inputs)
        _, pred = torch.max(outputs, 1)
        total += outputs.size(0)
        running_acc += (pred == labels).sum().item()
        
        if mode2:
            for i in range(len(labels)):
                label = labels[i]
                label_dict[label.item()] += 1
                if (pred==labels)[i]:
                    correct_dict[label.item()] += 1
        
        classif_loss = criterion(outputs, labels)
        # distil_loss4 = torch.mean(dist_criterion(t4, s4))
        # distil_loss3 = torch.mean(dist_criterion(t3, s3))
        # distil_loss2 = torch.mean(dist_criterion(t2, s2))
        # distil_loss1 = torch.mean(dist_criterion(t1, s1))
        
        # distil_loss = distil_loss4 + distil_loss3 + distil_loss2 + distil_loss1
        loss = classif_loss
        print(pred)

        running_loss += loss.item()
    total_loss = running_loss / len(loader)
    total_acc = 100 * running_acc / total
    if mode2:
            print(label_dict)
            print(correct_dict)
            
    if total_acc >= BEST_SCORE and not mode=='test':
        path = os.path.join(save_path, f'student.pth')
        torch.save(student_model.state_dict(), path)
        BEST_SCORE = total_acc
    print(f'Test epoch : {epoch} loss : {total_loss} Acc : {total_acc}%')

In [10]:
# 모델 학습 및 평가
BEST_SCORE = 0
for epoch in range(epochs):
    train(epoch)
    test(epoch, valloader)
    print(BEST_SCORE)


Epoch: 0


  "See the documentation of nn.Upsample for details.".format(mode)


Train epoch : 0 cls loss : 1.285164475440979  dist loss : 1.271368145942688 Acc : 35.833333333333336%

Epoch: 0
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
Test epoch : 0 loss : 1.0797734260559082 Acc : 43.333333333333336%
43.333333333333336

Epoch: 1
Train epoch : 1 cls loss : 1.1016485691070557  dist loss : 1.2206957340240479 Acc : 38.333333333333336%

Epoch: 1
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
Test epoch : 1 loss : 1.0728958249092102 Acc : 43.333333333333336%
43.333333333333336

Epoch: 2
Train epoch : 2 cls loss : 1.1097688674926758  dist loss : 1.19219970703125 Acc : 45.833333333333336%

Epoch: 2
tensor([0, 0, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 2, 0], device='cuda:0')
Test epoch : 2 loss : 0.9989242851734161 Acc

In [11]:
# test에서 평가
student_model.load_state_dict(torch.load(os.path.join(save_path, f'student.pth')))
test(-1, testloader, 'test', True)


Epoch: -1
tensor([2, 2, 1, 0, 1, 1, 1, 2, 0, 2, 2, 2, 0, 0, 2, 2], device='cuda:0')
tensor([0, 2, 0, 1, 1, 0, 2, 2, 0, 1, 0, 0, 1, 0], device='cuda:0')
{0: 11, 1: 8, 2: 11}
{0: 11, 1: 8, 2: 11}
Test epoch : -1 loss : 0.02565804310142994 Acc : 100.0%
