In [1]:
import time, logging, torchvision, torch
import numpy as np

from copy import deepcopy
from multiprocessing import Pipe, Pool
from random import sample as random_sample
from random import seed
from time import localtime, strftime
from typing import Any, Generator, Sequence
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from z3 import *
from utils.dataclasses import ExperimentInfoGlobal, ExperimentInfoLocal, expr_info_global
from utils.networks import Mozafari2018
from utils.visual import log
from utils.types import MNIST_DoG_Data, MNIST_DoG_Target
from utils.SpykeTorch.SpykeTorch import utils as sutils

In [2]:
mozafari = Mozafari2018()
mozafari.load_state_dict(torch.load("saved.net"))

  mozafari.load_state_dict(torch.load("saved.net"))


<All keys matched successfully>

In [3]:
MNIST_train = sutils.CacheDataset(torchvision.datasets.MNIST(root="data", train=True, download=True,
                                                                 transform = Mozafari2018.generate_transform())) # type: ignore
MNIST_loader = DataLoader(MNIST_train, batch_size=4, num_workers=128, shuffle=False, pin_memory=True)

In [5]:
for batch in MNIST_loader:
    for sample, target in batch:
        sample = sample.squeeze(0)
        sample = sample.to(mozafari.device)
        target = target.to(mozafari.device)
        _out = mozafari(sample, 3)

ValueError: too many values to unpack (expected 2)

In [10]:
batch[1]

tensor([5, 0, 4, 1])

In [8]:
sample[0].shape

torch.Size([15, 6, 28, 28])

In [2]:
def run_test(e_info:ExperimentInfoLocal):
    log_name = f"{e_info.log_name}_{e_info.num_steps}_delta{e_info.delta}.log"
    logging.basicConfig(filename=f"{expr_info_global}/" + log_name, level=logging.INFO)
    log(e_info)

    seed(e_info.seed)
    np.random.seed(e_info.seed)
    # torch.manual_seed(e_info.seed)
    # torch.use_deterministic_algorithms(True)
    
    # mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
    # test_loader = DataLoader(mnist_test, batch_size=1, shuffle=True, drop_last=True)
    
    MNIST_train:Iterable[tuple[torch.Tensor, int]] # batch, (data, label), c, h, w
    MNIST_train = sutils.CacheDataset(torchvision.datasets.MNIST(root=e_info.data_root, train=True, download=True,
                                                                 transform = e_info.net_type.generate_transform())) # type: ignore
    
    
    log('Data is loaded')
    S = Solver()
    # spike_indicators = gen_spikes()
    spike_times = gen_spike_times()
    weights = gen_weights(weights_list)
    
    # Load equations.
    eqn_path = f'eqn/eqn_{num_steps}_{"_".join([str(i) for i in n_layer_neurons])}.txt'
    if not load_expr or not os.path.isfile(eqn_path):
        node_eqns = gen_node_eqns(weights, spike_times)
        S.add(node_eqns)
        # if cfg.np_level == 1:
        #     node_eqns.extend(gen_dnp_v2(weights, spike_indicators, potentials))
        # elif cfg.np_level == 2:
        #     node_eqns.extend(gen_gnp(weights, spike_indicators))
        if save_expr:
            try:
                with open(eqn_path, 'w') as f:
                    f.write(S.sexpr())
                    log("Node equations are saved.")
            except:
                pdb.set_trace(header="Failed to save node eqns.")
    else:
        S.from_file(eqn_path)
    log("Solver is loaded.")

    samples_no_list:List[int] = []
    sampled_imgs:List[TImage] = []
    orig_preds:List[int] = []
    for sample_no in random_sample([*range(len(images))], k=cfg.num_samples):
        log(f"sample {sample_no} is drawn.")
        samples_no_list.append(sample_no)
        img:TImage = images[sample_no]
        sampled_imgs.append(img) # type: ignore
        orig_preds.append(forward(weights_list, img))
    log(f"Sampling is completed with {num_procs} samples.")
    # data, target = next(iter(test_loader))
    # inp = spikegen.rate(data, num_steps=num_steps) # type: ignore
    # op = net.forward(inp.view(num_steps, -1))[0]
    # label = int(torch.cat(op).sum(dim=0).argmax())
    # log(f'single input ran in {time.time()-tx} sec')

    # For each delta
    for delta in cfg.deltas:
        global check_sample
        def check_sample(sample:Tuple[int, TImage, int]):
            sample_no, img, orig_pred = sample
            orig_neuron = (orig_pred, 0)
            tx = time.time()
            
            # # Input property terms
            prop = []
            # max_delta_per_neuron = min(1, delta)
            # max_delta_per_neuron = delta
            input_layer = 0
            deltas_list = []
            delta_pos = IntVal(0)
            delta_neg = IntVal(0)
            def relu(x): return If(x>0, x, 0)
            for in_neuron in get_layer_neurons_iter(input_layer):
                ## Try to avoid using abs, it makes z3 extremely slow.
                delta_pos += relu(spike_times[in_neuron, input_layer] - int(img[in_neuron]))
                delta_neg += relu(int(img[in_neuron]) - spike_times[in_neuron, input_layer])
                # neuron_spktime_delta = (
                #     typecast(ArithRef,
                #              Abs(spike_times[in_neuron, input_layer] - int(img[in_neuron]))))
                # prop.append(neuron_spktime_delta <= max_delta_per_neuron)
                # deltas_list.append(neuron_spktime_delta)
                # prop.append(spike_times[in_neuron,input_layer] == int(img[in_neuron]))
                # print(img[in_neuron], end = '\t')
            prop.append((delta_pos + delta_neg) <= delta)
            # prop.append(Sum(deltas_list) <= delta)
            log(f"Inputs Property Done in {time.time() - tx} sec")

            # Output property
            tx = time.time()
            op = []
            last_layer = len(n_layer_neurons)-1
            for out_neuron in get_layer_neurons_iter(last_layer):
                if out_neuron != orig_neuron:
                    # It is equal to Not(spike_times[out_neuron, last_layer] >= spike_times[orig_neuron, last_layer]),
                    # we are checking p and Not(q) and q = And(q1, q2, ..., qn)
                    # so Not(q) is Or(Not(q1), Not(q2), ..., Not(qn))
                    op.append(
                        spike_times[out_neuron, last_layer] <= spike_times[orig_neuron, last_layer]
                    )
            op = Or(op)
            log(f'Output Property Done in {time.time() - tx} sec')

            tx = time.time()
            S_instance = deepcopy(S)
            log(f'Network Encoding read in {time.time() - tx} sec')
            S_instance.add(op)
            S_instance.add(prop)
            log(f'Total model ready in {time.time() - tx}')

            log('Query processing starts')
            # set_param(verbose=2)
            # set_param("parallel.enable", True)
            tx = time.time()
            result = S_instance.check()
            log(f"Checking done in time {time.time() - tx}")
            if result == sat:
                log(f"Not robust for sample {sample_no} and delta={delta}")
            elif result == unsat:
                log(f"Robust for sample {sample_no} and delta={delta}")
            else:
                log(f"Unknown at sample {sample_no} for reason {S_instance.reason_unknown()}")
            log("")
            return result
        
        samples = zip(samples_no_list, sampled_imgs, orig_preds)
        if mp:
            with Pool(num_procs) as pool:
                pool.map(check_sample, samples)
                pool.close()
                pool.join()
        else:
            for sample in samples:
                check_sample(sample)

    log("")
    


In [None]:
expr_info = ExperimentInfoLocal(log_name="mnist", num_steps=15, delta=1, seed=42, data_root="data", net_type=Mozafari2018)
k = run_test(expr_info)

In [11]:
k[0][0].shape

torch.Size([15, 6, 28, 28])