In [None]:
import torch
import torch.nn as nn
import torchvision.datasets as dset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchvision.models import alexnet


In [None]:
batch_size = 100
num_epochs = 10
learning_rate = 0.001
class_num = 10
T = 10
rr = 0.1

device = torch.device('cuda')
torch.manual_seed(777)
feature_extract = True

In [None]:

root = './MNIST_Fashion'
transform = transforms.Compose([transforms.Resize(224),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5,), std=(0.5,))])

train_data = dset.FashionMNIST(root=root, train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)

test_data = dset.FashionMNIST(root=root, train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./MNIST_Fashion/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=26421880.0), HTML(value='')))


Extracting ./MNIST_Fashion/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./MNIST_Fashion/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./MNIST_Fashion/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=29515.0), HTML(value='')))


Extracting ./MNIST_Fashion/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST_Fashion/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./MNIST_Fashion/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4422102.0), HTML(value='')))


Extracting ./MNIST_Fashion/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST_Fashion/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./MNIST_Fashion/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=5148.0), HTML(value='')))


Extracting ./MNIST_Fashion/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST_Fashion/FashionMNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [None]:

model = alexnet(pretrained=True, num_classes=1000)

for param in model.parameters():
          param.requires_grad = False
model.features[0] = nn.Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
model.classifier[6] = nn.Linear(in_features=4096, out_features= 2 * class_num, bias=True)
model.to(device)
print(model)

Downloading: "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-4df8aa71.pth


HBox(children=(FloatProgress(value=0.0, max=244418560.0), HTML(value='')))


AlexNet(
  (features): Sequential(
    (0): Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)


In [None]:
params_to_update = model.parameters()
print("Params to learn:")
params_to_update = []

for name,param in model.named_parameters():
    if param.requires_grad == True:
        params_to_update.append(param)
        print("\t",name)


# Observe that all parameters are being optimized
optimizer = torch.optim.Adam(params_to_update, lr=learning_rate)

Params to learn:
	 features.0.weight
	 features.0.bias
	 classifier.6.weight
	 classifier.6.bias


In [None]:
criterion = torch.nn.CrossEntropyLoss().to(device)

In [None]:
costs = []
total_batch = len(train_loader)
LogSoftmax = torch.nn.LogSoftmax(dim=1)
T = 1
for epoch in range(num_epochs):
    total_cost = 0
    for i, (imgs, labels) in enumerate(train_loader):
        imgs, labels = imgs.to(device), labels.to(device)

        sampling_outputs = model(imgs)
        sampling_logit, sampling_std = sampling_outputs[:, :class_num], sampling_outputs[:, class_num : ]

        x = torch.zeros(10)
        e = torch.normal(x, 1).to(device)
        noised_outputs =  sampling_logit + sampling_std * e
        loss = criterion(noised_outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_cost += loss
    avg_cost = total_cost / total_batch
    print("Epoch:", "%03d" % (epoch+1), "Cost =", "{:.9f}".format(avg_cost))
    costs.append(avg_cost)

Epoch: 001 Cost = 0.787873447
Epoch: 002 Cost = 0.605133832
Epoch: 003 Cost = 0.566681862
Epoch: 004 Cost = 0.533443093
Epoch: 005 Cost = 0.522577643
Epoch: 006 Cost = 0.512051582
Epoch: 007 Cost = 0.514473975
Epoch: 008 Cost = 0.497740448
Epoch: 009 Cost = 0.497214258
Epoch: 010 Cost = 0.481660992


In [None]:
model.eval()

with torch.no_grad():
    correct = 0
    safe_correct = 0
    total = 0
    reject = 0
    for i, (imgs, labels) in enumerate(test_loader):

        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        x = torch.zeros(10)
        e = torch.normal(x, 1).to(device)

        noised_outputs =  outputs[:, :class_num] + outputs[:, class_num:] * e
        _, argmax = torch.max(noised_outputs, 1)

        total += imgs.size(0)
        batch_correct = (labels == argmax).sum().item()
        correct += batch_correct


    print('Accuracy for total images {} {:.2f}%'.format(total, correct / (total) * 100))

Accuracy for total images 10000, rejcets images: 0 reject rate : 0.1  85.10% -> 0.00%


In [None]:
# 논문에 나온 그대로 cost 설정했을 때 44%

Accuracy for total images 10000, rejcets images: 0 reject rate : 0.1  44.02% -> 0.00%


In [None]:
# PRETRAIN X
feature_extract = False

model = alexnet(pretrained=False, num_classes=1000)


if feature_extract:
  for param in model.parameters():
     param.requires_grad = False

model.features[0] = nn.Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
model.classifier[6] = nn.Linear(in_features=4096, out_features=20, bias=True)
model.to(device)
print(model)
criterion = torch.nn.CrossEntropyLoss().to(device)


AlexNet(
  (features): Sequential(
    (0): Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [None]:
params_to_update = model.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

# Observe that all parameters are being optimized
optimizer = torch.optim.Adam(params_to_update, lr=learning_rate)

Params to learn:
	 features.0.weight
	 features.0.bias
	 features.3.weight
	 features.3.bias
	 features.6.weight
	 features.6.bias
	 features.8.weight
	 features.8.bias
	 features.10.weight
	 features.10.bias
	 classifier.1.weight
	 classifier.1.bias
	 classifier.4.weight
	 classifier.4.bias
	 classifier.6.weight
	 classifier.6.bias


In [None]:
costs = []
total_batch = len(train_loader)
LogSoftmax = torch.nn.LogSoftmax(dim=1)
T = 1
for epoch in range(num_epochs):
    total_cost = 0
    for i, (imgs, labels) in enumerate(train_loader):
        imgs, labels = imgs.to(device), labels.to(device)

        sampling_outputs = model(imgs)
        sampling_logit, sampling_std = sampling_outputs[:, :class_num], sampling_outputs[:, class_num : ]

        x = torch.zeros(10)
        e = torch.normal(x, 1).to(device)
        noised_outputs =  sampling_logit + sampling_std * e
        loss = criterion(noised_outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_cost += loss
    avg_cost = total_cost / total_batch
    print("Epoch:", "%03d" % (epoch+1), "Cost =", "{:.9f}".format(avg_cost))
    costs.append(avg_cost)

Epoch: 001 Cost = 0.723779082
Epoch: 002 Cost = 0.369990408
Epoch: 003 Cost = 0.319598585
Epoch: 004 Cost = 0.293592244
Epoch: 005 Cost = 0.273541629
Epoch: 006 Cost = 0.267542511
Epoch: 007 Cost = 0.254703373
Epoch: 008 Cost = 0.240575030
Epoch: 009 Cost = 0.229396299
Epoch: 010 Cost = 0.224156931


In [None]:
model.eval()

with torch.no_grad():
    correct = 0
    safe_correct = 0
    total = 0
    reject = 0
    for i, (imgs, labels) in enumerate(test_loader):

        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        x = torch.zeros(10)
        e = torch.normal(x, 1).to(device)

        noised_outputs =  outputs[:, :class_num] + outputs[:, class_num:] * e
        _, argmax = torch.max(noised_outputs, 1)

        total += imgs.size(0)
        batch_correct = (labels == argmax).sum().item()
        correct += batch_correct


    print('Accuracy for total {} images {:.2f}% '.format(total,  correct / (total) * 100))

Accuracy for total 10000 images 91.22% 


In [None]:
# data uncertainty 제거
model.eval()

with torch.no_grad():
    correct = 0
    safe_correct = 0
    total = 0
    reject = 0
    for i, (imgs, labels) in enumerate(test_loader):

        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        x = torch.zeros(10)
        e = torch.normal(x, 1).to(device)

        noised_outputs =  outputs[:, :class_num] + outputs[:, class_num:] * e
        _, argmax = torch.max(noised_outputs, 1)

        total += imgs.size(0)
        batch_correct = (labels == argmax).sum().item()
        correct += batch_correct


    print('Accuracy for total {} images {:.2f}% '.format(total,  correct / (total) * 100))