In [1]:
from __future__ import print_function, division
import os
import torch


# DataLoader은 Dataset을 샘플에 쉽게 접근할 수 있도록 순회가능한 객체(iterable)로 감쌉니다
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets
from torchvision.transforms import ToTensor
import torchvision.models as models 

import pprint
from datetime import datetime



import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

from collections import defaultdict

from auto_encoder import MNIST_BN_32_64_256, ae_train

from active_learn import argparser

In [2]:
import numpy as np
original_data = datasets.MNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor()
    )


original_all = []
original_dataset = []
original_label = [] 

for i, sample in enumerate(original_data) : 
    original_all.append(sample)
    feature = np.array(sample[0])
    original_dataset.append(feature)
    original_label.append([sample[1], i])
    
unlabeled_dataset = original_dataset[:]
unlabeled_dataset_label = original_label[:]
labeled_dataset = [] 
labeled_dataset_label = []

c_labeled_dataset = [] 
c_labeled_dataset_label = []

count_subgraph = defaultdict(list)

In [3]:
use_cuda = True


device = torch.device("cuda" if use_cuda else "cpu")
    # use_cuda가 true라면 kwargs를 다음과 같이 지정하기. 
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    # 데이터 변경시 수정 필요 
ae_training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor())
    # 데이터 변경시 수정 필요 
ae_test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor())

kwargs = {'num_workers': 1, 'pin_memory': True} if True else {}

In [4]:
PATH = './weights/MNIST/'

CAE = torch.load(PATH + 'CAE.pt')  
CAE.load_state_dict(torch.load(PATH + 'CAE_state_dict.pt'))  

sample_size = 50
if len(unlabeled_dataset) < sample_size:
    sample_size = len(unlabeled_dataset)

In [20]:
from active_learn import active_sample
sample_dataset, sample_index,radius  = active_sample(unlabeled_dataset, labeled_dataset, sample_size, model=CAE, device=device)

sample_data = [unlabeled_dataset[i] for i in sample_index]
sample_label = [unlabeled_dataset_label[i] for i in sample_index]

if len(labeled_dataset_label) == 0 :  
    labeled_dataset = sample_data[:]
    labeled_dataset_label = sample_label[:]
else : 
    labeled_dataset = np.concatenate((labeled_dataset,sample_data),axis=0)
    labeled_dataset_label = np.concatenate((labeled_dataset_label, sample_label), axis =0)

for i in sample_index[::-1] : 
    del unlabeled_dataset[i]
    del unlabeled_dataset_label[i]

Max distance from cluster : 10.41


In [21]:
from active_learn import make_subgraph, adjacency_subgraph

subgraph, density_subgraph = make_subgraph(sample_label, original_dataset, radius, CAE)

dist_class, adj_dist, classified_subgraph_index, pseudo_class_label = adjacency_subgraph(sample_dataset, sample_label, radius, CAE, 0)

print("Well work!")

Well work!


In [22]:
from active_learn import first_classification, check_performance
f_classification = first_classification(classified_subgraph_index, pseudo_class_label, subgraph, density_subgraph, 0.99)
num_classification, score, dic_score = check_performance(f_classification,original_label)


In [23]:
# CS1 방법을 적용 후, unlabeled dataset 구분하기 
erase_dataset_ori_index = []
pre_index = [j[1] for j in c_labeled_dataset_label]

for i in f_classification.keys(): 
    index = f_classification[i]
    
    index = list(set(index) - set(pre_index))

    new_labeled_dataset = [original_dataset[j] for j in index]
    new_labeled_dataset_label = [ [i,j] for j in index ]
    new_erase_original_index = [new_labeled_dataset_label[j][1] for j in range(len(new_labeled_dataset_label))]

    if len(c_labeled_dataset_label) == 0 : 
        c_labeled_dataset = new_labeled_dataset
        c_labeled_dataset_label = new_labeled_dataset_label
        
    else : 
        c_labeled_dataset = np.concatenate((c_labeled_dataset, new_labeled_dataset), axis=0)
        c_labeled_dataset_label = np.concatenate((c_labeled_dataset_label, new_labeled_dataset_label), axis =0)
    
    erase_dataset_ori_index += new_erase_original_index

erase_unlabeled_index = [np.where(np.array(unlabeled_dataset_label).T[1] == i)[0][0]  for i in erase_dataset_ori_index]
erase_unlabeled_index.sort()


for i in erase_unlabeled_index[::-1] : 
    del unlabeled_dataset[i]
    del unlabeled_dataset_label[i]

In [24]:
len(c_labeled_dataset_label)

3809

In [18]:
def update_count_subgraph(count_subgraph, unlabeled_dataset_label, subgraph) : 
    unlabeled_index = [i[1] for i in unlabeled_dataset_label]
    for i in unlabeled_index : 
        count = [0]*10
        i_subgraph = np.where(subgraph[:, i]==1)[0]
        for j in i_subgraph : 
            count[labeled_dataset_label[j][0]] += 1

        count_subgraph[i].append([count, radius[0]]) 
    
    return count_subgraph


In [25]:
# unlabeled의 변화를 Folliwing 하지 못하기 때문에 original index로 접근해야 한다. 
update_count_subgraph(count_subgraph, unlabeled_dataset_label, subgraph)

defaultdict(list,
            {0: [[[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 17.855932],
              [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 10.405967]],
             1: [[[1, 0, 0, 0, 0, 1, 0, 0, 1, 0], 17.855932],
              [[0, 0, 0, 0, 0, 0, 1, 0, 0, 0], 10.405967]],
             2: [[[0, 0, 0, 0, 0, 0, 0, 0, 1, 1], 17.855932],
              [[2, 0, 0, 0, 0, 0, 0, 0, 0, 0], 10.405967]],
             3: [[[0, 3, 0, 0, 0, 0, 0, 0, 0, 0], 17.855932]],
             4: [[[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 17.855932],
              [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 10.405967]],
             5: [[[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 17.855932],
              [[1, 0, 0, 0, 0, 0, 1, 0, 0, 0], 10.405967]],
             6: [[[0, 1, 0, 0, 0, 0, 0, 1, 0, 0], 17.855932],
              [[1, 0, 0, 0, 1, 0, 0, 0, 0, 0], 10.405967]],
             7: [[[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 17.855932],
              [[0, 0, 0, 0, 0, 0, 1, 0, 0, 0], 10.405967]],
             8: [[[0, 2, 0, 0, 0, 0, 0, 1, 0, 0], 17.855932],
   

In [28]:
count_subgraph[0][1]

[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 10.405967]