In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from train_main import CIFAR10Classifier
import os

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [2]:
model = CIFAR10Classifier.load_from_checkpoint("experiment_results/resnet18-2022-12-06-22-09/version_0/checkpoints/epoch=49-step=54700.ckpt")
model.eval()
print(model)

CIFAR10Classifier(
  (net): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tra

In [5]:
# from patch_maker import PatchMaker





# cfg = Configuration()
# img_path = cfg.get('TESTING')['test_image_path']
# img = Image.open(img_path)
# loc_grad_smooth = LocalGradientsSmoothing(**cfg.get('DEFAULT'))

# grad_mask = loc_grad_smooth(img).squeeze(0)
# grad_mask = grad_mask.repeat((3, 1, 1))

# img_t = ToTensor()(img)
# collage_t = torch.cat([img_t, grad_mask, img_t * (1 - grad_mask)], dim=-1)
# collage = ToPILImage()(collage_t)
# collage.show()
# result_path = cfg.get('TESTING')['result_path']
# collage.save(result_path)



total_count = 0
correct_count = 0

# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(testloader):
    # Every data instance is an input + label pair
    inputs, labels = data

    logits = model(inputs)
    preds = torch.argmax(logits, dim=1)

    correct = preds == labels

    # print(labels.shape)

    total_count += len(labels)
    correct_count += torch.sum(correct)

print(total_count, correct_count, correct_count / total_count)

10000 tensor(6248) tensor(0.6248)


In [7]:
from PIL import Image
# os.chdir('./local_gradients_smoothing')
from configs import Configuration
from lgs import LocalGradientsSmoothing

cfg = Configuration()
# img_path = cfg.get('TESTING')['test_image_path']
# img = Image.open(img_path)
loc_grad_smooth = LocalGradientsSmoothing(**cfg.get('DEFAULT'))



total_count = 0
correct_count = 0

# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
with torch.no_grad(): 
    for i, data in enumerate(testloader):
        # Every data instance is an input + label pair
        inputs, labels = data

        grad_mask = loc_grad_smooth(inputs).squeeze(0)

        grad_mask[grad_mask >= 0.5] = 1
        grad_mask[grad_mask < 0.5] = 0
        
        grad_mask = grad_mask.repeat((3, 1, 1))
        
        inputs[0] = inputs[0] * (1 - grad_mask)

        # # img_t = transforms.ToTensor()(inputs)
        # collage_t = torch.cat([inputs[0], grad_mask, inputs[0] * (1 - grad_mask)], dim=-1)
        # collage = transforms.ToPILImage()(collage_t)
        # collage.show()
        # result_path = cfg.get('TESTING')['result_path']
        # collage.save(result_path)



        logits = model(inputs)
        preds = torch.argmax(logits, dim=1)

        correct = preds == labels

        
        total_count += len(labels)
        correct_count += torch.sum(correct)

print(total_count, correct_count, correct_count / total_count)

10000 tensor(6464) tensor(0.6464)
