# 9. 성능 개선
## 9.4 준지도 학습 - Pseudo Label

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
cd/content/gdrive/My Drive/pytorch_dlbro

In [None]:
# GPU vs CPU
# 현재 가능한 장치를 확인한다.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# 3.3에서 사용한 양식을 그대로 사용하되 전처리 작업을 할 수 있도록 transform을 추가한다. 
class MyDataset(Dataset):
    
    def __init__(self, x_data, y_data, transform=None):
        
        self.x_data = x_data # torch.floattensor로 들어옴
        self.y_data = y_data#.view(-1,1) # torch.longtensor로 들어옴
        self.transform = transform
        self.len = len(y_data)
    
    def __getitem__(self, index):
        sample = self.x_data[index], self.y_data[index]
        
        if self.transform:
            sample = self.transform(sample) #self.transform이 None이 아니라면 전처리를 작업한다.
        
        return sample 
    
    def __len__(self):
        return self.len       

class TrainTransform:
    
    def __call__(self, sample):
        inputs, labels = sample
        #labels = labels.float()

        transf = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor()
                    ])
        final_output = transf(inputs)      
        
        return final_output, labels 

In [None]:
def balanced_subset(data, labels, num_cls, num_data): # numpy
    num_data_per_class = num_data // num_cls
    data1 = torch.tensor([],dtype=torch.float)
    data2 = torch.tensor([],dtype=torch.float)
    labels1 = torch.tensor([],dtype=torch.long)
    labels2 = torch.tensor([],dtype=torch.long)
    for cls in range(num_cls):
        idx = np.where(labels.numpy() == cls)[0]
        shuffled_idx = np.random.choice(len(idx), len(idx), replace=False)
        data1 = torch.cat([data1, data[shuffled_idx[:num_data_per_class]]], dim=0)
        data2 = torch.cat([data2, data[shuffled_idx[num_data_per_class:]]], dim=0)     
        labels1 = torch.cat([labels1, labels[shuffled_idx[:num_data_per_class]]], dim=0)
        labels2 = torch.cat([labels2, labels[shuffled_idx[num_data_per_class:]]], dim=0)

    return data1, data2, labels1, labels2

In [None]:
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True) # 50000

In [None]:
labeled_data, unlabeled_data, labels, unlabels = balanced_subset(trainset.data, trainset.targets, num_cls=10, num_data=2000)
train_images, val_images, train_labels, val_labels = balanced_subset(labeled_data, labels, num_cls=10, num_data=1000)

In [None]:
train_images = train_images.unsqueeze(1)
val_images = val_images.unsqueeze(1)
trainset = MyDataset(train_images, train_labels, transform=TrainTransform())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)       
validationset = MyDataset(val_images, val_labels)
valloader = torch.utils.data.DataLoader(validationset, batch_size=128, shuffle=False)    

In [None]:
unlabeled_images = unlabeled_data.unsqueeze(1)
unlabeledset = MyDataset(unlabeled_images, unlabels)
unlabeledloader = torch.utils.data.DataLoader(unlabeledset, batch_size=256, shuffle=True)   

In [None]:
# 데이터 불러오기 및 전처리 작업
transform = transforms.Compose([transforms.ToTensor()])
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,shuffle=False)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.features = nn.Sequential(
                        nn.Conv2d(1, 64, 3), nn.ReLU(),
                        nn.MaxPool2d(2, 2),
                        nn.Conv2d(64, 192, 3, padding=1), nn.ReLU(),
                        nn.MaxPool2d(2, 2))       
        self.classifier = nn.Sequential(
                        nn.Dropout(0.5),
                        nn.Linear(192*6*6, 1024), nn.ReLU(),
                        nn.Dropout(0.5),
                        nn.Linear(1024, 512), nn.ReLU(),
                        nn.Linear(512, 10))          
    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 192*6*6)
        x = self.classifier(x)    
        return x

model = Net().to(device) # 모델 선언

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
#scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,200], gamma=0.1)

In [None]:
def accuracy(dataloader):
    correct = 0
    total = 0
    with torch.no_grad():
        model.eval()
        for data in dataloader:
            images, labels = data[0].to(device), data[1].to(device)       
            outputs = model(images)
            _, predicted = torch.max(outputs.detach(), 1)
            total += labels.size(0)      
            correct += (predicted == labels).sum().item()

    acc = 100*correct/total
    model.train()
    return acc

### 학습 데이터로만 모델 학습

In [None]:
best_acc = 0
for epoch in range(501):
    correct = 0
    total = 0
    for traindata in trainloader: 
       
        inputs, labels = traindata[0].to(device), traindata[1].to(device)     
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)      
        loss.backward()
        optimizer.step()
        _, predicted = torch.max(outputs.detach(), 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    val_acc = accuracy(valloader)
    if val_acc >= best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), './models/cifar_model_for_pseudo_baseline.pth')  
        print('[%d] train acc: %.2f, validation acc: %.2f - Saved the best model' %(epoch, 100*correct/total, val_acc))  
    elif epoch % 10 == 0:
        print('[%d] train acc: %.2f, validation acc: %.2f' %(epoch, 100*correct/total, val_acc))  

In [None]:
model.load_state_dict(torch.load('./models/cifar_model_for_pseudo_baseline.pth'))
accuracy(testloader)

## 첫번째 방법
예측값을 기준으로하는 의사라벨과 예측값을 비교하여 손실 함수를 계산한다.

In [None]:
model = Net().to(device) # 모델 선언
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
alpha = 0
alpha_t = 1e-4
T1 = 100
T2 = 450
best_acc = 0

for epoch in range(501):
    correct = 0
    total = 0
    for traindata, pseudodata in zip(trainloader, unlabeledloader): 
       
        inputs, labels = traindata[0].to(device), traindata[1].to(device)     
        pinputs = pseudodata[0].to(device) 
        optimizer.zero_grad()
        outputs = model(inputs)

        if alpha > 0:            
            poutputs = model(pinputs)  
            _, plabels = torch.max(poutputs.detach(), 1)     
            loss = criterion(outputs, labels)  + alpha*criterion(poutputs, plabels)   
        else:    
            loss = criterion(outputs, labels)    
              
        loss.backward()
        optimizer.step()
        _, predicted = torch.max(outputs.detach(), 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    if (epoch > T1) and (epoch < T2):
        alpha = alpha_t*(epoch - T1)/(T2 - T1)
    elif epoch >= T2:    
        alpha = alpha_t

    val_acc = accuracy(valloader)
    if val_acc >= best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), './models/cifar_model_for_pseudo_label.pth')    
        print('[%d] train acc: %.2f, validation acc: %.2f - Saved the best model' %(epoch, 100*correct/total, val_acc))  

    elif epoch % 10 == 0:
        print('[%d] train acc: %.2f, validation acc: %.2f' %(epoch, 100*correct/total, val_acc))  

In [None]:
model.load_state_dict(torch.load('./models/cifar_model_for_pseudo_label.pth'))
accuracy(testloader)

## 두번째 방법
학습 데이터로만 학습한 모델을 가지고 의사라벨을 만들어 데이터로 활용한다.

In [None]:
model = Net().to(device) # 모델 선언
model.load_state_dict(torch.load('./models/cifar_model_for_pseudo_baseline.pth'))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
pseudo_threshold = 0.99
pseudo_images = torch.tensor([], dtype=torch.float)
pseudo_labels = torch.tensor([], dtype=torch.long)

with torch.no_grad():
    for data in tqdm(unlabeledloader):
        model.eval()
        images = data[0].to(device)
        outputs = model(images)
        #print(outputs.size())
        outputs = torch.nn.functional.softmax(outputs, dim=1)
        
        max_val, predicted = torch.max(outputs.detach(), 1)
        idx = np.where(max_val.cpu().numpy() >= pseudo_threshold)[0]
        if len(idx) > 0:
            pseudo_images = torch.cat((pseudo_images, images.cpu()[idx]), 0) 
            pseudo_labels = torch.cat((pseudo_labels, predicted.cpu()[idx]), 0)

#print(pseudo_images.size(), pseudo_labels.size())        

In [None]:
print(pseudo_images.size(), pseudo_labels.size())   

In [None]:
pseudo_dataset = MyDataset(pseudo_images, pseudo_labels)
pseudoloader = torch.utils.data.DataLoader(pseudo_dataset, batch_size=256, shuffle=True)   

In [None]:
alpha = 0
alpha_t = 1e-4
T1 = 20
T2 = 450
best_acc = 0

for epoch in range(501):
    correct = 0
    total = 0
    for traindata, pseudodata in zip(trainloader, pseudoloader): 
       
        inputs, labels = traindata[0].to(device), traindata[1].to(device)     
        pinputs, plabels = pseudodata[0].to(device), pseudodata[1].to(device)    
        optimizer.zero_grad()
        outputs = model(inputs)
        poutputs = model(pinputs)
        loss = criterion(outputs, labels) + alpha*criterion(poutputs, plabels)         
        loss.backward()
        optimizer.step()
        _, predicted = torch.max(outputs.detach(), 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    #scheduler.step()
    if (epoch > T1) and (epoch < T2):
        alpha = alpha_t*(epoch - T1)/(T2 - T1)
        
    elif epoch >= T2:    
        alpha = alpha_t

    val_acc = accuracy(valloader)
    if val_acc >= best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), './models/cifar_model_for_pseudo_label2.pth') 
        print('[%d] train acc: %.2f, validation acc: %.2f - Saved the best model' %(epoch, 100*correct/total, val_acc))  

    elif epoch % 10 == 0:
        print('[%d] train acc: %.2f, validation acc: %.2f' %(epoch, 100*correct/total, val_acc))  

        


In [None]:
model.load_state_dict(torch.load('./models/cifar_model_for_pseudo_label2.pth'))
accuracy(testloader)