# Novel Simulations

## 1. Define and train a VGG Convolutional Neural Network (CNN) using CIFAR-10

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import memtorch
from memtorch.utils import LoadCIFAR10


class Net(nn.Module):
    def __init__(self, inflation_ratio=1):
        super(Net, self).__init__()
        self.conv0 = nn.Conv2d(in_channels=3, out_channels=128*inflation_ratio, kernel_size=3, stride=1, padding=1)
        self.bn0 = nn.BatchNorm2d(num_features=128*inflation_ratio)
        self.act0 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels=128*inflation_ratio, out_channels=128*inflation_ratio, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=128*inflation_ratio)
        self.act1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=128*inflation_ratio, out_channels=256*inflation_ratio, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_features=256*inflation_ratio)
        self.act2 = nn.ReLU()
        self.conv3 = nn.Conv2d(in_channels=256*inflation_ratio, out_channels=256*inflation_ratio, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(num_features=256*inflation_ratio)
        self.act3 = nn.ReLU()
        self.conv4 = nn.Conv2d(in_channels=256*inflation_ratio, out_channels=512*inflation_ratio, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(num_features=512*inflation_ratio)
        self.act4 = nn.ReLU()
        self.conv5 = nn.Conv2d(in_channels=512*inflation_ratio, out_channels=512, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(num_features=512)
        self.act5 = nn.ReLU()
        self.fc6 = nn.Linear(in_features=512*4*4, out_features=1024)
        self.bn6 = nn.BatchNorm1d(num_features=1024)
        self.act6 = nn.ReLU()
        self.fc7 = nn.Linear(in_features=1024, out_features=1024)
        self.bn7 = nn.BatchNorm1d(num_features=1024)
        self.act7 = nn.ReLU()
        self.fc8 = nn.Linear(in_features=1024, out_features=10)

    def forward(self, input):
        x = self.act0(self.bn0(self.conv0(input)))
        x = self.act1(self.bn1(F.max_pool2d(self.conv1(x), 2)))
        x = self.act2(self.bn2(self.conv2(x)))
        x = self.act3(self.bn3(F.max_pool2d(self.conv3(x), 2)))
        x = self.act4(self.bn4(self.conv4(x)))
        x = self.act5(self.bn5(F.max_pool2d(self.conv5(x), 2)))
        x = x.view(x.size(0), -1)
        x = self.act6(self.bn6(self.fc6(x)))
        x = self.act7(self.bn7(self.fc7(x)))
        return self.fc8(x)
    
    
def test(model, test_loader):
    correct = 0
    for batch_idx, (data, target) in enumerate(test_loader):        
        output = model(data.to(device))
        pred = output.data.max(1)[1]
        correct += pred.eq(target.to(device).data.view_as(pred)).cpu().sum()

    return 100. * float(correct) / float(len(test_loader.dataset))

device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')
epochs = 50
train_loader, validation_loader, test_loader = LoadCIFAR10(batch_size=256, validation=False)
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
best_accuracy = 0
for epoch in range(0, epochs):
    print('Epoch: [%d]\t\t' % (epoch + 1), end='')
    if epoch % 20 == 0:
        learning_rate = learning_rate * 0.1
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate

    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data.to(device))
        loss = criterion(output, target.to(device))
        loss.backward()
        optimizer.step()

    accuracy = test(model, test_loader)
    print('%2.2f%%' % accuracy)
    if accuracy > best_accuracy:
        torch.save(model.state_dict(), 'trained_model.pt')
        best_accuracy = accuracy

## 2. Load and test the network

In [None]:
import copy
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import numpy as np


model = Net().to(device)
try:
    model.load_state_dict(torch.load('trained_model.pt'), strict=False)
except:
    raise Exception('trained_model.pt has not been found.')
    
print('Test Set Accuracy: \t%2.2f%%' % test(model, test_loader))

## 3. Import seaborn and define an appropriate color-palette

In [None]:
import seaborn as sns


palette = ["#DA4453", "#8CC152", "#4A89DC", "#F6BB42", "#B600B0", "#535353"]

## 4. Device-device variability investigation

In [None]:
from memtorch.mn.Module import patch_model
from memtorch.map.Parameter import naive_map
from memtorch.bh.crossbar.Program import naive_program
from memtorch.bh.nonideality.NonIdeality import apply_nonidealities


def trial(r_on, r_off, sigma):
    model_ = copy.deepcopy(model)
    reference_memristor = memtorch.bh.memristor.VTEAM
    reference_memristor_params = {'time_series_resolution': 1e-10,
                                  'r_off': memtorch.bh.StochasticParameter(r_off, std=sigma*2, min=1),
                                  'r_on': memtorch.bh.StochasticParameter(r_on, std=sigma, min=1)}

    patched_model = patch_model(copy.deepcopy(model_),
                              memristor_model=reference_memristor,
                              memristor_model_params=reference_memristor_params,
                              module_parameters_to_patch=[torch.nn.Linear, torch.nn.Conv2d],
                              mapping_routine=naive_map,
                              transistor=True,
                              programming_routine=None)
    
    patched_model.tune_()
    return test(patched_model, test_loader)

df = pd.DataFrame(columns=['sigma', 'test_set_accuracy'])
r_on = 200
r_off = 500
sigma_values = np.linspace(0, 100, 21)
for sigma in sigma_values:
    df = df.append({'sigma': sigma, 'test_set_accuracy': trial(r_on, r_off, sigma)}, ignore_index=True)
    
df.to_csv('variability.csv', index=False)

In [None]:
df = pd.read_csv('variability.csv')
f = plt.figure()
plt.axhline(y=10, color='k', linestyle='--', zorder=1)
b = plt.bar(df['sigma'], df['test_set_accuracy'], width=2.5, zorder=2)
plt.xlabel('$\sigma$')
plt.ylabel('CIFAR-10 Test-set Accuracy (%)')
for bar in b:
    bar.set_edgecolor('black')
    bar.set_facecolor(palette[0])
    bar.set_linewidth(1)
    
f.tight_layout()
plt.grid()
plt.savefig("P1.svg")
plt.show()

## 5. Finite conductance states investigation

In [None]:
from memtorch.mn.Module import patch_model
from memtorch.map.Parameter import naive_map
from memtorch.bh.crossbar.Program import naive_program
from memtorch.bh.nonideality.NonIdeality import apply_nonidealities


def trial(r_on, r_off, finite_states):
    model_ = copy.deepcopy(model)
    reference_memristor = memtorch.bh.memristor.VTEAM
    reference_memristor_params = {'time_series_resolution': 1e-10,
                                  'r_off': r_off,
                                  'r_on': r_on}

    patched_model = patch_model(copy.deepcopy(model_),
                              memristor_model=reference_memristor,
                              memristor_model_params=reference_memristor_params,
                              module_parameters_to_patch=[torch.nn.Linear, torch.nn.Conv2d],
                              mapping_routine=naive_map,
                              transistor=True,
                              programming_routine=None)
    
    patched_model = apply_nonidealities(patched_model,
                            non_idealities=[memtorch.bh.nonideality.NonIdeality.FiniteConductanceStates],
                            conductance_states = int(finite_states))
    
    patched_model.tune_()
    return test(patched_model, test_loader)

df = pd.DataFrame(columns=['finite_states', 'test_set_accuracy'])
r_on = 200
r_off = 500
finite_state_values = np.linspace(1, 10, 10)
for finite_states in finite_state_values:
    df = df.append({'finite_states': finite_states, 'test_set_accuracy': trial(r_on, r_off, finite_states)}, ignore_index=True)
    
df.to_csv('finite_states.csv', index=False)

In [None]:
df = pd.read_csv('finite_states.csv')
f = plt.figure()
plt.axhline(y=10, color='k', linestyle='--', zorder=1)
b = plt.bar(df['finite_states'], df['test_set_accuracy'], width=0.5, zorder=2)
plt.xlabel('Number of Finite States')
plt.ylabel('CIFAR-10 Test-set Accuracy (%)')
plt.xticks(df['finite_states'])
for bar in b:
    bar.set_edgecolor('black')
    bar.set_facecolor(palette[0])
    bar.set_linewidth(1)
    
f.tight_layout()
plt.grid()
plt.savefig("P2.svg")
plt.show()

## 6. Device failure investigation

In [None]:
from memtorch.mn.Module import patch_model
from memtorch.map.Parameter import naive_map
from memtorch.bh.crossbar.Program import naive_program
from memtorch.bh.nonideality.NonIdeality import apply_nonidealities


def trial(r_on, r_off, lrs_proportion, hrs_proportion):
    model_ = copy.deepcopy(model)
    reference_memristor = memtorch.bh.memristor.VTEAM
    reference_memristor_params = {'time_series_resolution': 1e-10,
                                  'r_off': r_off,
                                  'r_on': r_on}

    patched_model = patch_model(copy.deepcopy(model_),
                              memristor_model=reference_memristor,
                              memristor_model_params=reference_memristor_params,
                              module_parameters_to_patch=[torch.nn.Linear, torch.nn.Conv2d],
                              mapping_routine=naive_map,
                              transistor=True,
                              programming_routine=None)
    
    patched_model = apply_nonidealities(patched_model,
                            non_idealities=[memtorch.bh.nonideality.NonIdeality.DeviceFaults],
                            lrs_proportion=lrs_proportion,
                            hrs_proportion=hrs_proportion,
                            electroform_proportion=0)
    
    patched_model.tune_()
    return test(patched_model, test_loader)

df_lrs_hrs = pd.DataFrame(columns=['failure_percentage', 'test_set_accuracy'])
df_lrs = pd.DataFrame(columns=['failure_percentage', 'test_set_accuracy'])
df_hrs = pd.DataFrame(columns=['failure_percentage', 'test_set_accuracy'])
r_on = 200
r_off = 500
failures = np.linspace(0, 0.25, 11)

for failure in failures:
    df_lrs_hrs = df_lrs_hrs.append({'failure_percentage': failure, 'test_set_accuracy': trial(r_on, r_off, failure, failure)}, ignore_index=True)
    df_lrs = df_lrs.append({'failure_percentage': failure, 'test_set_accuracy': trial(r_on, r_off, failure, 0)}, ignore_index=True)    
    df_hrs = df_hrs.append({'failure_percentage': failure, 'test_set_accuracy': trial(r_on, r_off, 0, failure)}, ignore_index=True)  
    
df_lrs_hrs.to_csv('failure_lrs_hrs.csv', index=False)
df_lrs.to_csv('failure_lrs.csv', index=False)
df_hrs.to_csv('failure_hrs.csv', index=False)

In [None]:
import seaborn as sns
from matplotlib.ticker import FormatStrFormatter


df_lrs_hrs = pd.read_csv('failure_lrs_hrs.csv')
df_lrs = pd.read_csv('failure_lrs.csv')
df_hrs = pd.read_csv('failure_hrs.csv')
f = plt.figure()
plt.axhline(y=10, color='k', linestyle='--', zorder=1)
concat = pd.concat([df_lrs_hrs['failure_percentage'], 
                  df_lrs_hrs['test_set_accuracy'], 
                  df_lrs['test_set_accuracy'], 
                  df_hrs['test_set_accuracy']], 
                  axis=1)
concat.columns = ['failure_percentage', 'lrs_hrs', 'lrs', 'hrs']
data = pd.DataFrame(columns=['failure_percentage', 'state', 'test_set_accuracy'])
for index, row in concat.iterrows():
    data = data.append({'failure_percentage': row['failure_percentage'] * 100, 'state': 'lrs_hrs', 'test_set_accuracy': row['lrs_hrs']}, ignore_index=True)
    data = data.append({'failure_percentage': row['failure_percentage'] * 100, 'state': 'lrs', 'test_set_accuracy': row['lrs']}, ignore_index=True)
    data = data.append({'failure_percentage': row['failure_percentage'] * 100, 'state': 'hrs', 'test_set_accuracy': row['hrs']}, ignore_index=True)
    
data['state'] = data['state'].map({'lrs_hrs': '$R_{ON}$ and $R_{OFF}$', 'lrs': '$R_{ON}$', 'hrs': '$R_{OFF}$'})
h = sns.barplot(x="failure_percentage", hue="state", y="test_set_accuracy", data=data, zorder=2, edgecolor='black', linewidth='1', palette=sns.color_palette(palette), saturation=1)
plt.gca().xaxis.set_major_formatter(FormatStrFormatter('%1.1f'))
plt.xticks(np.arange(11), np.arange(0, 25 + 2.5, step=2.5))
h.legend(loc=1)
plt.xlabel('Device Failure (%)')
plt.ylabel('CIFAR-10 Test-set Accuracy (%)') 
f.tight_layout()
plt.grid()
plt.savefig("P3.svg")
plt.show()

## 7. First novel simulation

In [None]:
from memtorch.mn.Module import patch_model
from memtorch.map.Parameter import naive_map
from memtorch.bh.crossbar.Program import naive_program
from memtorch.bh.nonideality.NonIdeality import apply_nonidealities


def trial(num_conductance_states, g_ratio, sigma):
    model_ = copy.deepcopy(model)
    r_on = 200
    reference_memristor = memtorch.bh.memristor.VTEAM
    reference_memristor_params = {'time_series_resolution': 1e-10,
                                  'r_off': memtorch.bh.StochasticParameter(r_on * g_ratio, std=sigma*2, min=1),
                                  'r_on': memtorch.bh.StochasticParameter(r_on, std=sigma, min=1)}
    patched_model = patch_model(copy.deepcopy(model_),
                              memristor_model=reference_memristor,
                              memristor_model_params=reference_memristor_params,
                              module_parameters_to_patch=[torch.nn.Linear, torch.nn.Conv2d],
                              mapping_routine=naive_map,
                              transistor=True,
                              programming_routine=None)
    patched_model = apply_nonidealities(patched_model,
                            non_idealities=[memtorch.bh.nonideality.NonIdeality.FiniteConductanceStates],
                            conductance_states=int(num_conductance_states))
    patched_model.tune_()
    return test(patched_model, test_loader)

std_devs = [0, 20, 100]
g_ratios = [2 ** n for n in range(6)]
conductance_states = np.linspace(2, 10, 9)
for std_dev in std_devs:
    df = pd.DataFrame(columns=['conductance_states', 'g_ratio', 'test_set_accuracy'])
    for g_ratio in g_ratios:
        for num_conductance_states in conductance_states:
            test_set_accuracy = trial(num_conductance_states, g_ratio, std_dev)
            df = df.append({'conductance_states': num_conductance_states, 
                            'g_ratio': g_ratio, 
                            'test_set_accuracy': test_set_accuracy}, ignore_index=True)
    
    df.to_csv('S1_std_dev_%d.csv' % std_dev, index=False)

In [None]:
import seaborn as sns
from matplotlib.ticker import FormatStrFormatter

f = plt.figure(figsize=(16, 4))
plt.subplot(1, len(std_devs), 1)
for plot_index, std_dev in enumerate(std_devs):
    plt.subplot(1, len(std_devs), plot_index + 1)
    plt.axhline(y=10, color='k', linestyle='--', zorder=1)
    data = pd.read_csv('S1_std_dev_%d.csv' % std_dev)
    h = sns.barplot(x="conductance_states", hue="g_ratio", y="test_set_accuracy", data=data, zorder=2, edgecolor='black', linewidth='1', palette=sns.color_palette(palette), saturation=1)
    plt.gca().xaxis.set_major_formatter(FormatStrFormatter('%1.0f'))
    plt.xticks(np.arange(0, len(conductance_states)), map(lambda n: "%d" % n, conductance_states))
    leg = h.axes.get_legend()
    leg.set_title('RON/ROFF Ratio')
    h.legend(loc=1)
    plt.title('$\sigma$ = %d' % std_dev)
    if plot_index == 0:
        plt.xlabel('Number of Finite States')
        plt.ylabel('CIFAR-10 Test-set Accuracy (%)')
    else:
        plt.xlabel('')
        plt.ylabel('')
        
    plt.grid()
     
f.tight_layout()
plt.savefig("S1.svg")
plt.show()

## 8. Second novel simulation

In [None]:
from memtorch.mn.Module import patch_model
from memtorch.map.Parameter import naive_map
from memtorch.bh.crossbar.Program import naive_program
from memtorch.bh.nonideality.NonIdeality import apply_nonidealities


def trial(num_conductance_states, lrs_failure_rate, hrs_failure_rate, sigma):
    model_ = copy.deepcopy(model)
    r_on = 200
    r_off = 500
    reference_memristor = memtorch.bh.memristor.VTEAM
    reference_memristor_params = {'time_series_resolution': 1e-10,
                                  'r_off': memtorch.bh.StochasticParameter(r_on * g_ratio, std=sigma*2, min=1),
                                  'r_on': memtorch.bh.StochasticParameter(r_off, std=sigma, min=1)}
    patched_model = patch_model(copy.deepcopy(model_),
                              memristor_model=reference_memristor,
                              memristor_model_params=reference_memristor_params,
                              module_parameters_to_patch=[torch.nn.Linear, torch.nn.Conv2d],
                              mapping_routine=naive_map,
                              transistor=True,
                              programming_routine=None)
    patched_model = apply_nonidealities(patched_model,
                            non_idealities=[memtorch.bh.nonideality.NonIdeality.FiniteConductanceStates,
                                            memtorch.bh.nonideality.NonIdeality.DeviceFaults],
                            conductance_states=int(num_conductance_states),
                            lrs_proportion=lrs_failure_rate,
                            hrs_proportion=hrs_failure_rate,
                            electroform_proportion=0)
    patched_model.tune_()
    return test(patched_model, test_loader)

std_devs = [0, 20, 100]
failure_rates = np.linspace(0, 0.25, 6)
conductance_states = np.linspace(2, 10, 9)

# LRS
for std_dev in std_devs:
    df = pd.DataFrame(columns=['conductance_states', 'failure_rate', 'test_set_accuracy'])
    for failure_rate in failure_rates:
        for num_conductance_states in conductance_states:
            test_set_accuracy = trial(num_conductance_states, failure_rate, 0, std_dev)
            df = df.append({'conductance_states': num_conductance_states, 
                            'failure_rate': failure_rate, 
                            'test_set_accuracy': test_set_accuracy}, ignore_index=True)
    
    df.to_csv('S2_LRS_std_dev_%d.csv' % std_dev, index=False)
    
# HRS
for std_dev in std_devs:
    df = pd.DataFrame(columns=['conductance_states', 'failure_rate', 'test_set_accuracy'])
    for failure_rate in failure_rates:
        for num_conductance_states in conductance_states:
            test_set_accuracy = trial(num_conductance_states, 0, failure_rate, std_dev)
            df = df.append({'conductance_states': num_conductance_states, 
                            'failure_rate': failure_rate, 
                            'test_set_accuracy': test_set_accuracy}, ignore_index=True)
    
    df.to_csv('S2_HRS_std_dev_%d.csv' % std_dev, index=False)
    
# LRS and HRS
for std_dev in std_devs:
    df = pd.DataFrame(columns=['conductance_states', 'failure_rate', 'test_set_accuracy'])
    for failure_rate in failure_rates:
        for num_conductance_states in conductance_states:
            test_set_accuracy = trial(num_conductance_states, failure_rate, failure_rate, std_dev)
            df = df.append({'conductance_states': num_conductance_states, 
                            'failure_rate': failure_rate, 
                            'test_set_accuracy': test_set_accuracy}, ignore_index=True)
    
    df.to_csv('S2_LRS_HRS_std_dev_%d.csv' % std_dev, index=False)

In [None]:
import seaborn as sns
from matplotlib.ticker import FormatStrFormatter

f = plt.figure(figsize=(16, 12))
plt.subplot(3, len(std_devs), 1)
# LRS
for plot_index, std_dev in enumerate(std_devs):
    plt.subplot(3, len(std_devs), plot_index + 1)
    plt.axhline(y=10, color='k', linestyle='--', zorder=1)
    data = pd.read_csv('S2_LRS_std_dev_%d.csv' % std_dev)
    h = sns.barplot(x="conductance_states", hue="failure_rate", y="test_set_accuracy", data=data, zorder=2, edgecolor='black', linewidth='1', palette=sns.color_palette(palette), saturation=1)
    plt.gca().xaxis.set_major_formatter(FormatStrFormatter('%1.0f'))
    plt.xticks(np.arange(0, len(conductance_states)), map(lambda n: "%d" % n, conductance_states))
    leg = h.axes.get_legend()
    leg.set_title('Device Failure (%)')
    h.legend(loc=1)
    plt.title('$\sigma$ = %d' % std_dev)
    if plot_index == 0:
        plt.xlabel('Number of Finite States')
        plt.ylabel('[LRS] CIFAR-10 Test-set Accuracy (%)')
    else:
        plt.xlabel('')
        plt.ylabel('')
        
    plt.grid()
        
# HRS
for plot_index, std_dev in enumerate(std_devs):
    plt.subplot(3, len(std_devs), plot_index + 1 + len(std_devs))
    plt.axhline(y=10, color='k', linestyle='--', zorder=1)
    data = pd.read_csv('S2_HRS_std_dev_%d.csv' % std_dev)
    h = sns.barplot(x="conductance_states", hue="failure_rate", y="test_set_accuracy", data=data, zorder=2, edgecolor='black', linewidth='1', palette=sns.color_palette(palette), saturation=1)
    plt.gca().xaxis.set_major_formatter(FormatStrFormatter('%1.0f'))
    plt.xticks(np.arange(0, len(conductance_states)), map(lambda n: "%d" % n, conductance_states))
    leg = h.axes.get_legend()
    leg.set_title('Device Failure (%)')
    h.legend(loc=1)
    plt.title('$\sigma$ = %d' % std_dev)
    if plot_index == 0:
        plt.xlabel('Number of Finite States')
        plt.ylabel('[HRS] CIFAR-10 Test-set Accuracy (%)')
    else:
        plt.xlabel('')
        plt.ylabel('')
        
    plt.grid()
        
# LRS and HRS
for plot_index, std_dev in enumerate(std_devs):
    plt.subplot(3, len(std_devs), plot_index + 1 + 2 * len(std_devs))
    plt.axhline(y=10, color='k', linestyle='--', zorder=1)
    data = pd.read_csv('S2_LRS_HRS_std_dev_%d.csv' % std_dev)
    h = sns.barplot(x="conductance_states", hue="failure_rate", y="test_set_accuracy", data=data, zorder=2, edgecolor='black', linewidth='1', palette=sns.color_palette(palette), saturation=1)
    plt.gca().xaxis.set_major_formatter(FormatStrFormatter('%1.0f'))
    plt.xticks(np.arange(0, len(conductance_states)), map(lambda n: "%d" % n, conductance_states))
    leg = h.axes.get_legend()
    leg.set_title('Device Failure (%)')
    h.legend(loc=1)
    plt.title('$\sigma$ = %d' % std_dev)
    if plot_index == 0:
        plt.xlabel('Number of Finite States')
        plt.ylabel('[LRS and HRS] CIFAR-10 Test-set Accuracy (%)')
    else:
        plt.xlabel('')
        plt.ylabel('')
        
    plt.grid()
        
f.tight_layout()
plt.savefig("S2.svg")
plt.show()