In [1]:
import sys
sys.path.append("..")
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pickle
import random
import numpy as np
import copy
import json

In [2]:
from das.Classification_Model import (MLPForClassification,
                                  train_model,
                                  eval_model,
                                  test_model,
                                  make_model)
from das.Helper_Functions import set_seed
from das.Dataset_Generation import (make_model_dataset,
                                make_model_dataset_AndOrAnd,
                                make_intervention_dataset_variable_intervention_all,
                                make_intervention_dataset_variable_intervention_first,
                                make_intervention_dataset_first_input_intervention,
                                make_intervention_dataset_AndOrAnd,
                                make_intervention_dataset_AndOr)
from das.RevNet import RevNet
from das.Rotation_Model import Rotation
from das.DAS import phi_class
from das.DAS_MLP import Distributed_Alignment_Search_MLP

In [3]:
DEVICE  = "cuda" #"cuda"/"cpu"
#Setting = "Both Equality Relations"
#Setting = "Left Equality Relation"
#Setting = "Identity of First Argument"
Setting = "AndOrAnd"
#Setting = "AndOr"


transformation_config = {"type"        : "Rotation",
                         "in_features" :         24}
"""
transformation_config = {"type"          : "RevNet",
                         "number_blocks" :       10,
                         "in_features"   :       16,
                         "hidden_size"   :       16}
"""

Max_Epochs                       = 50
Early_Stopping_Epochs            = 5
early_stopping_improve_threshold = 0.001
ReduceLROnPlateau_patience       = 10

In [4]:
one_variable_settings = ["Identity of First Argument"]
two_variable_settings = ["Both Equality Relations","Left Equality Relation","AndOrAnd","AndOr"]
DAS_Original_tasks    = ["Both Equality Relations","Left Equality Relation","Identity of First Argument"]
AndOrAnd_tasks        = ["AndOrAnd","AndOr"]

In [5]:
results=[]
for acseed in [4287, 3837, 9097, 2635, 5137, 6442, 5234, 4641, 8039, 2266]:
    results.append({})
    set_seed(acseed)
    if Setting in DAS_Original_tasks:
        X_train,y_train = make_model_dataset(1048576,4,DEVICE)
        X_eval,y_eval   = make_model_dataset(10000,4,DEVICE)
        X_test,y_test   = make_model_dataset(10000,4,DEVICE)
    elif Setting in AndOrAnd_tasks:
        X_train,y_train = make_model_dataset_AndOrAnd(1048576,4,DEVICE)
        X_eval,y_eval   = make_model_dataset_AndOrAnd(10000,4,DEVICE)
        X_test,y_test   = make_model_dataset_AndOrAnd(10000,4,DEVICE)
    model,accuracy=make_model(X_train,y_train,X_eval,y_eval,X_test,y_test,input_size=transformation_config["in_features"],epochs=20,device=DEVICE)
    Layers=[]
    Layers.append(("Layer1",model.mlp.h[0]))
    Layers.append(("Layer2",model.mlp.h[1]))
    Layers.append(("Layer3",model.mlp.h[2]))
    inter_dims=[]
    
    if Setting in two_variable_settings:
        inter_dims.append([list(range(0,transformation_config["in_features"]//2)),list(range(transformation_config["in_features"]//2,transformation_config["in_features"]))])
        inter_dims.append([list(range(0,2)),list(range(2,4))])
        inter_dims.append([list(range(0,1)),list(range(1,2))])
    elif Setting in one_variable_settings:
        inter_dims.append([list(range(0,transformation_config["in_features"]//2))])
        inter_dims.append([list(range(0,2))])
        inter_dims.append([list(range(0,1))])
    else:
        Exception("Unknown Setting")

    if Setting == "Both Equality Relations":
        DAS_Train = make_intervention_dataset_variable_intervention_all(1280000,4)
        DAS_Test  = make_intervention_dataset_variable_intervention_all(10000,4)
        DAS_Eval  = make_intervention_dataset_variable_intervention_all(10000,4)
    elif Setting == "Left Equality Relation":
        DAS_Train = make_intervention_dataset_variable_intervention_first(1280000,4)
        DAS_Test  = make_intervention_dataset_variable_intervention_first(10000,4)
        DAS_Eval  = make_intervention_dataset_variable_intervention_first(10000,4)
    elif Setting == "Identity of First Argument":
        DAS_Train = make_intervention_dataset_first_input_intervention(1280000,4)
        DAS_Test  = make_intervention_dataset_first_input_intervention(10000,4)
        DAS_Eval  = make_intervention_dataset_first_input_intervention(10000,4)
    elif Setting == "AndOrAnd":
        DAS_Train = make_intervention_dataset_AndOrAnd(1280000,4)
        DAS_Test  = make_intervention_dataset_AndOrAnd(10000,4)
        DAS_Eval  = make_intervention_dataset_AndOrAnd(10000,4)
    elif Setting == "AndOr":
        DAS_Train = make_intervention_dataset_AndOr(1280000,4)
        DAS_Test  = make_intervention_dataset_AndOr(10000,4)
        DAS_Eval  = make_intervention_dataset_AndOr(10000,4)
    else:
        Exception("Unknown Setting")
        
    results[-1]["accuracy"]=accuracy
    for LayerName,Layer in Layers:
        results[-1][LayerName]={}
        for inter_dim in inter_dims:
            print(LayerName,":",inter_dim, flush=True)
            
    
            #Initialize transformation function
            if transformation_config["type"]=="Rotation":
                p = Rotation(transformation_config["in_features"])
            elif transformation_config["type"]=="RevNet":
                p = RevNet(number_blocks =  transformation_config["number_blocks"],
                           in_features   =  transformation_config["in_features"],
                           hidden_size   =  transformation_config["hidden_size"]
                          )
            else:
                Exception("Unknown transformation function")
            p.to(DEVICE)
            p_inverse = p.inverse
            optimizer = optim.Adam(p.parameters(), lr=0.001)
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=ReduceLROnPlateau_patience)
            criterion = nn.CrossEntropyLoss()
            
            
            phi=phi_class(p,p_inverse,criterion,optimizer,scheduler)

            
    
            DAS_Experiment=Distributed_Alignment_Search_MLP(Model=model,
                                                            Model_Layer=Layer,
                                                            Train_Data_Raw=DAS_Train,
                                                            Test_Data_Raw=DAS_Test,
                                                            Eval_Data_Raw=DAS_Eval,
                                                            Hidden_Layer_Size=transformation_config["in_features"],
                                                            Variable_Dimensions=inter_dim,
                                                            Transformation_Class=phi,
                                                            Device=DEVICE)
    
            DAS_Experiment.train_test(batch_size=6400,
                                      epochs=Max_Epochs,
                                      mode=1,
                                      early_stopping_threshold=Early_Stopping_Epochs,
                                      early_stopping_improve_threshold=early_stopping_improve_threshold) #Train
    
            accuracy=DAS_Experiment.train_test(batch_size=6400,
                                               mode=2)#Test
            
            results[-1][LayerName][str(inter_dim)]=accuracy
            DAS_Experiment.Cleanup()
            DAS_Experiment=None
            with open('results.json', 'w') as f:
                json.dump(results, f)

  return torch.tensor(model_inputs, dtype=torch.float32).to(device),torch.tensor(labels, dtype=torch.float32).to(device)


Epoch 1, Loss: 0.20628716659848578 steps without improvement: 1 best accuracy: 0.978
Epoch 2, Loss: 0.031662344103096984 steps without improvement: 32 best accuracy: 0.9984
Epoch 3, Loss: 0.007456264303982607 steps without improvement: 168 best accuracy: 0.9997
Epoch 4, Loss: 0.003872038481574691 steps without improvement: 675 best accuracy: 0.9999
Layer1 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Loss: 5.227875565290451 steps without improvement: 0 eval accuracy: 0.776 best eval accuracy: 0.776 learning rate: 0.001
Epoch 2, Loss: 1.672629815340042 steps without improvement: 0 eval accuracy: 0.8025 best eval accuracy: 0.8025 learning rate: 0.001
Epoch 3, Loss: 1.5925766408443451 steps without improvement: 0 eval accuracy: 0.814 best eval accuracy: 0.814 learning rate: 0.001
Epoch 4, Loss: 1.5618918687105179 steps without improvement: 1 eval accuracy: 0.8146 best eval accuracy: 0.8146 learning rate: 0.001
Epoch 5, Loss: 1.546930