## Required imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [None]:
!pip install memtorch

## Define and train the CNN

In [None]:
import torch
from torch.autograd import Variable
import memtorch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from memtorch.utils import LoadCIFAR10
import numpy as np


class Net(nn.Module):
    def __init__(self, inflation_ratio=1):
        super(Net, self).__init__()
        self.conv0 = nn.Conv2d(in_channels=3, out_channels=128*inflation_ratio, kernel_size=3, stride=1, padding=1)
        self.bn0 = nn.BatchNorm2d(num_features=128*inflation_ratio)
        self.act0 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels=128*inflation_ratio, out_channels=128*inflation_ratio, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=128*inflation_ratio)
        self.act1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=128*inflation_ratio, out_channels=256*inflation_ratio, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_features=256*inflation_ratio)
        self.act2 = nn.ReLU()
        self.conv3 = nn.Conv2d(in_channels=256*inflation_ratio, out_channels=256*inflation_ratio, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(num_features=256*inflation_ratio)
        self.act3 = nn.ReLU()
        self.conv4 = nn.Conv2d(in_channels=256*inflation_ratio, out_channels=512*inflation_ratio, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(num_features=512*inflation_ratio)
        self.act4 = nn.ReLU()
        self.conv5 = nn.Conv2d(in_channels=512*inflation_ratio, out_channels=512, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(num_features=512)
        self.act5 = nn.ReLU()
        self.fc6 = nn.Linear(in_features=512*4*4, out_features=1024)
        self.bn6 = nn.BatchNorm1d(num_features=1024)
        self.act6 = nn.ReLU()
        self.fc7 = nn.Linear(in_features=1024, out_features=1024)
        self.bn7 = nn.BatchNorm1d(num_features=1024)
        self.act7 = nn.ReLU()
        self.fc8 = nn.Linear(in_features=1024, out_features=10)

    def forward(self, input):
        x = self.act0(self.bn0(self.conv0(input)))
        x = self.act1(self.bn1(F.max_pool2d(self.conv1(x), 2)))
        x = self.act2(self.bn2(self.conv2(x)))
        x = self.act3(self.bn3(F.max_pool2d(self.conv3(x), 2)))
        x = self.act4(self.bn4(self.conv4(x)))
        x = self.act5(self.bn5(F.max_pool2d(self.conv5(x), 2)))
        x = x.view(x.size(0), -1)
        x = self.act6(self.bn6(self.fc6(x)))
        x = self.act7(self.bn7(self.fc7(x)))
        return self.fc8(x)

def test(model, test_loader):
    correct = 0
    for batch_idx, (data, target) in enumerate(test_loader):        
        output = model(data.to(device))
        pred = output.data.max(1)[1]
        correct += pred.eq(target.to(device).data.view_as(pred)).cpu().sum()

    return 100. * float(correct) / float(len(test_loader.dataset))

device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')
epochs = 60
train_loader, validation_loader, test_loader = LoadCIFAR10(batch_size=512, validation=False)
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
best_accuracy = 0
for epoch in range(0, epochs):
    print('Epoch: [%d]\t\t' % (epoch + 1), end='')
    if epoch % 20 == 0:
        learning_rate = learning_rate * 0.1
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate

    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data.to(device))
        loss = criterion(output, target.to(device))
        loss.backward()
        optimizer.step()

    accuracy = test(model, test_loader)
    print('%2.2f%%' % accuracy)
    if accuracy > best_accuracy:
        print('Saving model...')
        torch.save(model.state_dict(), 'trained_model.pt')
        best_accuracy = accuracy

## Validate the CNN

In [None]:
import torch
import memtorch
from memtorch.utils import LoadCIFAR10
import numpy as np


def test(model, test_loader):
    correct = 0
    for batch_idx, (data, target) in enumerate(test_loader):
        output = model(data.to(device))
        pred = output.data.max(1)[1]
        correct += pred.eq(target.to(device).data.view_as(pred)).cpu().sum()

    return 100. * float(correct) / float(len(test_loader.dataset))

device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')
train_loader, validation_loader, test_loader = LoadCIFAR10(batch_size=264, validation=False)
model = Net().to(device)
try:
    model.load_state_dict(torch.load('trained_model.pt'), strict=False)
except:
    raise Exception('trained_model.pt has not been found.')

print('Test Set Accuracy: \t%2.2f%%' % test(model, test_loader))

## Device endurance (gradual) simulations

In [None]:
import enum
from enum import Enum, auto
from memtorch.mn.Module import supported_module_parameters
import math
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from memtorch.mn.Module import patch_model
from memtorch.map.Module import naive_tune
from memtorch.map.Parameter import naive_map
from memtorch.bh.crossbar.Program import naive_program
from memtorch.bh.nonideality.NonIdeality import apply_nonidealities
from memtorch.bh.crossbar.Crossbar import init_crossbar
import copy
from pprint import pprint


def minimal_tune(model):
    for i, (name, m) in enumerate(list(model.named_modules())):
        if hasattr(m, 'tune'):
            m.transform_output = lambda input: input
            if isinstance(m, memtorch.mn.Conv2d):
                try:
                    m.transform_output = naive_tune(m, (4, m.in_channels, 8, 8))
                except:
                    pass
            if isinstance(m, memtorch.mn.Linear):
                try:
                    m.transform_output = naive_tune(m, (64, m.in_features))
                except:
                    pass
                
    return model
    
def update_patched_model(patched_model, model):
    for i, (name, m) in enumerate(list(patched_model.named_modules())):
        if isinstance(m, memtorch.mn.Conv2d) or isinstance(m, memtorch.mn.Linear):
            pos_conductance_matrix, neg_conductance_matrix = naive_map(getattr(model, name).weight.data, r_on, r_off,scheme=memtorch.bh.Scheme.DoubleColumn)
            m.crossbars[0].write_conductance_matrix(pos_conductance_matrix, transistor=True, programming_routine=None)
            m.crossbars[1].write_conductance_matrix(neg_conductance_matrix, transistor=True, programming_routine=None)
            m.weight.data = getattr(model, name).weight.data
            
    return patched_model
    
scale_input = interp1d([1.3, 1.9], [0, 1])
def scale_p_0(p_0, p_1, v_stop, cell_size=10):
    scaled_input = scale_input(v_stop)
    x = 1.50
    y = p_0 * math.exp(p_1 * cell_size)
    k = math.log10(y) / (1 - (2 * scale_input(x) - 1) ** (2))
    new_y = 10 ** (k * (1 - (2 * scaled_input - 1) ** (2)))
    # Backsolve for p_0
    p_0 = new_y / math.exp(p_1 * cell_size)
    return p_0

def gradual(input, cycle_count, p_1, p_2, p_3, cell_size):
    p_0 = torch.log10(input)
    threshold = p_1 * math.exp(p_2 * cell_size)
    return torch.pow(10, (p_3 * cell_size * math.log10(cycle_count) + torch.log10(10 **  p_0) - p_3 * cell_size * math.log10(threshold)))
       
def model_gradual(layer, cycle_count, v_stop):
    cell_size = 10
    convergence_point = 1e4
    p_1_lrs = 1.0399076623425807
    p_2_lrs = 0.9171208448973687
    p_3_lrs = 0.0143551595777695
    p_1_hrs = 4.3590883730463410
    p_2_hrs = 0.7738077425228179
    p_3_hrs = -0.018865423084966
    p_1_lrs = scale_p_0(p_1_lrs, p_2_lrs, v_stop)
    p_1_hrs = scale_p_0(p_1_hrs, p_2_hrs, v_stop)
    threshold_lrs = p_1_lrs * math.exp(p_2_lrs * cell_size)
    threshold_hrs = p_1_hrs * math.exp(p_2_hrs * cell_size)
    for i in range(len(layer.crossbars)):
        input = 1 / layer.crossbars[i].conductance_matrix
        if input[input < convergence_point].nelement() > 0:
            if cycle_count > threshold_lrs:
                input[input < convergence_point] = gradual(input[input < convergence_point], cycle_count, p_1_lrs, p_2_lrs, p_3_lrs, cell_size)
        if input[input > convergence_point].nelement() > 0:
            if cycle_count > threshold_hrs:
                input[input > convergence_point] = gradual(input[input > convergence_point], cycle_count, p_1_hrs, p_2_hrs, p_3_hrs, cell_size)
                
        layer.crossbars[i].conductance_matrix = 1 / input

    return layer

def model_degradation(model, cycle_count, v_stop):
    for i, (name, m) in enumerate(list(model.named_modules())):
        if type(m) in supported_module_parameters.values():
            setattr(model, name, model_gradual(m, cycle_count, v_stop)) # setattr(model.module, name, model_gradual(m, cycle_count, v_stop))

                    
    return model

device = torch.device('cuda')
batch_size = 64
train_loader, validation_loader, test_loader = LoadCIFAR10(batch_size=batch_size, validation=False)
reference_memristor = memtorch.bh.memristor.VTEAM
r_on = 4400
r_off = 65000
reference_memristor_params = {'time_series_resolution': 1e-10,
                              'r_off': r_off,
                              'r_on': r_on}
times_to_reprogram = 10 ** np.arange(1, 10, dtype=np.float64)
v_stop_values = np.linspace(1.3, 1.9, 10, endpoint=True)
df = pd.DataFrame(columns=['times_reprogramed', 'v_stop', 'test_set_accuracy'])
for time_to_reprogram in times_to_reprogram:
    cycle_count = len(train_loader.dataset) * time_to_reprogram
    for v_stop in v_stop_values:
        print('time_to_reprogram: %f, v_stop: %f' % (time_to_reprogram, v_stop))
        model = Net().to(device)
        model.load_state_dict(torch.load('trained_model.pt'), strict=False)
        patched_model = patch_model(model,
                                  memristor_model=reference_memristor,
                                  memristor_model_params=reference_memristor_params,
                                  module_parameters_to_patch=[torch.nn.Linear, torch.nn.Conv2d],
                                  mapping_routine=naive_map,
                                  transistor=True,
                                  programming_routine=None,
                                  p_l=None,
                                  scheme=memtorch.bh.Scheme.DoubleColumn)
        patched_model = model_degradation(patched_model, cycle_count, v_stop)
        patched_model = minimal_tune(patched_model)
        accuracy = test(patched_model, test_loader)
        del patched_model
        del model
        df = df.append({'times_reprogramed': time_to_reprogram, 'v_stop': v_stop, 'test_set_accuracy': accuracy}, ignore_index=True)
        df.to_csv('endurance_gradual.csv', index=False)

## Device endurance (sudden) simulations

In [None]:
import enum
from enum import Enum, auto
from memtorch.mn.Module import supported_module_parameters
import math
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from memtorch.mn.Module import patch_model
from memtorch.map.Module import naive_tune
from memtorch.map.Parameter import naive_map
from memtorch.bh.crossbar.Program import naive_program
from memtorch.bh.nonideality.NonIdeality import apply_nonidealities
from memtorch.bh.crossbar.Crossbar import init_crossbar
import copy

def minimal_tune(model):
    for i, (name, m) in enumerate(list(model.named_modules())):
        if hasattr(m, 'tune'):
            m.transform_output = lambda input: input
            if isinstance(m, memtorch.mn.Conv2d):
                try:
                    m.transform_output = naive_tune(m, (4, m.in_channels, 8, 8))
                except:
                    pass
            if isinstance(m, memtorch.mn.Linear):
                try:
                    m.transform_output = naive_tune(m, (64, m.in_features))
                except:
                    pass
                
    return model
    
def update_patched_model(patched_model, model):
    for i, (name, m) in enumerate(list(patched_model.named_modules())):
        if isinstance(m, memtorch.mn.Conv2d) or isinstance(m, memtorch.mn.Linear):
            pos_conductance_matrix, neg_conductance_matrix = naive_map(getattr(model, name).weight.data, r_on, r_off,scheme=memtorch.bh.Scheme.DoubleColumn)
            m.crossbars[0].write_conductance_matrix(pos_conductance_matrix, transistor=True, programming_routine=None)
            m.crossbars[1].write_conductance_matrix(neg_conductance_matrix, transistor=True, programming_routine=None)
            m.weight.data = getattr(model, name).weight.data
            
    return patched_model

scale_input = interp1d([1.3, 1.9], [0, 1])
def scale_p_0(p_0, p_1, v_stop, cell_size=10):
    scaled_input = scale_input(v_stop)
    x = 1.45
    y = p_0 * cell_size + p_1
    k = math.log10(y) / (1 - (2 * scale_input(x) - 1) ** (2))
    new_y = 10 ** (k * (1 - (2 * scaled_input - 1) ** (2)))
    p_0 = (new_y - p_1) / cell_size
    return p_0
       
def model_sudden(layer, cycle_count, v_stop):
    cell_size = 10
    p_1 = 0
    p_0 = 2e7 / cell_size
    p_0 = scale_p_0(p_0, p_1, v_stop)
    threshold = p_0 * cell_size
    if cycle_count > threshold:
        for i in range(len(layer.crossbars)):
            input = layer.crossbars[i].conductance_matrix
            input[input < (1 / 2e4)] = 1 / 2e4
            layer.crossbars[i].conductance_matrix = input
            
    return layer

def model_degradation(model, cycle_count, v_stop):
    for i, (name, m) in enumerate(list(model.named_modules())):
        if type(m) in supported_module_parameters.values():
            setattr(model, name, model_sudden(m, cycle_count, v_stop))
             
    return model

device = torch.device('cuda')
batch_size = 64
train_loader, validation_loader, test_loader = LoadCIFAR10(batch_size=batch_size, validation=False)
reference_memristor = memtorch.bh.memristor.VTEAM
r_on = 2.00e4
r_off = 10.75e4
reference_memristor_params = {'time_series_resolution': 1e-10,
                              'r_off': r_off,
                              'r_on': r_on}
times_to_reprogram = 10 ** np.arange(1, 10, dtype=np.float64)
v_stop_values = np.linspace(1.3, 1.9, 10, endpoint=True)
df = pd.DataFrame(columns=['times_reprogramed', 'v_stop', 'test_set_accuracy'])
for time_to_reprogram in times_to_reprogram:
    cycle_count = len(train_loader.dataset) * time_to_reprogram
    for v_stop in v_stop_values:
        print('time_to_reprogram: %f, v_stop: %f' % (time_to_reprogram, v_stop))
        model = Net().to(device)
        model.load_state_dict(torch.load('trained_model.pt'), strict=False)
        patched_model = patch_model(model,
                                  memristor_model=reference_memristor,
                                  memristor_model_params=reference_memristor_params,
                                  module_parameters_to_patch=[torch.nn.Linear, torch.nn.Conv2d],
                                  mapping_routine=naive_map,
                                  transistor=True,
                                  programming_routine=None,
                                  p_l=None,
                                  scheme=memtorch.bh.Scheme.DoubleColumn)
        patched_model = model_degradation(patched_model, cycle_count, v_stop)
        patched_model = minimal_tune(patched_model)
        accuracy = test(patched_model, test_loader)
        del patched_model
        del model
        df = df.append({'times_reprogramed': time_to_reprogram, 'v_stop': v_stop, 'test_set_accuracy': accuracy}, ignore_index=True)
        df.to_csv('endurance_sudden.csv', index=False)

## Device retention (gradual) simulations

In [None]:
import enum
from enum import Enum, auto
import math
import pandas as pd
from memtorch.mn.Module import supported_module_parameters
import math
from memtorch.mn.Module import patch_model
from memtorch.map.Module import naive_tune
from memtorch.map.Parameter import naive_map
from memtorch.bh.crossbar.Program import naive_program
from memtorch.bh.nonideality.NonIdeality import apply_nonidealities
from memtorch.bh.crossbar.Crossbar import init_crossbar
import copy


class OperationMode(Enum):
    sudden = auto()
    gradual = auto()

    
def minimal_tune(model):
    for i, (name, m) in enumerate(list(model.named_modules())):
        if hasattr(m, 'tune'):
            m.transform_output = lambda input: input
            if isinstance(m, memtorch.mn.Conv2d):
                m.transform_output = naive_tune(m, (4, m.in_channels, 8, 8))
            if isinstance(m, memtorch.mn.Linear):
                m.transform_output = naive_tune(m, (64, m.in_features))
                
    return model
    
def gradual(input, x, p_0, p_1, p_2, p_3, cell_size, tempurature):
    p_0 = p_0 / input
    if tempurature > 298:
        input = torch.pow(10, p_0 * tempurature + torch.log10(input) - p_0 * 298)
        
    p_0 = torch.log10(input)
    threshold = p_1 * np.exp(p_2 * cell_size)
    return torch.pow(10, (p_3 * cell_size * math.log10(x) + torch.log10(10 **  p_0) - p_3 * cell_size * math.log10(threshold)))
    
def model_gradual(layer, x, tempurature):
    cell_size = 10
    convergence_point = 180000
    p_0_lrs = 0.000801158151673717 * 2400
    p_0_hrs = 0.00420717061486765 * 55000
    p_1 = 0.5
    p_2 = 0.7600902459542083
    p_3_lrs = 0.006489105105825544  
    p_3_hrs = -0.007240917429683966
    threshold_lrs = p_1 * math.exp(p_2 * cell_size)
    threshold_hrs = p_1 * math.exp(p_2 * cell_size)
    for i in range(len(layer.crossbars)):
        input = 1 / layer.crossbars[i].conductance_matrix
        if input[input < convergence_point].nelement() > 0:
            if x > threshold_lrs:
                input[input < convergence_point] = gradual(input[input < convergence_point], time_, p_0_lrs, p_1, p_2, p_3_lrs, cell_size, tempurature)
        if input[input > convergence_point].nelement() > 0:
            if x > threshold_hrs:
                input[input > convergence_point] = gradual(input[input > convergence_point], time_, p_0_hrs, p_1, p_2, p_3_hrs, cell_size, tempurature)

        layer.crossbars[i].conductance_matrix = 1 / input

    return layer
    
def model_degradation(model, time_, operation_mode, tempurature):
    for i, (name, m) in enumerate(list(model.named_modules())):
        if type(m) in supported_module_parameters.values():
            if len(name.split('.')) > 1:
                name = name.split('.')[1]

            if operation_mode == OperationMode.gradual:
                if hasattr(model, 'module'):
                    setattr(model.module, name, model_gradual(m, time_, tempurature))
                else:
                    setattr(model, name, model_gradual(m, time_, tempurature))
            elif operation_mode == OperationMode.sudden:
                if hasattr(model, 'module'):
                    setattr(model.module, name, model_sudden(m, time_, tempurature))
                else:
                    setattr(model, name, model_sudden(m, time_, tempurature))
                    
    return model

device = torch.device('cuda')
train_loader, validation_loader, test_loader = LoadCIFAR10(batch_size=64, validation=False)
reference_memristor = memtorch.bh.memristor.VTEAM
r_on = 9e4
r_off = 330000
reference_memristor_params = {'time_series_resolution': 1e-10,
                              'r_off': r_off,
                              'r_on': r_on}
times = 10 ** np.arange(1, 10, dtype=np.float64)
tempuratures = np.linspace(75, 175, 10, endpoint=True)
df = pd.DataFrame(columns=['time', 'tempurature', 'test_set_accuracy'])
for time_ in times:
    print(time_)
    for tempurature in tempuratures:
        tempurature += 273
        model = Net().to(device)
        model.load_state_dict(torch.load('trained_model.pt'), strict=False)
        patched_model = patch_model(model,
                                  memristor_model=reference_memristor,
                                  memristor_model_params=reference_memristor_params,
                                  module_parameters_to_patch=[torch.nn.Linear, torch.nn.Conv2d],
                                  mapping_routine=naive_map,
                                  transistor=True,
                                  programming_routine=None,
                                  p_l=None,
                                  scheme=memtorch.bh.Scheme.DoubleColumn)

        patched_model = minimal_tune(patched_model)
        patched_model = model_degradation(patched_model, time_, OperationMode.gradual, tempurature)
        accuracy = test(patched_model, test_loader)
        del patched_model
        df = df.append({'time': time_, 'tempurature': tempurature, 'test_set_accuracy': accuracy}, ignore_index=True)
        df.to_csv('retention_gradual.csv', index=False)