In [2]:
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

ModuleNotFoundError: No module named 'seaborn'

### Load the pretrained model

In [None]:
ckpt = torch.load("./save/vgg7_quant/vgg7_quant_w4_a4_mode_mean_asymm_wd0.0_swipe_train/model_best.pth.tar")
state_dict = ckpt["state_dict"]

### Get the weights of the last layer

In [None]:
weight = state_dict['features.17.weight']
print("Weight size = {}".format(list(weight.size())))

### Low precision weight

In [None]:
from models import quant
# precision
nbit = 4
cellBit = 1

# quantize
weight_q, wscale = quant.stats_quant(weight, nbit=nbit, dequantize=False)
weight_q = weight_q.add(7)
print("Unique levels of the {}bit weight: \n{}".format(nbit, weight_q.unique().cpu().numpy()))

weight_b = quant.decimal2binary(weight_q, nbit, cellBit)
print("\nBinary weight size = {}".format(list(weight_b.size())))

In [None]:
def binary2dec(wbit, weight_b, cellBit):
    weight_int = 0
    cellRange = 2**cellBit
    for k in range(wbit//cellBit):
        remainder = weight_b[k]
        scaler = cellRange**k
        weight_int += scaler*remainder
    return weight_int

### Conductance

In [None]:
hrs, lrs = 1e-6, 1.66e-4
nonideal_unit = lrs - hrs

### Scenario 0: Typicall value only

In [None]:
wb = weight_b.clone()
w_ref = quant.bit2cond(wb, hrs, lrs)
w_ref_q = w_ref.div(nonideal_unit)

# ideally quantized
wq_ideal = binary2dec(nbit, weight_b, cellBit=cellBit)
# typical value
wq_typicall = binary2dec(nbit, w_ref_q, cellBit=cellBit)

### Scenario 1: SWIPE for all the levels

In [None]:
swipe_ll = [-1]
w_swipe = quant.program_noise_cond(weight_q, weight_b, hrs, lrs, swipe_ll)
w_swipe = w_swipe.div(nonideal_unit)

# swipe
wq_swipe = binary2dec(nbit, w_swipe, cellBit=cellBit)

In [None]:
ql = wq_ideal.unique().cpu().numpy()
print(ql)
plt.figure(figsize=(10,6))
plt.scatter(ql, np.zeros(ql.shape), marker='s', s=100)
sns.distplot(wq_swipe.view(-1).cpu().numpy())
plt.xticks([ii for ii in range(15)])
plt.title("4-bit Weight Programmed with SWIPE scheme", fontsize=16, fontweight='bold')
plt.grid(True)
plt.savefig("./save/figs/swipe_all_4bit.png", bbox_inches = 'tight', pad_inches = 0.1)

### Scenario 2: Non-SWIPE for level 7

In [None]:
swipe_ll = [7]
w_swipe = quant.program_noise_cond(weight_q, weight_b, hrs, lrs, swipe_ll)
w_swipe = w_swipe.div(nonideal_unit)

# swipe
wq_swipe = binary2dec(nbit, w_swipe, cellBit=cellBit)

In [None]:
plt.figure(figsize=(10,6))
plt.scatter(ql, np.zeros(ql.shape), marker='s', s=100)
sns.distplot(wq_swipe.view(-1).cpu().numpy())
plt.xticks([ii for ii in range(15)])
plt.title("4-bit Weight Programmed with SWIPE scheme except level 7", fontsize=16, fontweight='bold')
plt.grid(True)
plt.savefig("./save/figs/nonswipe7_4bit.png", bbox_inches = 'tight', pad_inches = 0.1)

### Scenario 3: Non-SWIPE for level 7, 8, 9

In [None]:
swipe_ll = [7,8,9]
w_swipe = quant.program_noise_cond(weight_q, weight_b, hrs, lrs, swipe_ll)
w_swipe = w_swipe.div(nonideal_unit)

# swipe
wq_swipe = binary2dec(nbit, w_swipe, cellBit=cellBit)

In [None]:
plt.figure(figsize=(10,6))
plt.scatter(ql, np.zeros(ql.shape), marker='s', s=100)
sns.distplot(wq_swipe.view(-1).cpu().numpy())
plt.xticks([ii for ii in range(15)])
plt.title("4-bit Weight Programmed with SWIPE scheme except level 7 8 9", fontsize=16, fontweight='bold')
plt.grid(True)
plt.savefig("./save/figs/nonswipe789_4bit.png", bbox_inches = 'tight', pad_inches = 0.1)

### Scenario 4: Non-SWIPE for level 6,7,8,9

In [None]:
swipe_ll = [6,7,8,9]
w_swipe = quant.program_noise_cond(weight_q, weight_b, hrs, lrs, swipe_ll)
w_swipe = w_swipe.div(nonideal_unit)

# swipe
wq_swipe = binary2dec(nbit, w_swipe, cellBit=cellBit)

In [None]:
plt.figure(figsize=(10,6))
plt.scatter(ql, np.zeros(ql.shape), marker='s', s=100)
sns.distplot(wq_swipe.view(-1).cpu().numpy())
plt.xticks([ii for ii in range(15)])
plt.title("4-bit Weight Programmed with SWIPE scheme except level 6 7 8 9", fontsize=16, fontweight='bold')
plt.grid(True)
plt.savefig("./save/figs/nonswipe6789_4bit.png", bbox_inches = 'tight', pad_inches = 0.1)

### Scenario 4: Non-SWIPE for level for all

In [None]:
swipe_ll = [ii for ii in range(15)]
w_swipe = quant.program_noise_cond(weight_q, weight_b, hrs, lrs, swipe_ll)
w_swipe = w_swipe.div(nonideal_unit)

# swipe
wq_swipe = binary2dec(nbit, w_swipe, cellBit=cellBit)

In [None]:
plt.figure(figsize=(10,6))
plt.scatter(ql, np.zeros(ql.shape), marker='s', s=100)
sns.distplot(wq_swipe.view(-1).cpu().numpy())
plt.xticks([ii for ii in range(15)])
plt.title("4-bit Weight Programmed with Non-SWIPE scheme", fontsize=16, fontweight='bold')
plt.grid(True)
plt.savefig("./save/figs/nonswipe_4bit.png", bbox_inches = 'tight', pad_inches = 0.1)

### Layer level statistics

In [None]:
total = weight_q.numel()
swipe = [3,7,8,9]
swipe_perc = 0
all_perc = 0
for ii in weight_q.unique():
    n = weight_q[weight_q==ii].numel()
    perc = n/total * 100
    if ii in swipe:
        swipe_perc += perc
    print("Level: {}; Percentage: {:.3f}%".format(int(ii),perc))
    all_perc += perc
print("{:.2f}% of weights are programmed with SWIPE; {:.2f}% of weights are programmed by Non-SWIPE scheme".format(swipe_perc, all_perc-swipe_perc))

### Model level statistics

In [None]:
total_w = 0
level_element = np.zeros(15)
for k, v in state_dict.items():
    if len(v.size()) == 4 and v.size(1) > 3:
        wq, wscale = quant.stats_quant(v, nbit=nbit, dequantize=False)
        wq = wq.add(7)
        total_w += wq.numel()
        
        layer_element = []
        for ii in wq.unique():
            n = wq[wq==ii].numel()
            layer_element.append(n)
        print(layer_element)
        level_element += np.array(layer_element)
perc = level_element / total_w * 100

In [None]:
swipe_perc = 0
swipe = [6, 7,8,9]
for ii, p in enumerate(perc):
    if ii in swipe:
        swipe_perc += p
print("Percentage of {} = {:.2f}".format(swipe, swipe_perc))