In [1]:
import RingDataset
from Models import CNNModel, RNNGenerator, Distiller, MLP, RNNModel, QGRU
import os
import numpy as np
import pickle
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import re
import time

In [25]:
testset =  RingDataset.RingDataset('core4ToSlice3_test.pkl', threshold=42)

testloader = DataLoader(testset, batch_size=256, num_workers=4)
classifier_test = CNNModel(42, dim=256).cuda()
studentdim = 32
gen=QGRU(42, scale=0.25, dim=studentdim, drop=0.0)
assert os.path.isfile('./models/best_{}_{}.pth'.format('qgru', studentdim))
gen.load_state_dict(torch.load('./models/best_{}_{}.pth'.format('qgru', studentdim)))



<All keys matched successfully>

In [3]:
def shifter(arr, window=32):
    dup = arr[:,None,:].expand(arr.size(0), arr.size(1)+1, arr.size(1))
    dup2 = dup.reshape(arr.size(0), arr.size(1), arr.size(1)+1)
    shifted = dup2[:,:window,:-window]
    return shifted

In [10]:
class RNNGen2(nn.Module):
    def __init__(self, gen):
        super().__init__()
        self.encoder = gen.encoder
        self.quant = torch.quantization.QuantStub()
        self.resblock = gen.resblock

        self.decoder = gen.decoder
        self.dequant = torch.quantization.DeQuantStub()

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return weight.new_zeros(2, bsz, 16)
                
    def forward(self, x, hidden):
        x = self.quant(x)
        encoded = self.encoder(x.permute(0,2,1)) #N,C,S -> N,S,C
        hidden = self.quant(hidden)
        res, hidden = self.resblock(encoded, hidden)
        out = encoded + res #N,S,C
        out = self.decoder(out).view(out.size(0),-1)
        #out = out + self.scale*torch.randn_like(out)
        #out = out + noise
        
        return self.dequant(torch.relu(out))
gen2 = RNNGen2(gen)

In [26]:
model = gen
model.eval()

model_int8 = torch.quantization.quantize_dynamic(
    model, {nn.GRUCell, nn.Linear}, dtype=torch.qint8
)
print(model_int8)

input_fp32, _ = next(iter(testloader))
shifted = shifter(input_fp32)

'''
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
model_fp32_prepared = torch.quantization.prepare(model)

# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset

model_fp32_prepared(shifted, hidden_fp32)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
model_int8 = torch.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
'''


# run the model, relevant calculations will happen in int8
res = model_int8(shifted)
res2 = model(shifted)
print(res-res2)

QGRU(
  (encoder): Sequential(
    (0): DynamicQuantizedLinear(in_features=32, out_features=32, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    (1): Dropout(p=0.0, inplace=False)
  )
  (gru1): DynamicQuantizedGRUCell(32, 32)
  (gru2): DynamicQuantizedGRUCell(32, 32)
  (decoder): Sequential(
    (0): DynamicQuantizedLinear(in_features=32, out_features=1, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
  )
)
tensor([[ 0.1273,  0.0283,  0.0000,  ..., -0.0147,  0.0943, -0.1522],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.3346,  0.0323, -0.0957],
        [-0.5370,  0.0000,  0.0000,  ..., -0.1543,  0.5100, -0.4143],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.9470, -0.3757,  0.4048],
        [ 0.0000, -0.1592,  0.0000,  ...,  0.4305, -0.1351,  0.4698],
        [-0.0582,  0.0000,  0.0000,  ...,  0.0000, -0.2967, -0.8013]],
       grad_fn=<SubBackward0>)


In [27]:
def quantizer(arr, std=8):
    return torch.round(arr*std)/std

testset =  RingDataset.RingDataset('core4ToSlice3_test.pkl', threshold=42)
valset =  RingDataset.RingDataset('core4ToSlice3_valid.pkl', threshold=42)

testloader = DataLoader(testset, batch_size=128, num_workers=4)
valloader = DataLoader(valset, batch_size=128, num_workers=4)

classifier_test = CNNModel(42, dim=256).cuda()
criterion = nn.CrossEntropyLoss()

lastacc = 0.0
lastnorm = 0.0
optim_c2 = torch.optim.Adam(classifier_test.parameters(), lr=1e-4)

halfstudent = model_int8
cooldown = 30
#halfstudent.eval()
for e in range(cooldown):
    classifier_test.train()
    for x,y in valloader:
        xdata, ydata = x.cuda(), y.cuda()
        shifted = shifter(xdata.cpu())
        #train classifier
        optim_c2.zero_grad()
        perturb = halfstudent(shifted).view(shifted.size(0),-1).cuda()
        #perturb = gen(xdata[:,31:])
        #interleaving?
        output = classifier_test(xdata[:,31:]+perturb.detach().float())
        loss_c = criterion(output, ydata)
        loss_c.backward()
        optim_c2.step()


    mloss = 0.0
    totcorrect = 0
    totcount = 0
    mnorm = 0.0
    zerocorrect = 0
    zerocount = 0
    onecorrect = 0
    onecount = 0
    #evaluate classifier

    with torch.no_grad():
        classifier_test.eval()
        for x,y in testloader:
            xdata, ydata = x.cuda(), y.cuda()
            shifted = shifter(xdata.cpu())
            perturb = halfstudent(shifted).view(shifted.size(0),-1).cuda()
            perturb = quantizer(perturb)
            #perturb = gen(xdata[:,31:])
            norm = torch.mean(perturb)
            output = classifier_test(xdata[:,31:]+perturb.float())
            loss_c = criterion(output, ydata)
            pred = output.argmax(axis=-1)
            mnorm += norm.item()/len(testloader)
            mloss += loss_c.item()/len(testloader)
            #macc += ((pred==ydata).sum().float()/pred.nelement()).item()/len(testloader)
            totcorrect += (pred==ydata).sum().item()
            totcount += y.size(0)
            zerocorrect += ((pred==0)*(ydata==0)).sum().item()
            zerocount += (ydata==0).sum().item()
            onecorrect += ((pred==1)*(ydata==1)).sum().item()
            onecount += (ydata==1).sum().item()
        macc = float(totcorrect)/totcount
        zacc = float(zerocorrect)/zerocount
        oacc = float(onecorrect)/onecount
        print("epoch {} \t zacc {:.6f}\t oneacc {:.6f}\t loss {:.6f}\t Avg perturb {:.6f}\n".format(e+1, zacc, oacc, mloss, mnorm))
        if cooldown - e <= 10:
            lastacc += macc/10
            lastnorm += mnorm/10
print("Last 10 acc: {:.6f}\t perturb: {:.6f}".format(lastacc,lastnorm))



epoch 1 	 zacc 0.000000	 oneacc 1.000000	 loss 0.693106	 Avg perturb 1.681546

epoch 2 	 zacc 0.000000	 oneacc 1.000000	 loss 0.693088	 Avg perturb 1.682122

epoch 3 	 zacc 0.000000	 oneacc 1.000000	 loss 0.693102	 Avg perturb 1.682553

epoch 4 	 zacc 0.000000	 oneacc 1.000000	 loss 0.693103	 Avg perturb 1.682439

epoch 5 	 zacc 0.000000	 oneacc 1.000000	 loss 0.693083	 Avg perturb 1.683011

epoch 6 	 zacc 0.000000	 oneacc 1.000000	 loss 0.693082	 Avg perturb 1.681719

epoch 7 	 zacc 0.000000	 oneacc 1.000000	 loss 0.693087	 Avg perturb 1.681907

epoch 8 	 zacc 0.000827	 oneacc 0.998759	 loss 0.693056	 Avg perturb 1.682583

epoch 9 	 zacc 0.004136	 oneacc 0.994210	 loss 0.693048	 Avg perturb 1.682250

epoch 10 	 zacc 0.031431	 oneacc 0.978495	 loss 0.693018	 Avg perturb 1.681912

epoch 11 	 zacc 0.114557	 oneacc 0.914392	 loss 0.692978	 Avg perturb 1.682365

epoch 12 	 zacc 0.234905	 oneacc 0.792390	 loss 0.692896	 Avg perturb 1.682087

epoch 13 	 zacc 0.403226	 oneacc 0.622002	 loss 0