In [None]:
# impot
import os
import numpy as np
from statistics import mean

import torch
import torchvision

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from scipy import stats

from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings("ignore")

## Classifier model 학습

In [None]:
# 학습 데이터 경로 수정 필요
PROJECT_PATH = os.getenv('HOME') + '/aiffel/dt/data'
MODEL_PATH = os.path.join(PROJECT_PATH, 'weights')
DATA_PATH = os.path.join(PROJECT_PATH, 'data')
TRAIN_PATH = os.path.join(DATA_PATH, 'train')
VAL_PATH = os.path.join(DATA_PATH, 'val')
TEST_PATH = os.path.join(DATA_PATH, 'test')
REJECT_PATH = os.path.join(DATA_PATH, 'reject')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Dataloader 생성 함수
def create_dataloader(path, batch_size, istrain):
    nearest_mode = torchvision.transforms.InterpolationMode.NEAREST
    normalize = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
    )
    
    train_transformer = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224,224), interpolation=nearest_mode),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ColorJitter(),
        torchvision.transforms.ToTensor(),
        normalize
    ])
    
    val_transformer = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224,224), interpolation=nearest_mode),
        torchvision.transforms.ToTensor(),
        normalize
    ])
    
    test_transformer = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224,224), interpolation=nearest_mode),
        torchvision.transforms.ToTensor(),
        normalize
    ])
    
    if istrain:
        data = torchvision.datasets.ImageFolder(path, transform=train_transformer)
        dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
       
    else:
        data = torchvision.datasets.ImageFolder(path, transform=test_transformer)
        dataloader = torch.utils.data.DataLoader(data, shuffle=False)
    
    print(len(data))
    
    return dataloader, data

In [None]:
# batch 설정 
BATCH_SIZE = 128

# TTRAIN_PATH에 들어있는 데이터를 shuffle 하도록 dataloader 만들기
train_loader, _train_data = create_dataloader(TRAIN_PATH, BATCH_SIZE, True)
target_class_num = len(os.listdir(os.path.join(TRAIN_PATH)))

print('target_class_num: ', target_class_num)
print('train: ', _train_data.class_to_idx)

In [None]:
# 폴더 별 이미지 개수 (TRAIN_PATH, VAL_PATH, TEST_PATH)
for dirpath, dirnames, filenames in os.walk(TRAIN_PATH):
    print(f'{dirpath} : {len(filenames)}')

In [None]:
# VAL loder 생성
val_loader, _val_data = create_dataloader(VAL_PATH, BATCH_SIZE, False)
print('val: ', _val_data.class_to_idx)

In [None]:
# TEST loder 생성
test_loader, _test_data = create_dataloader(TEST_PATH, BATCH_SIZE, False)
print('test: ', _test_data.class_to_idx)

In [None]:
print(f"데이터셋 샘플 수: {len(_train_data)}")
print(f"데이터로더 배치 수: {len(train_loader)}")
print(f"현재 DataLoader의 배치 크기: {train_loader.batch_size}")

In [None]:
# metrics 함수
def calculate_metrics(trues, preds):
    accuracy = accuracy_score(trues, preds)
    f1 = f1_score(trues, preds, average='macro')
    precision = precision_score(trues, preds, average='macro')
    recall = recall_score(trues, preds, average='macro')
    return accuracy, f1, precision, recall

In [None]:
# Train 함수
# 1 epoch 당 1회 수행되는 train 함수
# optimizer : Adam, loss 함수 : CrossEntropyLoss
def train(dataloader, net, learning_rate, weight_decay_level, device):
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        net.parameters(),
        lr = learning_rate, 
        weight_decay = weight_decay_level
    )

    net.train()

    train_losses = list()
    train_preds = list()
    train_trues = list()

    for idx, (img, label) in enumerate(dataloader):

        img = img.to(device)
        label = label.to(device)
        
        optimizer.zero_grad() # gradient 초기화

        out = net(img)

        _, pred = torch.max(out, 1)
        loss = criterion(out, label)

        loss.backward() # gradient 계산
        optimizer.step() # 파라미터 업데이트

        train_losses.append(loss.item())
        train_trues.extend(label.view(-1).cpu().numpy().tolist())
        train_preds.extend(pred.view(-1).cpu().detach().numpy().tolist())

    acc, f1, prec, rec = calculate_metrics(train_trues, train_preds)

    print('\n''====== Training Metrics ======')
    print('Loss:', mean(train_losses), 'Acc:', acc, 'F1:', f1, 'Precision:', prec, 'Recall:', rec)

    return net, acc, f1, prec, rec

In [None]:
# Test 함수
def test(dataloader, net, device):

    criterion = torch.nn.CrossEntropyLoss()
    
    net.eval()
    test_losses = list()
    test_trues = list()
    test_preds = list()
    
    with torch.no_grad():
        for idx, (img, label) in enumerate(dataloader):

            img = img.to(device)
            label = label.to(device)

            out = net(img)

            _, pred = torch.max(out, 1)
            loss = criterion(out, label)

            test_losses.append(loss.item())
            test_trues.extend(label.view(-1).cpu().numpy().tolist())
            test_preds.extend(pred.view(-1).cpu().detach().numpy().tolist())

    acc, f1, prec, rec = calculate_metrics(test_trues, test_preds)

    print('====== Test Metrics ======')
    print('Loss:', mean(test_losses), 'Acc:', acc, 'F1:', f1, 'Precision:', prec, 'Recall:', rec)

    return net, acc, f1, prec, rec

In [None]:
# 학습을 위한 함수( 폴더생성, 훈련, 평가, high-acc pth 저장)
def train_classifier(net, train_loader, val_loader, n_epochs, learning_rate, weight_decay, device):
    best_test_acc = 0
    
    model_save_path = None
    model_save_base = 'weights'
    if not os.path.exists(model_save_base):
        os.makedirs(model_save_base)
    
    print('>> Start Training Model!')
    for epoch in range(n_epochs):
        
        print('> epoch: ', epoch)

        net, _, _, _, _ = train(train_loader, net, learning_rate, weight_decay, device)
        net, test_acc, _, _, _  = test(val_loader, net, device)

        if test_acc > best_test_acc:

            best_test_acc = test_acc
            test_acc_str = '%.5f' % test_acc

            print('[Notification] Best Model Updated!')
            model_save_path = os.path.join(model_save_base, 'classifier_acc_' + str(test_acc_str) + '.pth') 
            torch.save(net.state_dict(), model_save_path)
                
    return model_save_path

In [None]:
# resnet50 model
net = torchvision.models.resnet50(pretrained=True)
net.fc = torch.nn.Linear(
    net.fc.in_features,
    target_class_num
)

net.to(device)

In [None]:
# epoch, lr, weight 설정 후 학습 시작
EPOCHS = 80
LEARNING_RATE = 0.01
WEIGHT_DECAY = 0.001

saved_weight_path = train_classifier(net, train_loader, val_loader, EPOCHS, LEARNING_RATE, WEIGHT_DECAY, device)

## open max

In [None]:
# 학습한 모델 로드 하여 Eval()
net = torchvision.models.resnet50(pretrained=True)
net.fc = torch.nn.Linear(
    net.fc.in_features,
    target_class_num
)

saved_weight_path = '../classifier_weights.pth'
net.load_state_dict(torch.load(saved_weight_path, map_location=device))
net.eval()
net.to(device)

In [None]:
# batch 설정 
BATCH_SIZE = 1

# TTRAIN_PATH에 들어있는 데이터를 shuffle 하도록 dataloader 만들기
train_loader, _train_data = create_dataloader(TRAIN_PATH, BATCH_SIZE, True)
target_class_num = len(os.listdir(os.path.join(TRAIN_PATH)))

print('target_class_num: ', target_class_num)
print('train: ', _train_data.class_to_idx)

In [None]:
# OpenMax에 사용할 데이터를 추출
train_preds = list()
train_actvecs = list()
train_outputs_softmax = list()
train_labels = list()

with torch.no_grad():
    for idx, (img, label) in enumerate(train_loader):
        img = img.to(device)
        label = label.to(device)

        out = net(img)
        out_actvec = out.cpu().detach().numpy()[0]
        out_softmax = torch.softmax(out, 1).cpu().detach().numpy()[0]
        out_pred = int(torch.argmax(out).cpu().detach().numpy())
        out_label = int(label.cpu().detach().numpy())

        train_actvecs.append(out_actvec) # component 1: softmax 전의 Activation Vector
        train_preds.append(out_pred) # componenet 2: 각 데이터에 대한 예측값
        train_outputs_softmax.append(out_softmax) # component 3: 각 데이터에 대한 softmax 확률
        train_labels.append(out_label) # component 4: 각 데이터에 대한 Label (정답)

train_actvecs = np.asarray(train_actvecs)
train_preds = np.asarray(train_preds)
train_outputs_softmax = np.asarray(train_outputs_softmax)
train_labels = np.asarray(train_labels)

In [None]:
# 올바른 경우의 Activation Vector만 사용
train_correct_actvecs = train_actvecs[train_labels==train_preds]
train_correct_labels = train_labels[train_labels==train_preds]

print('Activation vector: ', train_correct_actvecs.shape)
print('Labels: ', train_correct_labels.shape)

In [None]:
# Activation Vector를 클래스마다 나눠 담기
# 클래스별로 나눠진 Activation Vector별 평균으로부터 가장 먼 100개의 Vector를 이용해 베이불 분포의 모수를 추출
# 각 클래스당 베이불 분포의 모수들을 저장
class_means = list()
dist_to_means = list()
mr_models = {}

for class_idx in np.unique(train_labels):
    
    class_act_vec = train_correct_actvecs[train_correct_labels==class_idx]
    print('class_idx: ', class_idx)
    print(class_act_vec.shape)
    
    class_mean = class_act_vec.mean(axis=0)
    class_means.append(class_mean)
    
    dist_to_mean = np.square(class_act_vec - class_mean).sum(axis=1) # 각 activation vector의 거리를 계산
    dist_to_mean_sorted = np.sort(dist_to_mean).astype(np.float64) # 거리를 기준으로 오름차순 정렬
    dist_to_means.append(dist_to_mean_sorted)

    shape, loc, scale = stats.weibull_max.fit(dist_to_mean_sorted[-100:]) # 거리가 가장 먼 100개를 사용하여 모수 추출
    
    mr_models[str(class_idx)] = {
        'shape':shape,
        'loc':loc,
        'scale':scale
    }
    
class_means = np.asarray(class_means)

In [None]:
# OpenMax 확률 함수
def compute_openmax(actvec, class_means, mr_models):
    dist_to_mean = np.square(actvec - class_means).sum(axis=1)

    scores = list()
    for class_idx in range(len(class_means)):
        params = mr_models[str(class_idx)]
        score = stats.weibull_max.cdf(
            dist_to_mean[class_idx],
            params['shape'],
            params['loc'],
            params['scale']
        )
        scores.append(score)
    scores = np.asarray(scores)
    
    weight_on_actvec = 1 - scores # 각 class별 가중치
    rev_actvec = np.concatenate([
        weight_on_actvec * actvec, # known class에 대한 가중치 곱
        [((1-weight_on_actvec) * actvec).sum()] # unknown class에 새로운 계산식
    ])
    
    openmax_prob = np.exp(rev_actvec) / np.exp(rev_actvec).sum()
    return openmax_prob

In [None]:
# 계산한 최대 확률이 threshold 보다 낮은 경우라면 강제로 reject클래스로 분류해주는 함수
def inference(actvec, threshold, target_class_num, class_means, mr_models):
    openmax_prob = compute_openmax(actvec, class_means, mr_models)
    openmax_softmax = np.exp(openmax_prob)/sum(np.exp(openmax_prob))
    pred = np.argmax(openmax_softmax)

    if np.max(openmax_softmax) < threshold:
        pred = target_class_num
    return pred

In [None]:
# Threshold 탐색을 쉽게 하기 위한 함수
def inference_dataloader(net, data_loader, threshold, target_class_num, class_means, mr_models, is_reject=False):
    result_preds = list()
    result_labels = list()

    with torch.no_grad():
        for idx, (img, label) in enumerate(data_loader):
            img = img.to(device)
            label = label.to(device)
            
            out = net(img)
            out_actvec = out.cpu().detach().numpy()[0]
            out_softmax = torch.softmax(out, 1).cpu().detach().numpy()[0]
            out_label = int(label.cpu().detach().numpy())

            pred = inference(out_actvec, threshold, target_class_num, class_means, mr_models)
            result_preds.append(pred)
        
            if is_reject:
                result_labels.append(target_class_num)
            else:
                result_labels.append(out_label)

    return result_preds, result_labels

In [None]:
# test, reject dataloader 생성
test_loader, _test_data = create_dataloader(TEST_PATH, 1, False)
reject_loader, _reject_data = create_dataloader(REJECT_PATH, 1, False)
target_class_num = len(os.listdir(TEST_PATH))

In [None]:
# 하나의 threshold test 
test_preds, test_labels = inference_dataloader(net, test_loader, 0.45, target_class_num, class_means, mr_models)
reject_preds, reject_labels = inference_dataloader(net, reject_loader, 0.45, target_class_num, class_means, mr_models, is_reject=True)

print('Test Accuracy: ', accuracy_score(test_labels, test_preds))
print('Reject Accuracy: ', accuracy_score(reject_labels, reject_preds))

In [None]:
# 여러개의 threshold test 
thresholds = np.arange(0.36, 0.50, 0.02)
test_accuracies = list()
reject_accuracies = list()

for idx, threshold in enumerate(thresholds):
    test_preds, test_labels = inference_dataloader(net, test_loader, threshold, target_class_num, class_means, mr_models)
    reject_preds, reject_labels = inference_dataloader(net, reject_loader, threshold, target_class_num, class_means, mr_models, is_reject=True)
    
    test_accuracy = accuracy_score(test_labels, test_preds)
    reject_accuracy = accuracy_score(reject_labels, reject_preds)
    
    test_accuracies.append(test_accuracy)
    reject_accuracies.append(reject_accuracy)

# np.array 형 변환
test_accuracies = np.asarray(test_accuracies)
reject_accuracies = np.asarray(reject_accuracies)

In [None]:
# target_score의 값 이상일 때 reject_accuracy가 최대가 되는 threashold 찾기
target_score = 0.85
best_reject_accuracy = 0
best_test_accuracy = None
best_threashold = None
for idx, flag in enumerate(test_accuracies > target_score):
    if flag == True and best_reject_accuracy < reject_accuracies[idx]:
        best_threshold = thresholds[idx]
        best_test_accuracy = test_accuracies[idx]
        best_reject_accuracy = reject_accuracies[idx]

print(f"Test accuracy가 {target_score} 이상일 때: ")
print(f"reject accuracy의 최대값: {best_reject_accuracy}")
print(f"Test accuracy 값: {best_test_accuracy}")
print(f"threshold 값: {round(best_threshold,2)}")

In [None]:
# Test, Reject Accuracy 시각화
plt.figure(figsize=(20, 6))
plt.plot(thresholds, test_accuracies, label='Test Accuracy', marker='o')
plt.plot(thresholds, reject_accuracies, label='Reject Accuracy', marker='x')
plt.xlabel('Threshold')
plt.ylabel('Accuracy')
plt.title('Test Accuracy & Reject Accuracy')
plt.legend()
plt.grid(True)
plt.savefig('./openmax_threshold.png')
plt.show()