In [1]:
import sys
sys.path.append("..")
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pickle
import random
import numpy as np
import copy
import json
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

In [2]:
import wandb
wandb.init(mode="disabled")

In [3]:
from das.Classification_Model import (MLPForClassification,
                                  train_model,
                                  eval_model,
                                  test_model,
                                  make_model)
from das.Helper_Functions import set_seed
from das.Dataset_Generation import (make_model_dataset,
                                make_model_dataset_AndOrAnd,
                                make_intervention_dataset_variable_intervention_all,
                                make_intervention_dataset_variable_intervention_first,
                                make_intervention_dataset_first_input_intervention,
                                make_intervention_dataset_AndOrAnd,
                                make_intervention_dataset_AndOr)
from das.RevNet import RevNet
from das.Rotation_Model import Rotation
from das.DAS import phi_class
from das.DAS_MLP import Distributed_Alignment_Search_MLP

In [4]:
DEVICE  = "cpu" #"cuda"/"cpu"
num_classes=2

In [5]:
def register_intervention_hook(Save_array,Pos,layer):
    def hook_fn(module, input, output):
        Save_array[Pos].append(output.detach().cpu())
    layer.register_forward_hook(hook_fn)

In [6]:
results=[]
for acseed in [4287]:#, 3837, 9097, 2635, 5137, 6442, 5234, 4641, 8039, 2266]:
    results.append({})
    set_seed(acseed)
    X_train,y_train = make_model_dataset(1048576,4,DEVICE)#1048576
    X_eval,y_eval   = make_model_dataset(10000,4,DEVICE)#10000
    X_test,y_test   = make_model_dataset(10000,4,DEVICE)
    X_inj,y_inj   = make_model_dataset(5000,4,DEVICE)#5000
   
    model,accuracy=make_model(X_train,y_train,X_eval,y_eval,X_test,y_test,input_size=16,epochs=20,device=DEVICE)
    Layers=[]
    Layers.append(("Layer1",model.mlp.h[0]))
    Layers.append(("Layer2",model.mlp.h[1]))
    Layers.append(("Layer3",model.mlp.h[2]))


    Layer_Save=[[],[],[],[]]
    Results=[]
    Results.append([0,0,0,0,0,0,0])
    Hooks=[]
    predicted_classes_set = set()
    for acpos,aclayer in enumerate(Layers):
        Layer_Save.append([])
        Results.append([0,0,0,0,0,0,0])
        Hooks.append(register_intervention_hook(Layer_Save,acpos+1,aclayer[1]))
    
    
    test_dataset = TensorDataset(X_inj,y_inj)
    test_loader = DataLoader(test_dataset, batch_size=5, shuffle=False)#6400
    
    
    model.eval()
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            Layer_Save[0].append(X_batch)
            output=model(X_batch)
            predicted = torch.argmax(output, dim=1)
            Layer_Save[-3].append(predicted)
            Layer_Save[-2].append(torch.all(X_batch[:, :4] == X_batch[:, 4:8], dim=1))
            Layer_Save[-1].append(torch.all(X_batch[:, 8:12] == X_batch[:, 12:16], dim=1))
        
            # Update set with predictions from this batch
            predicted_classes_set.update(predicted.tolist())
    for i in range(len(Layer_Save)):
        Layer_Save[i]=torch.cat(Layer_Save[i])
    for i in tqdm(range(len(Layer_Save[0]))):
        for j in range(i+1,len(Layer_Save[0])):
            Results[0][0]+=1
            Results[0][1]+=torch.norm(Layer_Save[0][i]-Layer_Save[0][j], p=2).item()
            if Layer_Save[-3][i]==Layer_Save[-3][j]:
                Results[0][3]+=1
                Results[0][4]+=torch.norm(Layer_Save[0][i]-Layer_Save[0][j], p=2).item()
            if Layer_Save[-2][i]==Layer_Save[-2][j] and Layer_Save[-1][i]==Layer_Save[-1][j]:
                Results[0][5]+=1
                Results[0][6]+=torch.norm(Layer_Save[0][i]-Layer_Save[0][j], p=2).item()
            if torch.equal(Layer_Save[0][i],Layer_Save[0][j]):
                Results[0][2]+=1
            for k in range(1,len(Layer_Save)-3):
                Results[k][0]+=1
                Results[k][1]+=torch.norm(Layer_Save[k][i]-Layer_Save[k][j], p=2).item()
                if torch.equal(Layer_Save[k][i],Layer_Save[k][j]):
                    Results[k][2]+=1
                if Layer_Save[-3][i]==Layer_Save[-3][j]:
                    Results[k][3]+=1
                    Results[k][4]+=torch.norm(Layer_Save[k][i]-Layer_Save[k][j], p=2).item()
                if Layer_Save[-2][i]==Layer_Save[-2][j] and Layer_Save[-1][i]==Layer_Save[-1][j]:
                    Results[k][5]+=1
                    Results[k][6]+=torch.norm(Layer_Save[k][i]-Layer_Save[k][j], p=2).item()
    
    all_classes_set = set(range(num_classes))
    missing_classes_set = all_classes_set - predicted_classes_set
    Results_processed=[]
    for i in Results:
        Results_processed.append([i[1]/i[0],i[2]/i[0],i[4]/i[3],i[6]/i[5]])
    print("Surjectivity: Missing:", missing_classes_set,"Found:",predicted_classes_set)
    print("Injectivity", Results_processed)  
    print("Injectivity", Results)          

  return torch.tensor(model_inputs, dtype=torch.float32).to(device),torch.tensor(labels, dtype=torch.float32).to(device)


Epoch 1, Loss: 0.39049801870714873 steps without improvement: 7 best accuracy: 0.9837
Epoch 2, Loss: 0.026408711896237946 steps without improvement: 72 best accuracy: 0.9964
Epoch 3, Loss: 0.008641325565577063 steps without improvement: 37 best accuracy: 0.9986
Epoch 4, Loss: 0.004503298025667846 steps without improvement: 60 best accuracy: 0.9991
Epoch 5, Loss: 0.0028202552880998155 steps without improvement: 138 best accuracy: 0.9997
Epoch 6, Loss: 0.0019203961838059058 steps without improvement: 301 best accuracy: 0.9998


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [45:42<00:00,  1.82it/s]

Surjectivity: Missing: set() Found: {0, 1}
Injectivity [[1.6054958673748023, 0.0, 1.6054210510295241, 1.6019198240301147], [1.117766764482549, 0.0, 1.0932176194103445, 0.9145127258176602], [1.9995332732080362, 0.0, 1.9030948216695343, 1.280946735286285], [4.2082161881334255, 0.0, 3.2319988631395575, 2.1088286168429353]]
Injectivity [[12497500, 20064684.60251659, 0, 6247669, 10030139.332464576, 3123817, 5004104.378942281], [12497500, 13969290.139120657, 0, 6247669, 6830061.831043807, 3123817, 2856770.399625546], [12497500, 24989167.08191743, 0, 6247669, 11889906.521405278, 3123817, 4001443.1877817973], [12497500, 52592181.81119748, 0, 6247669, 20192459.105272256, 3123817, 6587594.683380447]]





In [8]:
Results

[[12497500,
  20064684.60251659,
  0,
  6247669,
  10030139.332464576,
  3123817,
  5004104.378942281],
 [12497500,
  13969290.139120657,
  0,
  6247669,
  6830061.831043807,
  3123817,
  2856770.399625546],
 [12497500,
  24989167.08191743,
  0,
  6247669,
  11889906.521405278,
  3123817,
  4001443.1877817973],
 [12497500,
  52592181.81119748,
  0,
  6247669,
  20192459.105272256,
  3123817,
  6587594.683380447]]

In [9]:
Results_processed

[[1.6054958673748023, 0.0, 1.6054210510295241, 1.6019198240301147],
 [1.117766764482549, 0.0, 1.0932176194103445, 0.9145127258176602],
 [1.9995332732080362, 0.0, 1.9030948216695343, 1.280946735286285],
 [4.2082161881334255, 0.0, 3.2319988631395575, 2.1088286168429353]]