In [2]:
import torch.optim as optim
import torch.nn.functional as F
import random
import torch
import torch as tc
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import random
import torchvision
from torchvision import transforms
import torchvision.datasets as datasets
from Classes import *
from torch.distributions.beta import Beta
from torch.distributions.bernoulli import Bernoulli
import matplotlib.pyplot as plt
import numpy as np



normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])

transform_train = transforms.Compose([
    #transforms.RandomAffine(30, translate=(0.1, 0.1), scale=(0.8, 1.3)),
    #transforms.RandomResizedCrop(32, scale=(0.9, 1)),
    #transforms.RandomHorizontalFlip(p=0.5),
    transforms.Resize(32),
    transforms.ToTensor()])

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform_train)


class Masked_data(Dataset):
    def __init__(self, p= 0.5):
        self.data = mnist_trainset
        self.len = len(self.data)
        self.p = p
        self.image_shape = self.data[0][0].shape
        self.image_length = self.image_shape[1]**2

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        mask = tc.bernoulli(tc.ones(self.image_length)*self.p)
        
        masked_data_idx, target_idx = self.data[idx]
        masked_data_idx = masked_data_idx.flatten()

        masked_data_idx[mask==0] = self.get_random_values(mask==0)

        return mask, masked_data_idx.reshape(self.image_shape), target_idx

    def get_random_values(self, vector):
        length = vector.shape[0]
        randomvalues= [self.data[random.randint(0,self.len-1)][0].flatten()[i] for i,boolean in enumerate(vector) if boolean]
        
        return tc.tensor(randomvalues)



In [3]:
class Block(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Block,self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(input_dim, output_dim, kernel_size =3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(output_dim),
            nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=2),
            nn.ReLU()
        )

    def forward(self,x):
        return(self.layers(x))


class Net(nn.Module):
    def __init__(self, input_dim, dim1, dim2, in_features, n_classes):
        super(Net, self).__init__()

        self.layers = nn.Sequential(
            Block(input_dim, dim1),
            Block(dim1,dim2))

        self.avg = nn.AdaptiveAvgPool2d((1, 1))
        self.FCN = nn.Linear(in_features, n_classes)

    def forward(self,x):
        x = self.layers(x)
        x = self.avg(x)
        x = x.view(x.size(0), -1)
        x = F.log_softmax(self.FCN(x))
        return x       
        
class ShapleyEstimator(nn.Module):
    def __init__(self):
        super(ShapleyEstimator,self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(32**2, 10),

            )

    def forward(self,x):
        return self.layers(x)


In [17]:
def train_with_eval(epoch):
    net.train()
    correct = 0
    for batch_id, (data, target) in enumerate(train_loader):
        data = data.cuda()
        target = target.cuda()
        optimizer.zero_grad()
        out = net(data)
        prediction = out.data.max(1, keepdim=True)[1]
        correct += prediction.eq(target.data.view_as(prediction)).cpu().sum()

        criterion = F.nll_loss
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
    print('Epoche:', epoch)

    precision = correct.item() /len(train_loader.dataset)
    print('training_precision:', precision)
    
def test_with_eval(net, data, target):
    net.eval()
    correct = 0
    data = data.cuda()
    target = target.cuda()
    optimizer.zero_grad()
    out = net(data)
    prediction = out.data.max(1, keepdim=True)[1]

    criterion = F.nll_loss
    return criterion(out, target, reduction= 'none')



def test_relevance(net,shapleyestimator):
    criterion = F.mse_loss
    net.cuda(), shapleyestimator.cuda()
    shapleyoptimizer = optim.SGD(shapleyestimator.parameters(), lr = 0.0001)
    for mask,masked_data, target in masked_loader:
        shapleyestimator.zero_grad()
        mask, masked_data,target = mask.cuda(), masked_data.cuda(), target.cuda()
        
        with torch.no_grad():
            target_loss = test_with_eval(net, masked_data, target)

        
        loss_prediction = shapleyestimator(mask)
        Shapleyloss = criterion(target_loss[target], loss_prediction[target])
        print(target)
        
        Shapleyloss.backward()
        shapleyoptimizer.step()

    print(loss)
    
def return_relevance_map(sorter):
    target_size=10
    one_mask_dataset = One_mask()
    one_mask_dataloader = DataLoader(one_mask_dataset, batch_size = 32**2)
    results = tc.zeros(target_size,*one_mask_dataset[0].squeeze().shape).cuda()
    per_target_relevances = []
    for i in range(target_size): #target_size
        for data in one_mask_dataloader:
            unsorted_relevances = sorter(data.cuda(), tc.tensor([i]*32**2).cuda())
            per_target_relevances.append(np.array(unsorted_relevances.reshape((32,32)).detach().cpu()))

    return np.array(per_target_relevances)

In [18]:
net=Net(1,32,64,64,10).cuda()
optimizer = optim.Adam(net.parameters(), lr=0.001)
epoch =5

shapleyestimator = ShapleyEstimator()
masked_data= Masked_data()
masked_loader = DataLoader(masked_data, batch_size=10)
train_loader = DataLoader(mnist_trainset, batch_size=512)



In [19]:
if os.path.isfile('MNIST_net.pt'):
    print('net found')
    net = torch.load('MNIST_net.pt')
else:
    print('net not found')
    for epoch in range(10):
        train_with_eval(epoch)
        torch.save(net, 'MNIST_net.pt')
        



net found


In [20]:
test_relevance(net, shapleyestimator)



tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4], device='cuda:0')
tensor([3, 5, 3, 6, 1, 7, 2, 8, 6, 9], device='cuda:0')
tensor([4, 0, 9, 1, 1, 2, 4, 3, 2, 7], device='cuda:0')
tensor([3, 8, 6, 9, 0, 5, 6, 0, 7, 6], device='cuda:0')
tensor([1, 8, 7, 9, 3, 9, 8, 5, 9, 3], device='cuda:0')
tensor([3, 0, 7, 4, 9, 8, 0, 9, 4, 1], device='cuda:0')
tensor([4, 4, 6, 0, 4, 5, 6, 1, 0, 0], device='cuda:0')
tensor([1, 7, 1, 6, 3, 0, 2, 1, 1, 7], device='cuda:0')
tensor([9, 0, 2, 6, 7, 8, 3, 9, 0, 4], device='cuda:0')
tensor([6, 7, 4, 6, 8, 0, 7, 8, 3, 1], device='cuda:0')
tensor([5, 7, 1, 7, 1, 1, 6, 3, 0, 2], device='cuda:0')


KeyboardInterrupt: 