In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Block(nn.Module):
    '''expand + depthwise + pointwise + squeeze-excitation'''

    def __init__(self, in_planes, out_planes, expansion, stride):
        super(Block, self).__init__()
        self.stride = stride

        planes = expansion * in_planes
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, groups=planes, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(
            planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_planes)

        self.shortcut = nn.Sequential()
        if stride == 1 and in_planes != out_planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=1,
                          stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_planes),
            )

        # SE layers
        self.fc1 = nn.Conv2d(out_planes, out_planes//16, kernel_size=1)
        self.fc2 = nn.Conv2d(out_planes//16, out_planes, kernel_size=1)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        shortcut = self.shortcut(x) if self.stride == 1 else out
        # Squeeze-Excitation
        w = F.avg_pool2d(out, out.size(2))
        w = F.relu(self.fc1(w))
        w = self.fc2(w).sigmoid()
        out = out * w + shortcut
        return out


class EfficientNet(nn.Module):
    def __init__(self, cfg, num_classes=10):
        super(EfficientNet, self).__init__()
        self.cfg = cfg
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.layers = self._make_layers(in_planes=32)
        self.linear = nn.Linear(cfg[-1][1], num_classes)

    def _make_layers(self, in_planes):
        layers = []
        for expansion, out_planes, num_blocks, stride in self.cfg:
            strides = [stride] + [1]*(num_blocks-1)
            for stride in strides:
                layers.append(Block(in_planes, out_planes, expansion, stride))
                in_planes = out_planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layers(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def EfficientNetB0():
    # (expansion, out_planes, num_blocks, stride)
    cfg = [(1,  16, 1, 2),
           (6,  24, 2, 1),
           (6,  40, 2, 2),
           (6,  80, 3, 2),
           (6, 112, 3, 1),
           (6, 192, 4, 2),
           (6, 320, 1, 2)]
    return EfficientNet(cfg)


def test():
    net = EfficientNetB0()
    x = torch.randn(2, 3, 32, 32)
    y = net(x)
    print(y.shape)

In [0]:
import os
import sys
import time
import math

import torch.nn as nn
import torch.nn.init as init


def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:,i,:,:].mean()
            std[i] += inputs[:,i,:,:].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std

def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode='fan_out')
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-3)
            if m.bias:
                init.constant(m.bias, 0)


# _, term_width = os.popen('stty size', 'r').read().split()
term_width = 80
term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f

In [0]:
'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import numpy as np
from torch.utils.data import Dataset, DataLoader

# from models import *
# from utils import progress_bar


# parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
# parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
# parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
# args = parser.parse_args()

lr = 0.01

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)


# trainset.targets***************************************************************
import random
length = len(trainset.targets)
percentage_corruption = 10
n = length*percentage_corruption/100
corrupt_classes = np.random.randint(0,10,int(n))

corrupt_idx = random.sample(range(0,length), int(n))
a = np.array(trainset.targets)
a[corrupt_idx] = corrupt_classes
trainset.targets = list(a)

class CustomizeDataset(Dataset):
    """MosaicDataset dataset."""

    def __init__(self, dataset, corrupt):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data = dataset.data
        self.data= np.rollaxis(self.data, 3,1)
        self.data =  np.float32(self.data)
#         print(self.data.dtype)
        #print(self.data.shape)
        self.label = dataset.targets
        self.corrupt_idx = corrupt
        self.indicator = torch.zeros(len(self.label)).byte()
        k = len(self.label)
        for i in range(k):
          if i in self.corrupt_idx :
            self.indicator[i]=1

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
#       print(self.data[idx].shape)
      return self.data[idx] , self.label[idx], self.indicator[idx]


train_data = CustomizeDataset(trainset,corrupt_idx)
trainloader = DataLoader( train_data, batch_size= 128,shuffle= False, num_workers=2)


#**********************************************************************************

testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
# net = VGG('VGG19')
# net = ResNet18()
# net = PreActResNet18()
# net = GoogLeNet()
# net = DenseNet121()
# net = ResNeXt29_2x64d()
# net = MobileNet()
# net = MobileNetV2()
# net = DPN92()
# net = ShuffleNetG2()
# net = SENet18()
# net = ShuffleNetV2(1)
net = EfficientNetB0()
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

# if args.resume:
#     # Load checkpoint.
#     print('==> Resuming from checkpoint..')
#     assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
#     checkpoint = torch.load('./checkpoint/ckpt.pth')
#     net.load_state_dict(checkpoint['net'])
#     best_acc = checkpoint['acc']
#     start_epoch = checkpoint['epoch']

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr, momentum=0.9, weight_decay=5e-4)
# optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

train_acc =[]
test_acc=[]
epoch_list=[]
# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    
    train_loss = 0
    tl=0
    train_corrupt_loss = 0
    tcl=0
    train_true_loss = 0
    ttl=0
    
    correct = 0
    true_correct=0
    corrupt_correct=0
    
    total = 0
    corrupt_total = 0
    true_total=0
    
    for batch_idx, (inputs, targets, indicator ) in enumerate(trainloader):
#         batch_corrupt_loss , batch_true_loss, batch_total_loss = 0,0,0
        inputs, targets, indicator = inputs.to(device), targets.to(device), indicator.to(device)
#         print(inputs,targets.size(),indicator.size())
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        batch_total_loss = loss.item()
        tl += loss.item()
        train_loss += batch_total_loss*len(indicator)
        
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
#         print(indicator.size)
#         print(sum(~indicator),sum(indicator))
#         ***************************************************************
        corrupt_out = (outputs[indicator])
        loss_corrupt = criterion(corrupt_out, targets[indicator])
        batch_corrupt_loss = loss_corrupt.item()
        tcl += loss_corrupt.item()
        train_corrupt_loss += batch_corrupt_loss*(sum(indicator).item())
       
        _, predict_corrupt = corrupt_out.max(1)
        corrupt_total += targets[indicator].size(0)
        corrupt_correct += predict_corrupt.eq(targets[indicator]).sum().item()
        
        true_out = (outputs[~indicator])
        loss_true = criterion(true_out, targets[~indicator])
        batch_true_loss = loss_true.item()
        ttl += loss_true.item()
        train_true_loss += batch_true_loss*(sum(~indicator).item())
        
        _, predict_true = true_out.max(1)
        true_total += targets[~indicator].size(0)
        true_correct += predict_true.eq(targets[~indicator]).sum().item()
        
#         ****************************************************************************
        loss.backward()
        optimizer.step()
        

#         progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
#             % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    print("corrupt_loss", train_corrupt_loss, tcl)
    print("true loss", train_true_loss, ttl)
    print("Train loss of full data", train_loss, tl)
    print("total_corrupt, total_true and total_data respectively are :  ",corrupt_total, true_total, total)
    print("Train accuracy on corrupted train-data", corrupt_correct/corrupt_total )
    print("Train accuracy on un-corrupted train-data", true_correct/true_total)
    print("Train accuracy on full train-data ", correct/total )
    train_acc.append(train_loss)

def test(epoch):
    global best_acc
#     net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

#             progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
#                 % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    print("test accuracy ", correct/total , test_loss)
    test_acc.append(test_loss)

    # Save checkpoint.
#     acc = 100.*correct/total
#     if acc > best_acc:
#         print('Saving..')
#         state = {
#             'net': net.state_dict(),
#             'acc': acc,
#             'epoch': epoch,
#         }
#         if not os.path.isdir('checkpoint'):
#             os.mkdir('checkpoint')
#         torch.save(state, './checkpoint/ckpt.pth')
#         best_acc = acc


for epoch in range(start_epoch, start_epoch+170):
    epoch_list.append(epoch)
    train(epoch)
    test(epoch)

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
==> Building model..

Epoch: 0
corrupt_loss 15464.679995417595 1209.4400079250336
true loss 71016.23705005646 617.195858836174
Train loss of full data 86480.91597366333 676.1586798429489
total_corrupt, total_true and total_data respectively are :   5000 45000 50000
Train accuracy on corrupted train-data 0.0926
Train accuracy on un-corrupted train-data 0.43233333333333335
Train accuracy on full train-data  0.39836
test accuracy  0.5088 137.9361857175827

Epoch: 1
corrupt_loss 16828.213594198227 1314.8855493068695
true loss 55244.09697389603 479.9511198401451
Train loss of full data 72072.31063652039 563.5409697294235
total_corrupt, total_true and total_data respectively are :   5000 45000 50000
Train accuracy on corrupted train-data 0.099
Train accuracy on un-corrupted train-data 0.5731555555555555
Train accuracy on full train-data  0.52574
test accuracy  0.5648 124.18820321559906

Epoch: 2


In [0]:
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
import random
length = len(testset.targets)
percentage_corruption = 10
n = length*percentage_corruption/100
corrupt_classes = np.random.randint(0,10,int(n))
# print(np.unique(corrupt_idx))
corrupt_idx = random.sample(range(0,length), int(n))
print(len(corrupt_idx) , len(np.unique(corrupt_idx)))
print(corrupt_idx)
print(len(corrupt_classes))

st1 = testset.targets
a = np.array(testset.targets)
a[corrupt_idx] = corrupt_classes
testset.targets = list(a)
st2 = testset.targets

print(st1 == st2)

Files already downloaded and verified
1000 1000
[7915, 7136, 319, 1036, 1797, 59, 9825, 1139, 2767, 4998, 5790, 1414, 9622, 3574, 5439, 3966, 5412, 2004, 3942, 6438, 2096, 7689, 1894, 9979, 3000, 3311, 2138, 4536, 9586, 290, 8778, 5082, 6629, 3768, 4443, 4588, 2584, 4821, 4960, 6491, 6281, 2736, 6960, 9690, 4656, 41, 3940, 9153, 9638, 9738, 9592, 3611, 5921, 2151, 6035, 3809, 2349, 5617, 6506, 9929, 6631, 5650, 9635, 8885, 763, 1546, 194, 2311, 4790, 6071, 5524, 8468, 6673, 5677, 9989, 8150, 9056, 2418, 6766, 9291, 4279, 5033, 4874, 5725, 6607, 4779, 2623, 8924, 4789, 8013, 207, 7740, 245, 5058, 7165, 2433, 9476, 5671, 7632, 5759, 8819, 651, 7036, 8176, 7188, 5899, 1913, 165, 1807, 6294, 8618, 7434, 1308, 5239, 4304, 2575, 4130, 1438, 5607, 5297, 5624, 2121, 3688, 4556, 125, 2535, 1355, 3184, 3569, 7241, 1826, 3133, 5780, 7134, 8654, 5900, 7399, 910, 1609, 8031, 7725, 4776, 5913, 7279, 9843, 20, 2528, 1970, 8586, 6091, 8062, 63, 2974, 9448, 5765, 2071, 6705, 3248, 8668, 7156, 5293, 492

In [0]:
test_data = CustomizeDataset(testset,corrupt_idx)
testloader = DataLoader( test_data, batch_size= 100,shuffle= False, num_workers=2)

# testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
global best_acc

test_loss = 0
tl=0
test_corrupt_loss = 0
tcl=0
test_true_loss = 0
ttl=0

correct = 0
true_correct=0
corrupt_correct=0

total = 0
corrupt_total = 0
true_total=0
with torch.no_grad():
    for batch_idx, (inputs, targets, indicator) in enumerate(testloader):
        inputs, targets, indicator = inputs.to(device), targets.to(device), indicator.to(device)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        batch_total_loss = loss.item()
        tl += batch_total_loss
        test_loss += batch_total_loss*len(indicator)

#         test_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
#         ***************************************************************************
        corrupt_out = (outputs[indicator])
        loss_corrupt = criterion(corrupt_out, targets[indicator])
        batch_corrupt_loss = loss_corrupt.item()
        tcl += loss_corrupt.item()
        test_corrupt_loss += batch_corrupt_loss*(sum(indicator).item())
       
        _, predict_corrupt = corrupt_out.max(1)
        corrupt_total += targets[indicator].size(0)
        corrupt_correct += predict_corrupt.eq(targets[indicator]).sum().item()
        
        true_out = (outputs[~indicator])
        loss_true = criterion(true_out, targets[~indicator])
        batch_true_loss = loss_true.item()
        ttl += loss_true.item()
        test_true_loss += batch_true_loss*(sum(~indicator).item())
        
        _, predict_true = true_out.max(1)
        true_total += targets[~indicator].size(0)
        true_correct += predict_true.eq(targets[~indicator]).sum().item()
#         *****************************************************************************

#             progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
#                 % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

print("corrupt_loss", test_corrupt_loss, tcl)
print("true loss", test_true_loss, ttl)
print("Test loss of full data", test_loss, tl)
print("total_corrupt, total_true and total_data respectively are :  ",corrupt_total, true_total, total)
print("Test accuracy on corrupted train-data", corrupt_correct/corrupt_total )
print("Test accuracy on un-corrupted train-data", true_correct/true_total)
print("test accuracy on full test-data", correct/total)
print(correct, total)


corrupt_loss 7536.210190296173 759.8516936302185
true loss 10285.424175202847 114.29510074853897
Test loss of full data 17821.634590625763 178.21634590625763
total_corrupt, total_true and total_data respectively are :   1000 9000 10000
Test accuracy on corrupted train-data 0.116
Test accuracy on un-corrupted train-data 0.7224444444444444
test accuracy on full test-data 0.6618
6618 10000
