In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import memtorch
from memtorch.utils import LoadCIFAR10


class Net(nn.Module):
    def __init__(self, inflation_ratio=1):
        super(Net, self).__init__()
        self.conv0 = nn.Conv2d(in_channels=3, out_channels=128*inflation_ratio, kernel_size=3, stride=1, padding=1)
        self.bn0 = nn.BatchNorm2d(num_features=128*inflation_ratio)
        self.act0 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels=128*inflation_ratio, out_channels=128*inflation_ratio, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=128*inflation_ratio)
        self.act1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=128*inflation_ratio, out_channels=256*inflation_ratio, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_features=256*inflation_ratio)
        self.act2 = nn.ReLU()
        self.conv3 = nn.Conv2d(in_channels=256*inflation_ratio, out_channels=256*inflation_ratio, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(num_features=256*inflation_ratio)
        self.act3 = nn.ReLU()
        self.conv4 = nn.Conv2d(in_channels=256*inflation_ratio, out_channels=512*inflation_ratio, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(num_features=512*inflation_ratio)
        self.act4 = nn.ReLU()
        self.conv5 = nn.Conv2d(in_channels=512*inflation_ratio, out_channels=512, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(num_features=512)
        self.act5 = nn.ReLU()
        self.fc6 = nn.Linear(in_features=512*4*4, out_features=1024)
        self.bn6 = nn.BatchNorm1d(num_features=1024)
        self.act6 = nn.ReLU()
        self.fc7 = nn.Linear(in_features=1024, out_features=1024)
        self.bn7 = nn.BatchNorm1d(num_features=1024)
        self.act7 = nn.ReLU()
        self.fc8 = nn.Linear(in_features=1024, out_features=10)

    def forward(self, input):
        x = self.act0(self.bn0(self.conv0(input)))
        x = self.act1(self.bn1(F.max_pool2d(self.conv1(x), 2)))
        x = self.act2(self.bn2(self.conv2(x)))
        x = self.act3(self.bn3(F.max_pool2d(self.conv3(x), 2)))
        x = self.act4(self.bn4(self.conv4(x)))
        x = self.act5(self.bn5(F.max_pool2d(self.conv5(x), 2)))
        x = x.view(x.size(0), -1)
        x = self.act6(self.bn6(self.fc6(x)))
        x = self.act7(self.bn7(self.fc7(x)))
        return self.fc8(x)


def test(model, test_loader):
    correct = 0
    for batch_idx, (data, target) in enumerate(test_loader):
        output = model(data.to(device))
        pred = output.data.max(1)[1]
        correct += pred.eq(target.to(device).data.view_as(pred)).cpu().sum()

    return 100. * float(correct) / float(len(test_loader.dataset))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 50
train_loader, validation_loader, test_loader = LoadCIFAR10(batch_size=256, validation=False)
model = Net().to(device)
if device == 'cuda':
    model = torch.nn.DataParallel(model)

criterion = nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
best_accuracy = 0
for epoch in range(0, epochs):
    print('Epoch: [%d]\t\t' % (epoch + 1), end='')
    if epoch % 20 == 0:
        learning_rate = learning_rate * 0.1
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate

    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data.to(device))
        loss = criterion(output, target.to(device))
        loss.backward()
        optimizer.step()

    accuracy = test(model, test_loader)
    print('%2.2f%%' % accuracy)
    if accuracy > best_accuracy:
        torch.save(model.state_dict(), 'trained_model.pt')
        best_accuracy = accuracy

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data\cifar-10-python.tar.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting data\cifar-10-python.tar.gz to data
Files already downloaded and verified
Epoch: [1]		71.47%
Epoch: [2]		78.67%
Epoch: [3]		82.17%
Epoch: [4]		83.17%
Epoch: [5]		84.16%
Epoch: [6]		84.38%
Epoch: [7]		84.73%
Epoch: [8]		84.14%
Epoch: [9]		84.93%
Epoch: [10]		84.92%
Epoch: [11]		84.17%
Epoch: [12]		84.99%
Epoch: [13]		85.28%
Epoch: [14]		85.00%
Epoch: [15]		84.81%
Epoch: [16]		85.16%
Epoch: [17]		86.29%
Epoch: [18]		85.81%
Epoch: [19]		84.72%
Epoch: [20]		85.35%
Epoch: [21]		86.64%
Epoch: [22]		86.93%
Epoch: [23]		87.00%
Epoch: [24]		87.02%
Epoch: [25]		87.01%
Epoch: [26]		87.13%
Epoch: [27]		87.13%
Epoch: [28]		87.09%
Epoch: [29]		87.15%
Epoch: [30]		87.18%
Epoch: [31]		87.20%
Epoch: [32]		87.17%
Epoch: [33]		87.24%
Epoch: [34]		87.29%
Epoch: [35]		87.28%
Epoch: [36]		87.32%
Epoch: [37]		87.31%
Epoch: [38]		87.35%
Epoch: [39]		87.36%
Epoch: [40]		87.38%
Epoch: [41]		87.39%
Epoch: [42]		87.43%
Epoch: [43]		87.43%
Epoch: [44]		87.43%
Epoch: [45]		87.46%
Epoch: [46]		87.46%
Epoc

In [3]:
import copy
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import numpy as np


model = Net().to(device)
try:
    model.load_state_dict(torch.load('trained_model.pt'), strict=False)
except:
    raise Exception('trained_model.pt has not been found.')

print('Test Set Accuracy: \t%2.2f%%' % test(model, test_loader))

Test Set Accuracy: 	87.48%


In [4]:
import seaborn as sns


palette = ["#DA4453", "#8CC152", "#4A89DC", "#F6BB42", "#B600B0", "#535353"]




In [9]:
from memtorch.mn.Module import patch_model
from memtorch.map.Parameter import naive_map
from memtorch.bh.crossbar.Program import naive_program
from memtorch.bh.nonideality.NonIdeality import apply_nonidealities


def trial(r_on, r_off, sigma):
    model_ = copy.deepcopy(model)
    reference_memristor = memtorch.bh.memristor.VTEAM
    reference_memristor_params = {'time_series_resolution': 1e-10,
                                  'r_off': memtorch.bh.StochasticParameter(loc=r_off, scale=sigma*2, min=1),
                                  'r_on': memtorch.bh.StochasticParameter(loc=r_on, scale=sigma, min=1)}    

    patched_model = patch_model(copy.deepcopy(model_),
                              memristor_model=reference_memristor,
                              memristor_model_params=reference_memristor_params,
                              module_parameters_to_patch=[torch.nn.Linear, torch.nn.Conv2d],
                              mapping_routine=naive_map,
                              transistor=True,
                              programming_routine=None)

    patched_model.tune_()
    return test(patched_model, test_loader)

df = pd.DataFrame(columns=['sigma', 'test_set_accuracy'])
r_on = 200
r_off = 500
sigma_values = np.linspace(0, 100, 21)
for sigma in sigma_values:
    df = df.append({'sigma': sigma, 'test_set_accuracy': trial(r_on, r_off, sigma)}, ignore_index=True)

df.to_csv('variability.csv', index=False)

Patched Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) -> bh.Conv2d(in_channels=3, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Patched Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) -> bh.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Patched Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) -> bh.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Patched Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) -> bh.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Patched Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) -> bh.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


KeyboardInterrupt: 