In [None]:
!pip install timm
!pip install shap
import sys
import shap
import numpy as np
import torch
import timm
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from google.colab import drive

drive.mount('/content/gdrive')

In [None]:

transform_train = transforms.Compose(
    [transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10), transforms.ToTensor(),])
transform_test = transforms.ToTensor()

batch_size = 20
lr = 0.00001
epochs = 100

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)
batch, _ = next(iter(trainloader))
images = batch[:2]
print(images.shape)
plt.imshow(images[0].permute(1, 2, 0))
plt.imshow(images[1].permute(1, 2, 0))

In [None]:


transform_train = transforms.Compose(
    [transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10), transforms.ToTensor(), transforms.GaussianBlur(3, 0.1)])
transform_test = transforms.ToTensor()

batch_size = 20
lr = 0.000001
epochs = 100

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
model = timm.create_model('ecaresnet26t', num_classes=10, pretrained=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load("/content/gdrive/My Drive/BasicModel/modeltrain3.pt"))
model.to(device)
print(f"Device {device}")
def trainAccuracy():
    
    model.eval()
    accuracy = 0.0
    total = 0.0
    
    with torch.no_grad():
        for data in trainloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            # run the model on the test set to predict labels
            outputs = model(images)
            # the label with the highest energy will be our prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            accuracy += (predicted == labels).sum().item()
    
    # compute the accuracy over all test images
    accuracy = (100 * accuracy / total)
    return(accuracy)
def testAccuracy():
    
    model.eval()
    accuracy = 0.0
    total = 0.0
    
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            # run the model on the test set to predict labels
            outputs = model(images)
            # the label with the highest energy will be our prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            accuracy += (predicted == labels).sum().item()
    
    # compute the accuracy over all test images
    accuracy = (100 * accuracy / total)
    return(accuracy)


criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-3)
for epoch in range(epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        sys.stdout.write(f"\repoch:{epoch + 1}, mini batch:{i + 1} loss: {running_loss/(i+1):.7f}")
        sys.stdout.flush()
    print()
    print(f'[epoch:{epoch + 1}, mini batch:{i + 1}] loss: {running_loss / 2500:.5f} accuracy train: {trainAccuracy()} accuracy test: {testAccuracy()}')
    running_loss = 0.0
    torch.save(model.state_dict(), f"/content/gdrive/My Drive/BasicModel/modeltrain{epoch+1}.pt")
print('Finished Training')

In [None]:
model = timm.create_model('ecaresnet26t', num_classes=10, pretrained=True)
model.load_state_dict(torch.load("/content/gdrive/My Drive/BasicModel/ecaresnet26t-CIFAR10.pt"))
model = model.to("cuda")

In [None]:
batch = next(iter(testloader))
images,labels = batch

background = images[:17].to("cuda")
test_images = images[17:20]
e = shap.DeepExplainer(model, background)
shap_values = e.shap_values(test_images)
shap_numpy = np.abs([np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]).mean(0)
test_numpy = (np.swapaxes(np.swapaxes(test_images.numpy(), 1, -1), 1, 2)  * 255).astype(np.uint8)
shap.plots.image(shap_numpy, test_numpy)


In [None]:
batch = next(iter(testloader))
images,labels = batch

background = images[:15].to("cuda")
test_images = images[15:20]
e = shap.DeepExplainer(model, background)
shap_values = e.shap_values(test_images)
shap_numpy = np.abs([np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]).mean(0)
test_numpy = (np.swapaxes(np.swapaxes(test_images.numpy(), 1, -1), 1, 2)  * 255).astype(np.uint8)
shap.plots.image(shap_numpy, test_numpy)


In [None]:
batch = next(iter(testloader))
images,labels = batch

background = images[:10].to("cuda")
test_images = images[10:20]
e = shap.DeepExplainer(model, background)
shap_values = e.shap_values(test_images)
shap_numpy = np.abs([np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]).mean(0)
test_numpy = (np.swapaxes(np.swapaxes(test_images.numpy(), 1, -1), 1, 2)  * 255).astype(np.uint8)
shap.plots.image(shap_numpy, test_numpy)
