In [1]:
import torch as ch
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import os

# Custom module imports
import utils

In [2]:
def extract_wb(m):
    combined = []
    features  = model.module.model.features
    classifier = model.module.model.classifier

    # Extract feature params
    for layer in features:
        if hasattr(layer, 'weight'):
            w, b = layer.weight.clone().detach(), layer.bias.clone().detach()
            w_, b_ = w.view(w.shape[0], -1), b.view(b.shape[0], -1)
            combined.append(ch.cat((w_, b_), -1))

    # Extract classifier params
    w, b = classifier.weight.clone().detach(), classifier.bias.clone().detach()
    w_, b_ = w.view(w.shape[0], -1), b.view(b.shape[0], -1)
    combined.append(ch.cat((w_, b_), -1))

    return combined

In [3]:
labels  = [
    0, 0,
    0, 0,
    0, 0,
    0, 0,
    0, 0,
    1, 1,
    1, 1,
    1, 1,
    1, 1,
    1, 1,
    1, 1]

paths = [
    "0p_linf", "0p_linf_2",
    "10p_linf", "10p_linf_2",
    "20p_linf", "20p_linf_2",
    "30p_linf", "30p_linf_2",
    "40p_linf", "40p_linf_2",
    "50p_linf", "50p_linf_2",
    "60p_linf", "60p_linf_2",
    "70p_linf", "70p_linf_2",
    "80p_linf", "80p_linf_2",
    "90p_linf", "90p_linf_2",
    "100p_linf", "100p_linf_2"]

In [4]:
class PIN_Model(nn.Module):
    def __init__(self):
        super(PIN_Model, self).__init__()
        self.rho = nn.Sequential(nn.Linear(33, 8), nn.ReLU(),
                                 nn.Linear(8, 1), nn.Sigmoid()).cuda()
    
    def forward(self, x):
        return self.rho(x)

In [5]:
# Use a dummy model to get required dimensionalities
constants = utils.BinaryCIFAR(None)
model = constants.get_model(None , "vgg19", parallel=True)
params = extract_wb(model)

In [6]:
def get_PIN_representations(params, phis):
    assert len(phis) == len(params)

    layer_reps = []
    prev_node_rep = None
    for phi, c in zip(phis, params):
        combined_c = c
        if prev_node_rep is not None:
            prev_nodes = ch.transpose(prev_node_rep.repeat(1, combined_c.shape[0]), 0, 1)
            combined_c = ch.cat((combined_c, prev_nodes), -1)
        node_rep = phi(combined_c.cuda()).clone().detach()
        layer_rep = ch.sum(node_rep, 0)
        layer_reps.append(layer_rep)
        prev_node_rep = node_rep

    model_rep = ch.cat(layer_reps)
    return model_rep

In [7]:
# Get phi-functions ready
phi_models = []
for i, param in enumerate(params):
    rs = param.shape[1]
    if i > 0: rs += params[i-1].shape[0]
    phi_models.append(nn.Sequential(nn.Linear(rs, 1), nn.ReLU()).cuda())

In [8]:
# Generate PI (permutation invariant) model representations
prefix = "/p/adversarialml/as9rw/new_exp_models/small/"
suffix = "checkpoint.pt.best"

reps = []
for path in tqdm(paths):
    model = constants.get_model(os.path.join(prefix, path, suffix) , "vgg19", parallel=True)
    params = extract_wb(model)
    
    model_rep = get_PIN_representations(params, phi_models)
    reps.append(model_rep)

  0%|          | 0/22 [00:00<?, ?it/s]

[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/0p_linf/checkpoint.pt.best'[0m
[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/0p_linf/checkpoint.pt.best' (epoch 136)[0m


  5%|▍         | 1/22 [00:01<00:40,  1.94s/it]

[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/0p_linf_2/checkpoint.pt.best'[0m


  9%|▉         | 2/22 [00:03<00:36,  1.83s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/0p_linf_2/checkpoint.pt.best' (epoch 11)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/10p_linf/checkpoint.pt.best'[0m


 14%|█▎        | 3/22 [00:05<00:33,  1.75s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/10p_linf/checkpoint.pt.best' (epoch 150)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/10p_linf_2/checkpoint.pt.best'[0m


 18%|█▊        | 4/22 [00:06<00:30,  1.71s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/10p_linf_2/checkpoint.pt.best' (epoch 126)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/20p_linf/checkpoint.pt.best'[0m


 23%|██▎       | 5/22 [00:08<00:28,  1.67s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/20p_linf/checkpoint.pt.best' (epoch 150)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/20p_linf_2/checkpoint.pt.best'[0m


 27%|██▋       | 6/22 [00:09<00:26,  1.64s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/20p_linf_2/checkpoint.pt.best' (epoch 16)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/30p_linf/checkpoint.pt.best'[0m


 32%|███▏      | 7/22 [00:11<00:24,  1.61s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/30p_linf/checkpoint.pt.best' (epoch 111)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/30p_linf_2/checkpoint.pt.best'[0m


 36%|███▋      | 8/22 [00:12<00:22,  1.60s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/30p_linf_2/checkpoint.pt.best' (epoch 136)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/40p_linf/checkpoint.pt.best'[0m


 41%|████      | 9/22 [00:14<00:20,  1.59s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/40p_linf/checkpoint.pt.best' (epoch 146)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/40p_linf_2/checkpoint.pt.best'[0m


 45%|████▌     | 10/22 [00:16<00:19,  1.59s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/40p_linf_2/checkpoint.pt.best' (epoch 141)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/50p_linf/checkpoint.pt.best'[0m


 50%|█████     | 11/22 [00:17<00:17,  1.59s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/50p_linf/checkpoint.pt.best' (epoch 150)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/50p_linf_2/checkpoint.pt.best'[0m


 55%|█████▍    | 12/22 [00:19<00:15,  1.59s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/50p_linf_2/checkpoint.pt.best' (epoch 146)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/60p_linf/checkpoint.pt.best'[0m


 59%|█████▉    | 13/22 [00:20<00:14,  1.58s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/60p_linf/checkpoint.pt.best' (epoch 136)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/60p_linf_2/checkpoint.pt.best'[0m


 64%|██████▎   | 14/22 [00:22<00:12,  1.57s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/60p_linf_2/checkpoint.pt.best' (epoch 121)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/70p_linf/checkpoint.pt.best'[0m


 68%|██████▊   | 15/22 [00:23<00:11,  1.57s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/70p_linf/checkpoint.pt.best' (epoch 136)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/70p_linf_2/checkpoint.pt.best'[0m


 73%|███████▎  | 16/22 [00:25<00:09,  1.58s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/70p_linf_2/checkpoint.pt.best' (epoch 141)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/80p_linf/checkpoint.pt.best'[0m


 77%|███████▋  | 17/22 [00:27<00:07,  1.57s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/80p_linf/checkpoint.pt.best' (epoch 146)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/80p_linf_2/checkpoint.pt.best'[0m


 82%|████████▏ | 18/22 [00:28<00:06,  1.57s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/80p_linf_2/checkpoint.pt.best' (epoch 150)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/90p_linf/checkpoint.pt.best'[0m


 86%|████████▋ | 19/22 [00:30<00:04,  1.57s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/90p_linf/checkpoint.pt.best' (epoch 1)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/90p_linf_2/checkpoint.pt.best'[0m


 91%|█████████ | 20/22 [00:31<00:03,  1.58s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/90p_linf_2/checkpoint.pt.best' (epoch 146)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/100p_linf/checkpoint.pt.best'[0m


 95%|█████████▌| 21/22 [00:33<00:01,  1.57s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/100p_linf/checkpoint.pt.best' (epoch 146)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/small/100p_linf_2/checkpoint.pt.best'[0m


100%|██████████| 22/22 [00:34<00:00,  1.59s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/small/100p_linf_2/checkpoint.pt.best' (epoch 146)[0m





In [9]:
reps[0].shape
all_reps = ch.stack(reps)

In [10]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score

X, y = all_reps.cpu().numpy(), np.array(labels)
pim_model = DecisionTreeClassifier(random_state=0, max_depth=2)
pim_model.fit(X, y)

print(pim_model.score(X, y))

0.9545454545454546


In [11]:
# Test classifier on trained models (unseen)

prefix = "/p/adversarialml/as9rw/new_exp_models/"
suffix = "checkpoint.pt.best"

paths_test = ["10p_linf", "50p_linf"]

reps_t = []
for path in tqdm(paths_test):
    model = constants.get_model(os.path.join(prefix, path, suffix) , "vgg19", parallel=True)
    params = extract_wb(model)
    
    model_rep = get_PIN_representations(params, phi_models)
    reps_t.append(model_rep)

reps_t = ch.stack(reps_t)

  0%|          | 0/2 [00:00<?, ?it/s]

[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/10p_linf/checkpoint.pt.best'[0m


 50%|█████     | 1/2 [00:02<00:02,  2.08s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/10p_linf/checkpoint.pt.best' (epoch 136)[0m
[35m=> loading checkpoint '/p/adversarialml/as9rw/new_exp_models/50p_linf/checkpoint.pt.best'[0m


100%|██████████| 2/2 [00:03<00:00,  1.93s/it]

[35m=> loaded checkpoint '/p/adversarialml/as9rw/new_exp_models/50p_linf/checkpoint.pt.best' (epoch 146)[0m





In [14]:
print(pim_model.predict(reps_t.cpu().numpy()))

[1 1]


In [46]:
pim_model = PIN_Model()

n_epochs = 100
loss_fn = nn.BCELoss()
optimizer = ch.optim.SGD(pim_model.parameters(), lr=.01)

labels_g = ch.from_numpy(np.array(labels)).float().cuda()
pim_model.train()

for i in range(n_epochs):
    optimizer.zero_grad()
    preds = pim_model(all_reps.cuda())
    
    loss = loss_fn(preds, labels_g)
    print(loss.item(), preds)
    loss.backward()
    optimizer.step()
    

0.7241067290306091 tensor([[0.6572],
        [0.5776],
        [0.5910],
        [0.6697],
        [0.6389],
        [0.6198],
        [0.6121],
        [0.6455],
        [0.7223],
        [0.5823],
        [0.6230],
        [0.5112],
        [0.6195],
        [0.6328],
        [0.6564],
        [0.6409],
        [0.6003],
        [0.6377],
        [0.5731],
        [0.6362],
        [0.5998],
        [0.6410]], device='cuda:0', grad_fn=<SigmoidBackward>)
2.3879756927490234 tensor([[0.0141],
        [0.0059],
        [0.0121],
        [0.0139],
        [0.0140],
        [0.0065],
        [0.0126],
        [0.0142],
        [0.0175],
        [0.0135],
        [0.0135],
        [0.0129],
        [0.0141],
        [0.0144],
        [0.0149],
        [0.0147],
        [0.0138],
        [0.0147],
        [0.0038],
        [0.0146],
        [0.0138],
        [0.0145]], device='cuda:0', grad_fn=<SigmoidBackward>)
1.501547932624817 tensor([[0.9586],
        [0.9748],
        [0.9589],
        

0.6890102624893188 tensor([[0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462],
        [0.5462]], device='cuda:0', grad_fn=<SigmoidBackward>)
0.6890102028846741 tensor([[0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461],
        [0.5461]], device='cuda:0', grad_fn=<SigmoidBackward>)
0.6890101432800293 tensor([[0.5461],
        [0.5461],
        [0.5461],
       