In [None]:
import os
import time
import sys

# If you are not using Kaggle Notebook, please import the necessary files in the proper way
sys.path.append('/kaggle/input/resnetgray/')
sys.path.append('/kaggle/input/newretransform/')

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# This line imports the existing ResNets architectures, train functions, and custom 1-channel ResNets
import resnetgray

# This line imports the extended RE transformation function
from new_retransform import RandomErasingTransform

for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
        
if not os.path.exists('/kaggle/working/save_temp'):
    os.mkdir('/kaggle/working/save_temp')
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Idelbayev, Y. Proper ResNet implementation for CIFAR10/CIFAR100 in PyTorch. https://github.com/akamaster/pytorch_resnet_cifar10
# Slightly modify to get information from the extended 1-channel Resnet architectures module

model_names = sorted(name for name in resnetgray.__dict__
    if name.islower() and not name.startswith("__")
                     and name.startswith("resnet")
                     and callable(resnetgray.__dict__[name]))

print(model_names)

In [None]:
# Idelbayev, Y. Proper ResNet implementation for CIFAR10/CIFAR100 in PyTorch. https://github.com/akamaster/pytorch_resnet_cifar10

best_prec1 = 0

class Args:
    pass

args = Args()

In [None]:
# Idelbayev, Y. Proper ResNet implementation for CIFAR10/CIFAR100 in PyTorch. https://github.com/akamaster/pytorch_resnet_cifar10


def train(train_loader, model, criterion, optimizer, epoch):

    global args

    """
        Run one train epoch
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        target = target.cuda()
        input_var = input.cuda()
        target_var = target
        if args.half:
            input_var = input_var.half()

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        output = output.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      epoch, i, len(train_loader), batch_time=batch_time,
                      data_time=data_time, loss=losses, top1=top1))

In [None]:
# Idelbayev, Y. Proper ResNet implementation for CIFAR10/CIFAR100 in PyTorch. https://github.com/akamaster/pytorch_resnet_cifar10

def validate(val_loader, model, criterion):

    global args

    """
    Run evaluation
    """
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target.cuda()

            if args.half:
                input_var = input_var.half()

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          i, len(val_loader), batch_time=batch_time, loss=losses,
                          top1=top1))

    print(' * Prec@1 {top1.avg:.3f}'
          .format(top1=top1))

    return top1.avg

In [None]:
# Idelbayev, Y. Proper ResNet implementation for CIFAR10/CIFAR100 in PyTorch. https://github.com/akamaster/pytorch_resnet_cifar10

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """
    Save the training model
    """
    torch.save(state, filename)

In [None]:
# Idelbayev, Y. Proper ResNet implementation for CIFAR10/CIFAR100 in PyTorch. https://github.com/akamaster/pytorch_resnet_cifar10

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
# Idelbayev, Y. Proper ResNet implementation for CIFAR10/CIFAR100 in PyTorch. https://github.com/akamaster/pytorch_resnet_cifar10

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

In [None]:
# Idelbayev, Y. Proper ResNet implementation for CIFAR10/CIFAR100 in PyTorch. https://github.com/akamaster/pytorch_resnet_cifar10
# Slightly modify to train on the Fasion-MNIST dataset
# Please modify the transforms of the dataloader to train models with different data augmentations

def main_advanced_re_mnist(arch='resnet20gray', workers=4, epochs=200, start_epoch=0, batch_size=128, lr=0.1,
         momentum=0.9, weight_decay=1e-4, print_freq=50, resume='', evaluate=False,
         pretrained=False, half=False, save_dir='./save_temp', save_every=10):

    global best_prec1
    global args


    args.arch = arch
    args.workers = workers
    args.epochs = epochs
    args.start_epoch = start_epoch
    args.batch_size = batch_size
    args.lr = lr
    args.momentum = momentum
    args.weight_decay = weight_decay
    args.print_freq = print_freq
    args.resume = resume
    args.evaluate = evaluate
    args.pretrained = pretrained
    args.half = half
    args.save_dir = save_dir
    args.save_every = save_every

    model = torch.nn.DataParallel(resnetgray.__dict__[args.arch]())
    model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.5], std=[0.5])

    train_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST(root='./data', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(28, 4),
            transforms.ToTensor(),

#             Comment out all the following RE transforms to train Baseline model

#             Training RE model:            
#             transforms.RandomErasing(),

#             Training RE cs:
#             RandomErasingTransform(probability=0.5, value=0, sl=0.02, sh=0.33, shape="circle", blend_edges=False, blend_type="random", blend_factor=0),

#             Training RE rs:
#             RandomErasingTransform(probability=0.5, value=0, sl=0.02, sh=0.33, shape="random", blend_edges=False, blend_type="random", blend_factor=0),

#             Training RE be:
#             RandomErasingTransform(probability=0.5, value=0, sl=0.02, sh=0.33, shape="rectangle", blend_edges=True, blend_type="random", blend_factor=0),

#             Training RE cs-be:
#             RandomErasingTransform(probability=0.5, value=0, sl=0.02, sh=0.33, shape="circle", blend_edges=True, blend_type="random", blend_factor=0),

#             Training RE rs-be:
#             RandomErasingTransform(probability=0.5, value=0, sl=0.02, sh=0.33, shape="random", blend_edges=True, blend_type="random", blend_factor=0),
            
            normalize,
        ]), download=True),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST(root='./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=128, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if args.half:
        model.half()
        criterion.half()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[int(epochs*0.5), int(epochs*0.75)], last_epoch=args.start_epoch - 1, gamma=0.1)

    if args.arch in ['resnet1202', 'resnet110']:
        # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
        # then switch back. In this setup it will correspond for first epoch.
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr*0.1


    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train(train_loader, model, criterion, optimizer, epoch)
        lr_scheduler.step()

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,                 
                # Change the output model name corresponding to the augmentation to be more organized in managing models
            }, is_best, filename=os.path.join(args.save_dir, 'checkpoint_baseline.th')) 

        save_checkpoint({
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            # Change the output model name corresponding to the augmentation to be more organized in managing models
        }, is_best, filename=os.path.join(args.save_dir, 'model_re_mnist_baseline.th'))

In [None]:
# Start training
main_advanced_re_mnist(arch='resnet20gray', workers=2, epochs=300, print_freq=468)

In [None]:
# Baseline Fashion-MNIST Testloader

normalize = transforms.Normalize(mean=[0.5], std=[0.5])

transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

In [None]:
# Prepare testloaders

# RE Fashion-MNIST Testloader |standard|

normalize = transforms.Normalize(mean=[0.5], std=[0.5])

transform_re = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomErasing(),
    normalize
])

testset_re = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_re)
testloader_re = torch.utils.data.DataLoader(testset_re, batch_size=4, shuffle=False, num_workers=2)


# RE Fashion-MNIST Testloader |circular shape|

normalize = transforms.Normalize(mean=[0.5], std=[0.5])

transform_re_cir = transforms.Compose([
    transforms.ToTensor(),
    RandomErasingTransform(probability=0.5, value=0, sl=0.02, sh=0.33, shape="circle", blend_edges=False, blend_type="random", blend_factor=0),
    normalize
])

testset_re_cir = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_re_cir)
testloader_re_cir = torch.utils.data.DataLoader(testset_re_cir, batch_size=4, shuffle=False, num_workers=2)


# RE Fashion-MNIST Testloader |random shape|

normalize = transforms.Normalize(mean=[0.5], std=[0.5])

transform_re_rs = transforms.Compose([
    transforms.ToTensor(),
    RandomErasingTransform(probability=0.5, value=0, sl=0.02, sh=0.33, shape="random", blend_edges=False, blend_type="random", blend_factor=0),
    normalize
])

testset_re_rs = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_re_rs)
testloader_re_rs = torch.utils.data.DataLoader(testset_re_rs, batch_size=4, shuffle=False, num_workers=2)


# RE Fashion-MNIST Testloader |blurred edge|

normalize = transforms.Normalize(mean=[0.5], std=[0.5])

transform_re_be = transforms.Compose([
    transforms.ToTensor(),
    RandomErasingTransform(probability=0.5, value=0, sl=0.02, sh=0.33, shape="rectangle", blend_edges=True, blend_type="random", blend_factor=0),
    normalize
])

testset_re_be = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_re_be)
testloader_re_be = torch.utils.data.DataLoader(testset_re_be, batch_size=4, shuffle=False, num_workers=2)


# RE Fashion-MNIST Testloader |circular shape & blurred edge|

normalize = transforms.Normalize(mean=[0.5], std=[0.5])

transform_re_cir_be = transforms.Compose([
    transforms.ToTensor(),
    RandomErasingTransform(probability=0.5, value=0, sl=0.02, sh=0.33, shape="circle", blend_edges=True, blend_type="random", blend_factor=0),
    normalize
])

testset_re_cir_be = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_re_cir_be)
testloader_re_cir_be = torch.utils.data.DataLoader(testset_re_cir_be, batch_size=4, shuffle=False, num_workers=2)


# RE Fashion-MNIST Testloader |random shape & blurred edge|

normalize = transforms.Normalize(mean=[0.5], std=[0.5])

transform_re_rs_be = transforms.Compose([
    transforms.ToTensor(),
    RandomErasingTransform(probability=0.5, value=0, sl=0.02, sh=0.33, shape="circle", blend_edges=True, blend_type="random", blend_factor=0),
    normalize
])

testset_re_rs_be = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_re_rs_be)
testloader_re_rs_be = torch.utils.data.DataLoader(testset_re_rs_be, batch_size=4, shuffle=False, num_workers=2)

In [None]:
# Load and prepare the model for testing

model_re_mnist = resnetgray.resnet20gray(option='B')
print(model_re_mnist.conv1.weight.shape)

# Load the model you want to test

# checkpoint_re_mnist_path = '/kaggle/input/baseline-300/baseline_300.th'
# checkpoint_re_mnist_path = '/kaggle/input/re-300/re_300.th'
# checkpoint_re_mnist_path = '/kaggle/input/re-cs-300/re_cs_300.th'
# checkpoint_re_mnist_path = '/kaggle/input/re-rs-300/re_rs_300.th'
# checkpoint_re_mnist_path = '/kaggle/input/re-be-300/re_be_300.th'
# checkpoint_re_mnist_path = '/kaggle/input/re-cs-be-300/re_cs_be_300.th'
checkpoint_re_mnist_path = '/kaggle/input/re-rs-be-300/re_rs_be_300.th'


checkpoint_re_mnist = torch.load(checkpoint_re_mnist_path)

model_state_dict_re_mnist = checkpoint_re_mnist['state_dict']

print(model_state_dict_re_mnist['module.conv1.weight'].shape)

new_state_dict_re_mnist = {k.replace('module.', ''): v for k, v in model_state_dict_re_mnist.items()}

model_re_mnist.load_state_dict(new_state_dict_re_mnist, strict=False)
model_re_mnist = model_re_mnist.to(device)

In [None]:
# Testing model with Baseline testloader. The output is test error rates

# [RE ResNet20 | Baseline Fashion-MNIST]

model_re_mnist.eval()

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model_re_mnist(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Test Error [RE ResNet20 | Baseline Fashion-MNIST]: %.2f %%' % (100-(100 * correct / total)))


In [None]:
# Testing model with RE testloaders. The outputs are test error rates

# [RE ResNet20 | RE Fashion-MNIST]

model_re_mnist.eval()

correct = 0
total = 0
with torch.no_grad():
    for data in testloader_re:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model_re_mnist(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Test Error [RE ResNet20 | RE Fashion-MNIST]: %.2f %%' % (100-(100 * correct / total)))

# [RE ResNet20 | RE Fashion-MNIST] |circular shape|

model_re_mnist.eval()

correct = 0
total = 0
with torch.no_grad():
    for data in testloader_re_cir:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model_re_mnist(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Test Error [RE ResNet20 | RE Fashion-MNIST] |circular shape|: %.2f %%' % (100-(100 * correct / total)))

# [RE ResNet20 | RE Fashion-MNIST] |random shape|

model_re_mnist.eval()

correct = 0
total = 0
with torch.no_grad():
    for data in testloader_re_rs:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model_re_mnist(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Test Error [RE ResNet20 | RE Fashion-MNIST] |random shape|: %.2f %%' % (100-(100 * correct / total)))

# [RE ResNet20 | RE Fashion-MNIST] |blurred edge|

model_re_mnist.eval()

correct = 0
total = 0
with torch.no_grad():
    for data in testloader_re_be:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model_re_mnist(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Test Error [RE ResNet20 | RE Fashion-MNIST] |blurred edge|: %.2f %%' % (100-(100 * correct / total)))

# [RE ResNet20 | RE Fashion-MNIST] |circular shape & blurred edge|

model_re_mnist.eval()

correct = 0
total = 0
with torch.no_grad():
    for data in testloader_re_cir_be:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model_re_mnist(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Test Error [RE ResNet20 | RE Fashion-MNIST] |circular shape & blurred edge|: %.2f %%' % (100-(100 * correct / total)))

# [RE ResNet20 | RE Fashion-MNIST] |random shape & blurred edge|

model_re_mnist.eval()

correct = 0
total = 0
with torch.no_grad():
    for data in testloader_re_rs_be:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model_re_mnist(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Test Error [RE ResNet20 | RE Fashion-MNIST] |random shape & blurred edge|: %.2f %%' % (100-(100 * correct / total)))
