In [2]:
import torch
import torch.nn as nn
import torchvision.datasets as dset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchvision.models import alexnet


In [3]:
batch_size = 100
num_epochs = 10
learning_rate = 0.001
class_num = 10


device = torch.device('cuda')
torch.manual_seed(777)
feature_extract = True

In [4]:
root = './MNIST_Fashion'
transform = transforms.Compose([transforms.Resize(224),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5,), std=(0.5,))])

train_data = dset.FashionMNIST(root=root, train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)

test_data = dset.FashionMNIST(root=root, train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True)

In [5]:
criterion = torch.nn.CrossEntropyLoss().to(device)

In [6]:
# PRETRAIN X
feature_extract = False

model = alexnet(pretrained=False, num_classes=1000)


if feature_extract:
  for param in model.parameters():
     param.requires_grad = False

model.features[0] = nn.Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
model.classifier[6] = nn.Linear(in_features=4096, out_features=20, bias=True)
model.to(device)
print(model)
criterion = torch.nn.CrossEntropyLoss().to(device)




AlexNet(
  (features): Sequential(
    (0): Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [7]:
params_to_update = model.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

# Observe that all parameters are being optimized
optimizer = torch.optim.Adam(params_to_update, lr=learning_rate)

Params to learn:
	 features.0.weight
	 features.0.bias
	 features.3.weight
	 features.3.bias
	 features.6.weight
	 features.6.bias
	 features.8.weight
	 features.8.bias
	 features.10.weight
	 features.10.bias
	 classifier.1.weight
	 classifier.1.bias
	 classifier.4.weight
	 classifier.4.bias
	 classifier.6.weight
	 classifier.6.bias


In [8]:
costs = []
total_batch = len(train_loader)
LogSoftmax = torch.nn.LogSoftmax(dim=1)
num_epochs = 3

for epoch in range(num_epochs):
    total_cost = 0
    for i, (imgs, labels) in enumerate(train_loader):
        imgs, labels = imgs.to(device), labels.to(device)

        sampling_outputs = model(imgs)
        sampling_logit, sampling_std = sampling_outputs[:, :class_num], sampling_outputs[:, class_num : ]

        x = torch.zeros(10)
        e = torch.normal(x, 1).to(device)
        noised_outputs =  sampling_logit + sampling_std * e
        loss = criterion(noised_outputs, labels)

        optimizer.zero_grad()
        loss.backward()sadasda
        optimizer.step()

        total_cost += loss
    avg_cost = total_cost / total_batch
    print("Epoch:", "%03d" % (epoch+1), "Cost =", "{:.9f}".format(avg_cost))
    costs.append(avg_cost)

Epoch: 001 Cost = 0.731570721
Epoch: 002 Cost = 0.387100071
Epoch: 003 Cost = 0.352068782


In [12]:
model.train()
T = 2
rr = 0.1
sm = torch.nn.Softmax(dim = 1)

with torch.no_grad():
    correct = 0
    safe_correct = 0
    total = 0
    reject = 0
    for i, (imgs, labels) in enumerate(test_loader):

        imgs, labels = imgs.to(device), labels.to(device)

        ## Sampling
        sampling_out = torch.zeros([T, len(imgs), len(test_data.classes)]).to(device)
        sampling_data_uncertainity = torch.zeros([T, len(imgs), len(test_data.classes)]).to(device)

        for t in range(T):
          outputs = model(imgs)
          x = torch.zeros(10)
          e = torch.normal(x, 1).to(device)
          sampling_out[t] = outputs[:, :class_num]
          sampling_data_uncertainity[t] = outputs[:, class_num:] * e

        # ouputs => 100 x 10 # outputs_std => 100 x 10 # outputs_prob 100 x 10
        outputs = torch.mean(sampling_out, dim = 0)
        data_uncertainity =  torch.mean(sampling_data_uncertainity, dim = 0)
        outputs_std = torch.std(sampling_out, dim = 0)
        outputs_prob = sm(outputs)

        # 로짓 arg_max 고르고 그에 따른 분산, prob
        _, argmax = torch.max(outputs, 1)
        max_std = outputs_std[range(len(labels)), argmax]
        max_prob = outputs_prob[range(len(labels)), argmax]
        max_data_uncertainity = data_uncertainity[range(len(labels)), argmax]

        #  분산 리스케일링 = uncertainty Model Uncetainity + Data Uncertainity
        uncertainty = torch.zeros_like(max_std)
        for i in range(len(imgs)):
          uncertainty[i] = max_std[i] * 1/max_prob[i] + max_data_uncertainity[i]  # max_std[i] * 1/max_prob[i] # Model Uncetainity

        #  리스케일링한 분산 sorting해서 index 얻기
        _, index = uncertainty.sort(dim = 0)

        #  sorting 한 output과 label
        sorted_outputs = outputs[index]
        sorted_labels = labels[index]
        sorted_prob = max_prob[index]

        # 분산 높은 결과 cut
        safe_outputs = sorted_outputs[:int(len(imgs) * (1 -rr))]
        safe_labels = sorted_labels[:int(len(imgs) * (1-rr))]
        _, safe_argmax = torch.max(safe_outputs, 1)


        ## end sampling
        total += imgs.size(0)
        reject +=  imgs.size(0) - safe_outputs.size(0)

        batch_safe_correct = (safe_labels == safe_argmax).sum().item()
        safe_correct += batch_safe_correct

        batch_correct = (labels == argmax).sum().item()
        correct += batch_correct

    print('Accuracy for total images {}, rejcets images: {} reject rate : {}  {:.2f}% -> {:.2f}%'.format(total, reject,  rr, correct / (total) * 100, safe_correct / (total - reject) * 100))

Accuracy for total images 10000, rejcets images: 1000 reject rate : 0.1  87.60% -> 86.60%
