# 9. 성능 개선
## 9.1. 과적합 - DisturbLabel



In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
cd/content/gdrive/My Drive/pytorch_dlbro

In [None]:
import torch
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True) 

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)

In [None]:
def imshow(img):
    img = .5*img + .5
    plt.figure(figsize=(10,100))
    plt.imshow(img.permute(1,2,0).numpy())
    plt.show()

images, labels = iter(trainloader).next()
imshow(torchvision.utils.make_grid(images,nrow=8))
print(images.size()) # 배치 및 이미지 크기 확인

In [None]:
# CPU/GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'{device} is available.')

In [None]:
resnet = models.resnet18()
resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Sequential(nn.Dropout2d(0.5), nn.Linear(num_ftrs, 10))
resnet = resnet.to(device)

In [None]:
print(resnet)
PATH = './models/cifar_resnet_disturblabel.pth' # 모델 저장 경로 

In [None]:
class DisturbLabel(torch.nn.Module):
    def __init__(self, alpha, num_classes):
        super(DisturbLabel, self).__init__()
        self.alpha = alpha
        self.C = num_classes
        self.p_c = (1 - ((self.C - 1) / self.C) * (alpha / 100)) # Multinoulli distribution
        self.p_i = (1-self.p_c)/(self.C-1)

    def forward(self, y):
        y_tensor = y.type(torch.LongTensor).view(-1, 1)      
        depth = self.C
        y_one_hot = torch.ones(y_tensor.size()[0], depth) * self.p_i
        y_one_hot.scatter_(1, y_tensor, self.p_c)
        y_one_hot = y_one_hot.view(*(tuple(y.shape) + (-1,))) # create disturbed labels    
        distribution = torch.distributions.OneHotCategorical(y_one_hot) # sample from Multinoulli distribution
        y_disturbed = distribution.sample()
        y_disturbed = y_disturbed.max(dim=1)[1]
        return y_disturbed

In [None]:
disturblabels = DisturbLabel(alpha=30, num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet.parameters(), lr=1e-3)

In [None]:
loss_ = [] # 그래프를 그리기 위한 loss 저장용 리스트 
n = len(trainloader) # 배치 개수

for epoch in range(50): 

    running_loss = 0.0
    for data in trainloader:

        inputs, labels = data[0].to(device), data[1].to(device) # 배치 데이터    
        optimizer.zero_grad()
        outputs = resnet(inputs) # 예측값 산출 
        labels = disturblabels(labels).to(device)
        loss = criterion(outputs, labels) # 손실함수 계산
        loss.backward() # 손실함수 기준으로 역전파 선언
        optimizer.step() # 가중치 최적화
        running_loss += loss.item()

    loss_.append(running_loss / n)    
    print('[%d] loss: %.3f' %(epoch + 1, running_loss / n))

torch.save(resnet.state_dict(), PATH)
print('Finished Training')

# AlexNet: 0.207 (epoch 50)
# ResNet: 0.083 (epoch 10)

In [None]:
plt.plot(loss_)
plt.title("Training Loss")
plt.xlabel("epoch")
plt.show()

In [None]:
torch.save(resnet.state_dict(), PATH) # 모델 저장

In [None]:
resnet = models.resnet18()
resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Sequential(nn.Dropout2d(0.5), nn.Linear(num_ftrs, 10))
resnet = resnet.to(device)
resnet.load_state_dict(torch.load(PATH)) # 모델 파라메타 불러오기

In [None]:
correct = 0
total = 0
with torch.no_grad():
    resnet.eval()
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = resnet(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item() 
        
print('Test accuracy: %.2f %%' % (100 * correct / total))
