In [1]:
import sys
sys.path.append("..")
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pickle
import random
import numpy as np
import copy
import json

In [2]:
import wandb
wandb.init(mode="disabled")

In [3]:
from das.Classification_Model import (MLPForClassification,
                                  train_model,
                                  eval_model,
                                  test_model,
                                  make_model)
from das.Helper_Functions import set_seed
from das.Dataset_Generation import (make_model_dataset,
                                make_model_dataset_AndOrAnd,
                                make_intervention_dataset_variable_intervention_all,
                                make_intervention_dataset_variable_intervention_first,
                                make_intervention_dataset_first_input_intervention,
                                make_intervention_dataset_AndOrAnd,
                                make_intervention_dataset_AndOr)
from das.RevNet import RevNet
from das.Rotation_Model import Rotation
from das.DAS import phi_class
from das.DAS_MLP import Distributed_Alignment_Search_MLP

In [4]:
DEVICE  = "cpu" #"cuda"/"cpu"
#Setting = "Both Equality Relations"
#Setting = "Left Equality Relation"
Setting = "Identity of First Argument"
#Setting = "AndOrAnd"
#Setting = "AndOr"


transformation_config = {"type"        : "identity",
                         "in_features" :         16}

Max_Epochs                       = 50
Early_Stopping_Epochs            = 5
early_stopping_improve_threshold = 0.001
ReduceLROnPlateau_patience       = 10

In [5]:
one_variable_settings = ["Identity of First Argument","Left Equality Relation"]
two_variable_settings = ["Both Equality Relations","AndOrAnd","AndOr"]
DAS_Original_tasks    = ["Both Equality Relations","Left Equality Relation","Identity of First Argument"]
AndOrAnd_tasks        = ["AndOrAnd","AndOr"]
numvar=None
if Setting in two_variable_settings:
    numvar=2
elif Setting in one_variable_settings:
    numvar=1

In [6]:
def Process_Data_Left_Equality(Data):
    for i in range(len(Data)):
        Data[i]["sources"]=Data[i]["sources"][:1]
        Data[i]["intervention"]=Data[i]["intervention"][:1]
    return Data

In [7]:
results=[]
for acseed in [4287, 3837, 9097, 2635, 5137, 6442, 5234, 4641, 8039, 2266]:
    results.append({})
    set_seed(acseed)
    if Setting in DAS_Original_tasks:
        X_train,y_train = make_model_dataset(1048576,4,DEVICE)
        X_eval,y_eval   = make_model_dataset(10000,4,DEVICE)
        X_test,y_test   = make_model_dataset(10000,4,DEVICE)
    elif Setting in AndOrAnd_tasks:
        X_train,y_train = make_model_dataset_AndOrAnd(1048576,4,DEVICE)
        X_eval,y_eval   = make_model_dataset_AndOrAnd(10000,4,DEVICE)
        X_test,y_test   = make_model_dataset_AndOrAnd(10000,4,DEVICE)
    #print("!!!!!!!!!!!!! Set and training number and epochs back to 20")
    model,accuracy=make_model(X_train,y_train,X_eval,y_eval,X_test,y_test,input_size=transformation_config["in_features"],epochs=20,device=DEVICE)
    Layers=[]
    Layers.append(("Layer1",model.mlp.h[0]))
    Layers.append(("Layer2",model.mlp.h[1]))
    Layers.append(("Layer3",model.mlp.h[2]))
    inter_dims=[]
    

    if Setting == "Both Equality Relations":
        DAS_Train = make_intervention_dataset_variable_intervention_all(1,4)
        DAS_Test  = make_intervention_dataset_variable_intervention_all(10000,4)
        DAS_Eval  = make_intervention_dataset_variable_intervention_all(1,4)
    elif Setting == "Left Equality Relation":
        DAS_Train = Process_Data_Left_Equality(make_intervention_dataset_variable_intervention_first(1,4))
        DAS_Test  = Process_Data_Left_Equality(make_intervention_dataset_variable_intervention_first(10000,4))
        DAS_Eval  = Process_Data_Left_Equality(make_intervention_dataset_variable_intervention_first(1,4))
    elif Setting == "Identity of First Argument":
        DAS_Train = make_intervention_dataset_first_input_intervention(1,4)
        DAS_Test  = make_intervention_dataset_first_input_intervention(10000,4)
        DAS_Eval  = make_intervention_dataset_first_input_intervention(1,4)
    elif Setting == "AndOrAnd":
        DAS_Train = make_intervention_dataset_AndOrAnd(1,4)
        DAS_Test  = make_intervention_dataset_AndOrAnd(10000,4)
        DAS_Eval  = make_intervention_dataset_AndOrAnd(1,4)
    elif Setting == "AndOr":
        DAS_Train = make_intervention_dataset_AndOr(1,4)
        DAS_Test  = make_intervention_dataset_AndOr(10000,4)
        DAS_Eval  = make_intervention_dataset_AndOr(1,4)
    else:
        Exception("Unknown Setting")
        
    results[-1]["accuracy"]=accuracy
    for LayerName,Layer in Layers:
        results[-1][LayerName]={}
            
    
        p = torch.nn.Identity()
        p.to(DEVICE)
        p_inverse = torch.nn.Identity()
        optimizer = None #optim.Adam(p.parameters(), lr=0.001)
        scheduler = None #optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=ReduceLROnPlateau_patience)
        criterion = nn.CrossEntropyLoss()
        
        
        phi=phi_class(p,p_inverse,criterion,optimizer,scheduler)

        

        

        
        greedy_variables=[[]]
        greedy_accuracy=-1
        if Setting in two_variable_settings :
            for i in range(transformation_config["in_features"]):
                for j in range(transformation_config["in_features"]):
                    if i==j:
                        continue
                    DAS_Experiment=Distributed_Alignment_Search_MLP(Model=model,
                                                                    Model_Layer=Layer,
                                                                    Train_Data_Raw=DAS_Train,
                                                                    Test_Data_Raw=DAS_Test,
                                                                    Eval_Data_Raw=DAS_Eval,
                                                                    Hidden_Layer_Size=transformation_config["in_features"],
                                                                    Variable_Dimensions=[[i],[j]],
                                                                    Transformation_Class=phi,
                                                                    Device=DEVICE)
                    accuracy=DAS_Experiment.train_test(batch_size=6400,
                                                       mode=2)#Test
                    DAS_Experiment.Cleanup()
                    print("*  ",accuracy,DAS_Experiment.Variable_Dimensions)
                    if greedy_accuracy<accuracy:
                        greedy_accuracy=accuracy
                        greedy_variables=DAS_Experiment.Variable_Dimensions
        
        ac_var_greedy_old=copy.deepcopy(greedy_variables)
        while ((Setting in one_variable_settings) and len(ac_var_greedy_old[0])<transformation_config["in_features"]) or ((Setting in two_variable_settings) and len(ac_var_greedy_old[0]+ac_var_greedy_old[1])<transformation_config["in_features"]):
            ac_acc_greedy=-1
            ac_var_greedy_new=None
            
            for i in range(numvar):
                for j in range(transformation_config["in_features"]):
                    if (Setting in two_variable_settings) and (j in ac_var_greedy_old[0]+ac_var_greedy_old[1]):
                        continue
                    if (Setting in one_variable_settings) and (j in ac_var_greedy_old[0]):
                        continue
                    ac_var=copy.deepcopy(ac_var_greedy_old)
                    ac_var[i].append(j)
                    DAS_Experiment=Distributed_Alignment_Search_MLP(Model=model,
                                                                    Model_Layer=Layer,
                                                                    Train_Data_Raw=DAS_Train,
                                                                    Test_Data_Raw=DAS_Test,
                                                                    Eval_Data_Raw=DAS_Eval,
                                                                    Hidden_Layer_Size=transformation_config["in_features"],
                                                                    Variable_Dimensions=ac_var,
                                                                    Transformation_Class=phi,
                                                                    Device=DEVICE)
                    accuracy=DAS_Experiment.train_test(batch_size=6400,
                                                       mode=2)#Test
                    DAS_Experiment.Cleanup()
                    print("** ",accuracy,DAS_Experiment.Variable_Dimensions)
                    if ac_acc_greedy<accuracy:
                        ac_acc_greedy=accuracy
                        ac_var_greedy_new=ac_var
            ac_var_greedy_old=ac_var_greedy_new
            if greedy_accuracy<ac_acc_greedy:
                greedy_accuracy=ac_acc_greedy
                greedy_variables=copy.deepcopy(ac_var_greedy_old)
                print("***",greedy_accuracy,greedy_variables)
        results[-1][LayerName]=(greedy_variables,greedy_accuracy)
        
        DAS_Experiment=None
        with open('results.json', 'w') as f:
            json.dump(results, f)

  return torch.tensor(model_inputs, dtype=torch.float32).to(device),torch.tensor(labels, dtype=torch.float32).to(device)


Epoch 1, Loss: 0.39049801870714873 steps without improvement: 7 best accuracy: 0.9837
Epoch 2, Loss: 0.026408711896237946 steps without improvement: 72 best accuracy: 0.9964
Epoch 3, Loss: 0.008641325565577063 steps without improvement: 37 best accuracy: 0.9986
Epoch 4, Loss: 0.004503298025667846 steps without improvement: 60 best accuracy: 0.9991
Epoch 5, Loss: 0.0028202552880998155 steps without improvement: 138 best accuracy: 0.9997
Epoch 6, Loss: 0.0019203961838059058 steps without improvement: 301 best accuracy: 0.9998
**  0.505 [[0]]
**  0.5045 [[1]]
**  0.5001 [[2]]
**  0.5036 [[3]]
**  0.5057 [[4]]
**  0.5009 [[5]]
**  0.5 [[6]]
**  0.5062 [[7]]
**  0.5014 [[8]]
**  0.5087 [[9]]
**  0.4992 [[10]]
**  0.5043 [[11]]
**  0.5056 [[12]]
**  0.5073 [[13]]
**  0.5053 [[14]]
**  0.5055 [[15]]
*** 0.5087 [[9]]
**  0.51 [[9, 0]]
**  0.5101 [[9, 1]]
**  0.5049 [[9, 2]]
**  0.5059 [[9, 3]]
**  0.5105 [[9, 4]]
**  0.505 [[9, 5]]
**  0.5053 [[9, 6]]
**  0.5102 [[9, 7]]
**  0.506 [[9, 8]]
** 