In [1]:
from insi import Probe, Probes, Cortex 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torchvision import datasets, transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(784, 100)
        self.l2 = nn.Linear(100, 10)
    
    def forward(self, x):
        x = self.l1(x)
        x = F.relu(x)
        x = self.l2(x)
        return x 
    
# Initialize your neural network model and objective function
model = MLP()
model.load_state_dict(torch.load("saved/mlp_model.pth"))

def objective(pred):
    return F.cross_entropy(pred, torch.tensor(9))

#set input to mnist images 
data =  datasets.MNIST('../data', train=True, download=True,                    transform=transforms.Compose([
                       transforms.ToTensor()
                   ]))
input = []
for x,y in data:
    xcat = x.view(x.shape[0], -1)[0]
    input.append(xcat)

In [48]:
# Create probes
num_probes = 784
probes =  {i: Probe() for i in range(num_probes)}

# Create a Probes collection
probes_collection = Probes(probes)

# Initialize Cortex instance
cortex = Cortex(probes_collection, model, objective)

In [None]:
# Tune the neural network using probes
cortex.tune(epochs=1, lr=0.1, input=input)

In [31]:
test_data =  datasets.MNIST('../data', train=False, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                   ]))

test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=1, shuffle=True,)

def get_acc(model):
    model.eval()
    test_loss = 0
    correct = 0
    for Xb, Yb in test_loader:
        Yb = torch.ones_like(Yb) * 9 
        xcat = Xb.view(Xb.shape[0], -1) # concatenate the vectors
        xcat += probes_collection.get_values()
        output = model(xcat)        
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(Yb.data.view_as(pred)).long().sum()

    test_loss /= len(test_loader.dataset)
    return 100.0 * correct / len(test_loader.dataset)


In [32]:
get_acc(model)

tensor(99.0500)