In [1]:
import sys
sys.path.append("..")
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pickle
import random
import numpy as np
import copy
import json
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

In [2]:
import wandb
wandb.init(mode="disabled")

In [3]:
from das.Classification_Model import (MLPForClassification,
                                  train_model,
                                  eval_model,
                                  test_model,
                                  make_model)
from das.Helper_Functions import set_seed
from das.Dataset_Generation import (make_model_dataset,
                                make_model_dataset_AndOrAnd,
                                make_intervention_dataset_variable_intervention_all,
                                make_intervention_dataset_variable_intervention_first,
                                make_intervention_dataset_first_input_intervention,
                                make_intervention_dataset_AndOrAnd,
                                make_intervention_dataset_AndOr)
from das.RevNet import RevNet
from das.Rotation_Model import Rotation
from das.DAS import phi_class
from das.DAS_MLP import Distributed_Alignment_Search_MLP

In [4]:
DEVICE  = "cpu" #"cuda"/"cpu"
num_classes=2

In [5]:
def register_intervention_hook(Save_array,Pos,layer):
    def hook_fn(module, input, output):
        Save_array[Pos].append(output.detach().cpu())
    layer.register_forward_hook(hook_fn)

In [6]:
Full_results=[]
for acseed in [4287, 3837, 9097, 2635, 5137, 6442, 5234, 4641, 8039, 2266]:
    set_seed(acseed)
    X_train,y_train = make_model_dataset(1048576,4,DEVICE)#1048576
    X_eval,y_eval   = make_model_dataset(10000,4,DEVICE)#10000
    X_test,y_test   = make_model_dataset(10000,4,DEVICE)
    X_inj,y_inj   = make_model_dataset(10000,4,DEVICE)#5000
   
    model,accuracy=make_model(X_train,y_train,X_eval,y_eval,X_test,y_test,input_size=16,epochs=20,device=DEVICE)
    Layers=[]
    Layers.append(("Layer1",model.mlp.h[0]))
    Layers.append(("Layer2",model.mlp.h[1]))
    Layers.append(("Layer3",model.mlp.h[2]))


    Layer_Save=[[],[],[],[]]
    Results=[]
    Results.append([0,0,0,0,0,0,0])
    Hooks=[]
    predicted_classes_set = set()
    for acpos,aclayer in enumerate(Layers):
        Layer_Save.append([])
        Results.append([0,0,0,0,0,0,0])
        Hooks.append(register_intervention_hook(Layer_Save,acpos+1,aclayer[1]))
    
    
    test_dataset = TensorDataset(X_inj,y_inj)
    test_loader = DataLoader(test_dataset, batch_size=5, shuffle=False)#6400
    
    
    model.eval()
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            Layer_Save[0].append(X_batch)
            output=model(X_batch)
            predicted = torch.argmax(output, dim=1)
            Layer_Save[-3].append(predicted)
            Layer_Save[-2].append(torch.all(X_batch[:, :4] == X_batch[:, 4:8], dim=1))
            Layer_Save[-1].append(torch.all(X_batch[:, 8:12] == X_batch[:, 12:16], dim=1))
        
            # Update set with predictions from this batch
            predicted_classes_set.update(predicted.tolist())
    for i in range(len(Layer_Save)):
        Layer_Save[i]=torch.cat(Layer_Save[i])
    for i in tqdm(range(len(Layer_Save[0]))):
        for j in range(i+1,len(Layer_Save[0])):
            Results[0][0]+=1
            Results[0][1]+=torch.norm(Layer_Save[0][i]-Layer_Save[0][j], p=2).item()
            if Layer_Save[-3][i]==Layer_Save[-3][j]:
                Results[0][3]+=1
                Results[0][4]+=torch.norm(Layer_Save[0][i]-Layer_Save[0][j], p=2).item()
            if Layer_Save[-2][i]==Layer_Save[-2][j] and Layer_Save[-1][i]==Layer_Save[-1][j]:
                Results[0][5]+=1
                Results[0][6]+=torch.norm(Layer_Save[0][i]-Layer_Save[0][j], p=2).item()
            if torch.equal(Layer_Save[0][i],Layer_Save[0][j]):
                Results[0][2]+=1
            for k in range(1,len(Layer_Save)-3):
                Results[k][0]+=1
                Results[k][1]+=torch.norm(Layer_Save[k][i]-Layer_Save[k][j], p=2).item()
                if torch.equal(Layer_Save[k][i],Layer_Save[k][j]):
                    Results[k][2]+=1
                if Layer_Save[-3][i]==Layer_Save[-3][j]:
                    Results[k][3]+=1
                    Results[k][4]+=torch.norm(Layer_Save[k][i]-Layer_Save[k][j], p=2).item()
                if Layer_Save[-2][i]==Layer_Save[-2][j] and Layer_Save[-1][i]==Layer_Save[-1][j]:
                    Results[k][5]+=1
                    Results[k][6]+=torch.norm(Layer_Save[k][i]-Layer_Save[k][j], p=2).item()
    
    all_classes_set = set(range(num_classes))
    missing_classes_set = all_classes_set - predicted_classes_set
    Results_processed=[]
    for i in Results:
        Results_processed.append([i[1]/i[0],i[2]/i[0],i[4]/i[3],i[6]/i[5]])
    print("Surjectivity: Missing:", missing_classes_set,"Found:",predicted_classes_set)
    print("Injectivity", Results_processed)  
    print("Injectivity", Results) 
    Full_results.append({})
    Full_results[-1]["Results"]=Results
    Full_results[-1]["Results_processed"]=Results_processed
    Full_results[-1]["missing_classes_set"]=list(missing_classes_set)
    Full_results[-1]["predicted_classes_set"]=list(predicted_classes_set)
    with open('results.json', 'w') as f:
        json.dump(Full_results, f)

  return torch.tensor(model_inputs, dtype=torch.float32).to(device),torch.tensor(labels, dtype=torch.float32).to(device)


Epoch 1, Loss: 0.39049801870714873 steps without improvement: 7 best accuracy: 0.9837
Epoch 2, Loss: 0.026408711896237946 steps without improvement: 72 best accuracy: 0.9964
Epoch 3, Loss: 0.008641325565577063 steps without improvement: 37 best accuracy: 0.9986
Epoch 4, Loss: 0.004503298025667846 steps without improvement: 60 best accuracy: 0.9991
Epoch 5, Loss: 0.0028202552880998155 steps without improvement: 138 best accuracy: 0.9997
Epoch 6, Loss: 0.0019203961838059058 steps without improvement: 301 best accuracy: 0.9998


100%|███████████████████████████████████| 10000/10000 [3:01:04<00:00,  1.09s/it]


Surjectivity: Missing: set() Found: {0, 1}
Injectivity [[1.6062524000963159, 0.0, 1.6062073364124536, 1.6027513844150898], [1.1187233052468357, 0.0, 1.0942841669223096, 0.9170955777466178], [2.0019155553167427, 0.0, 1.905394463454935, 1.2899510248964001], [4.199196621574812, 0.0, 3.22916042301307, 2.124650866420493]]
Injectivity [[49995000, 80304588.74281532, 0, 24995900, 40148597.96023205, 12499010, 20032805.58131805], [49995000, 55930571.64581555, 0, 24995900, 27352617.607973356, 12499010, 11462786.797210753], [49995000, 100085768.18806055, 0, 24995900, 47627049.469073206, 12499010, 16123110.759690354], [49995000, 209938835.0956327, 0, 24995900, 80715771.0175924, 12499010, 26556032.425898407]]
Epoch 1, Loss: 0.4398490324092563 steps without improvement: 4 best accuracy: 0.9509
Epoch 2, Loss: 0.04545272669111 steps without improvement: 24 best accuracy: 0.9961
Epoch 3, Loss: 0.008210613968117286 steps without improvement: 0 best accuracy: 0.9986
Epoch 4, Loss: 0.003960001421773995 ste

100%|███████████████████████████████████| 10000/10000 [3:04:28<00:00,  1.11s/it]


Surjectivity: Missing: set() Found: {0, 1}
Injectivity [[1.6116805801255343, 0.0, 1.6116501203351647, 1.608146495537712], [0.9382596691128412, 0.0, 0.9196931557965017, 0.7773773681690705], [0.9897321527266767, 0.0, 0.9171501031767214, 0.5658387308307294], [1.9914755162070028, 0.0, 1.214609791948695, 0.6593067594242125]]
Injectivity [[49995000, 80575970.60337609, 0, 24995400, 40283839.41782558, 12496505, 20096210.722219497], [49995000, 46908292.15729649, 0, 24995400, 22988098.306395877, 12496505, 9714500.168211631], [49995000, 49481658.9755702, 0, 24995400, 22924533.688943423, 12496505, 7071006.529019864], [49995000, 99563818.4327691, 0, 24995400, 30359657.593674406, 12496505, 8239030.215678468]]
Epoch 1, Loss: 0.5161472786130616 steps without improvement: 1 best accuracy: 0.9309
Epoch 2, Loss: 0.10854338564968202 steps without improvement: 17 best accuracy: 0.9841
Epoch 3, Loss: 0.045250272887642495 steps without improvement: 22 best accuracy: 0.9911
Epoch 4, Loss: 0.026091900404935586

100%|███████████████████████████████████| 10000/10000 [3:04:12<00:00,  1.11s/it]


Surjectivity: Missing: set() Found: {0, 1}
Injectivity [[1.6062736542884029, 0.0, 1.6062561515001899, 1.6027905496743713], [1.3293357958496836, 0.0, 1.3052414809290975, 1.14719372238951], [2.8090146347961844, 0.0, 2.7471793859454383, 2.1965310130855986], [5.99068023564102, 0.0, 5.621912692000358, 3.8167821266459603]]
Injectivity [[49995000, 80305651.3461487, 0, 24997401, 40152229.127767, 12496938, 20029974.12626654], [49995000, 66460143.11350494, 0, 24997401, 32627644.700618505, 12496938, 14336408.822690917], [49995000, 140436686.66663525, 0, 24997401, 68672344.72941189, 12496938, 27449911.885607917], [49995000, 299504058.3808728, 0, 24997401, 140533205.94892246, 12496938, 47698089.596202716]]
Epoch 1, Loss: 0.4501486698136432 steps without improvement: 0 best accuracy: 0.9342
Epoch 2, Loss: 0.08092447123453894 steps without improvement: 2 best accuracy: 0.9903
Epoch 3, Loss: 0.023304959614961263 steps without improvement: 119 best accuracy: 0.9964
Epoch 4, Loss: 0.011987315442411273 s

100%|███████████████████████████████████| 10000/10000 [3:04:32<00:00,  1.11s/it]


Surjectivity: Missing: set() Found: {0, 1}
Injectivity [[1.6092224155436392, 0.0, 1.6091902043148012, 1.6058842879395137], [1.1175274106780626, 0.0, 1.1021288412328047, 0.9725840055522534], [1.5899488942311077, 0.0, 1.5272682948808867, 1.0190041406426145], [3.436918770485683, 0.0, 2.6899388799511663, 1.4618381111539387]]
Injectivity [[49995000, 80453074.66510424, 0, 24995676, 40222796.96942657, 12500986, 20075137.00115183], [49995000, 55870782.89684974, 0, 24995676, 27548455.42571063, 12500986, 12158259.037232643], [49995000, 79489494.96708423, 0, 24995676, 38175103.4639151, 12500986, 12738556.496115355], [49995000, 171828753.93043172, 0, 24995676, 67236840.70306225, 12500986, 18274417.76180183]]
Epoch 1, Loss: 0.5061693647876382 steps without improvement: 10 best accuracy: 0.911
Epoch 2, Loss: 0.13824304291483713 steps without improvement: 5 best accuracy: 0.9788
Epoch 3, Loss: 0.0563202187877323 steps without improvement: 92 best accuracy: 0.9912
Epoch 4, Loss: 0.03151635768608685 st

100%|███████████████████████████████████| 10000/10000 [3:03:43<00:00,  1.10s/it]


Surjectivity: Missing: set() Found: {0, 1}
Injectivity [[1.6073220052450587, 0.0, 1.6070941358542443, 1.6035541371801847], [1.2794275598710592, 0.0, 1.258537555604157, 1.088796755990033], [2.927911102028927, 0.0, 2.8540916791602444, 2.3583705197391094], [4.258049850565859, 0.0, 3.8284332979387052, 2.8725594984414733]]
Injectivity [[49995000, 80358063.65222672, 0, 25033416, 40231056.05399981, 12512764, 20064894.479759276], [49995000, 63964980.85575361, 0, 25033416, 31505494.181061994, 12512764, 13623856.851668868], [49995000, 146380915.5459362, 0, 25033416, 71447664.30655693, 12512764, 29509733.73805282], [49995000, 212881202.27904013, 0, 25033416, 95838763.37555155, 12512764, 35943659.079956524]]
Epoch 1, Loss: 0.3828369175980697 steps without improvement: 19 best accuracy: 0.9963
Epoch 2, Loss: 0.0072656655906939704 steps without improvement: 10 best accuracy: 0.9993
Epoch 3, Loss: 0.002645535667625154 steps without improvement: 144 best accuracy: 0.9997
Epoch 4, Loss: 0.0014668943836

100%|███████████████████████████████████| 10000/10000 [3:04:49<00:00,  1.11s/it]


Surjectivity: Missing: set() Found: {0, 1}
Injectivity [[1.607259651586084, 0.0, 1.6072802594868445, 1.603664866986854], [0.9532000300431138, 0.0, 0.9407256447458406, 0.8004543707036372], [1.9033119486194316, 0.0, 1.74371223277668, 0.8214054080382402], [5.069972775294422, 0.0, 3.9814548072617217, 1.6212243874009242]]
Injectivity [[49995000, 80354946.28104627, 0, 24998025, 40178832.10865863, 12500554, 20046699.267671987], [49995000, 47655235.50200547, 0, 24998025, 23516283.18549764, 12500554, 10006123.085516835], [49995000, 95156080.87122849, 0, 24998025, 43589361.987757266, 12500554, 10268022.659074057], [49995000, 253473288.90084466, 0, 24998025, 99528506.8082987, 12500554, 20266203.00082217]]
Epoch 1, Loss: 0.4623644460370997 steps without improvement: 9 best accuracy: 0.9617
Epoch 2, Loss: 0.04989305537856126 steps without improvement: 33 best accuracy: 0.9955
Epoch 3, Loss: 0.012785509368313797 steps without improvement: 145 best accuracy: 0.9982
Epoch 4, Loss: 0.006159185772503406

100%|███████████████████████████████████| 10000/10000 [3:05:11<00:00,  1.11s/it]


Surjectivity: Missing: set() Found: {0, 1}
Injectivity [[1.6079940678433229, 0.0, 1.607914767321643, 1.6042424822107568], [1.1897857630820758, 0.0, 1.1683225430739173, 0.9852627927514517], [1.9513173833679216, 0.0, 1.8640219741382482, 1.2520198990757334], [3.676828120744925, 0.0, 2.9665157407614475, 1.7265219072565339]]
Injectivity [[49995000, 80391663.42182693, 0, 25001889, 40200906.53403655, 12510021, 20069107.141548693], [49995000, 59483339.225288376, 0, 25001889, 29210270.5381318, 12510021, 12325658.227839308], [49995000, 97556112.58147924, 0, 25001889, 46604070.49096535, 12510021, 15662795.229855306], [49995000, 183823021.89664254, 0, 25001889, 74168497.26727049, 12510021, 21598825.31673929]]
Epoch 1, Loss: 0.44204726628231583 steps without improvement: 6 best accuracy: 0.9684
Epoch 2, Loss: 0.05056346689525526 steps without improvement: 74 best accuracy: 0.9924
Epoch 3, Loss: 0.019781852772212005 steps without improvement: 55 best accuracy: 0.9973
Epoch 4, Loss: 0.012084425419743

100%|███████████████████████████████████| 10000/10000 [3:04:09<00:00,  1.10s/it]


Surjectivity: Missing: set() Found: {0, 1}
Injectivity [[1.6061571386875442, 0.0, 1.6061696141760247, 1.6028435674240267], [1.2241925686496185, 0.0, 1.2020133305599345, 1.0290868504151802], [2.721948233475916, 0.0, 2.6377805771929097, 2.0457959256244997], [4.927898843784046, 0.0, 4.534750237042202, 3.3554525440732865]]
Injectivity [[49995000, 80299826.14868377, 0, 24997304, 40149910.121120796, 12497797, 20032013.528421298], [49995000, 61203507.46963767, 0, 24997304, 30047092.636059172, 12497797, 12861318.551858287], [49995000, 136083801.93262842, 0, 24997304, 65937402.97338663, 12497797, 25567942.1818821], [49995000, 246370302.69498336, 0, 24997304, 113356530.23941597, 12497797, 41935764.73896149]]
Epoch 1, Loss: 0.4044290292367805 steps without improvement: 0 best accuracy: 0.9618
Epoch 2, Loss: 0.048057374328891456 steps without improvement: 33 best accuracy: 0.9949
Epoch 3, Loss: 0.011773959729907801 steps without improvement: 81 best accuracy: 0.9981
Epoch 4, Loss: 0.00615440726539

100%|███████████████████████████████████| 10000/10000 [3:04:57<00:00,  1.11s/it]


Surjectivity: Missing: set() Found: {0, 1}
Injectivity [[1.6043662035976995, 0.0, 1.60433194435454, 1.600887214108796], [1.0593074544599934, 0.0, 1.036685836156183, 0.8731391906782866], [1.2158356967631732, 0.0, 1.168732509342229, 0.8686391420728741], [3.0718428998910547, 0.0, 2.356094126087442, 1.6921212678935482]]
Injectivity [[49995000, 80210288.34886698, 0, 24995400, 40100918.68191947, 12496178, 20004971.585427627], [49995000, 52960076.18572737, 0, 24995400, 25912377.14905826, 12496178, 10910902.74549181], [49995000, 60785705.65967484, 0, 24995400, 29212936.564012747, 12496178, 10854669.337109923], [49995000, 153576785.7800533, 0, 24995400, 58891515.11920604, 12496178, 21145048.561183464]]
Epoch 1, Loss: 0.42178907026391244 steps without improvement: 7 best accuracy: 0.9588
Epoch 2, Loss: 0.054393591018197185 steps without improvement: 30 best accuracy: 0.9955
Epoch 3, Loss: 0.010968930734179594 steps without improvement: 116 best accuracy: 0.9989
Epoch 4, Loss: 0.00507179076180364

100%|███████████████████████████████████| 10000/10000 [3:05:08<00:00,  1.11s/it]

Surjectivity: Missing: set() Found: {0, 1}
Injectivity [[1.6074383553947849, 0.0, 1.6073831647598422, 1.6038395700585908], [0.8667285248388393, 0.0, 0.8474861772377078, 0.7255274033345864], [0.8236977808144543, 0.0, 0.7835643753990937, 0.5405897412703982], [2.5808393145957327, 0.0, 1.839171771947828, 1.1665662371696497]]
Injectivity [[49995000, 80363880.57796226, 0, 24995289, 40177006.73690687, 12495665, 20041041.98119618], [49995000, 43332092.599317774, 0, 24995289, 21183161.92356173, 12495665, 9065947.380388875], [49995000, 41180770.55181865, 0, 24995289, 19585418.01320484, 12495665, 6755028.30935157], [49995000, 129029061.53321366, 0, 24995289, 45970629.96047805, 12495665, 14577020.89998249]]





In [13]:
processed_results=[]
for ac_results in Full_results:
    for i in range(len(ac_results["Results_processed"])):
        if len(processed_results)-1<i:
            processed_results.append([])
        for j in range(len(ac_results["Results_processed"][i])):
            if len(processed_results[i])-1<j:
                processed_results[i].append([])
            processed_results[i][j].append(ac_results["Results_processed"][i][j])

In [14]:

for i in range(len(processed_results)):
    for j in range(len(processed_results[i])):
        processed_results[i][j]=np.array(processed_results[i][j])
        processed_results[i][j]=[round(np.mean(processed_results[i][j]),2),round(np.std(processed_results[i][j]),2)]

In [15]:
processed_results

[[[1.61, 0.0], [0.0, 0.0], [1.61, 0.0], [1.6, 0.0]],
 [[1.11, 0.15], [0.0, 0.0], [1.09, 0.14], [0.93, 0.13]],
 [[1.89, 0.72], [0.0, 0.0], [1.81, 0.71], [1.3, 0.64]],
 [[3.92, 1.16], [0.0, 0.0], [3.23, 1.24], [2.05, 0.95]]]