In [1]:
import sys
sys.path.append("..")
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pickle
import random
import numpy as np
import copy
import json

In [2]:
import wandb
wandb.init(mode="disabled")

In [3]:
from das.Classification_Model import make_model_DAS_fitted
from das.Helper_Functions import set_seed
from das.Dataset_Generation import (make_intervention_dataset_AndOrAnd,
                                make_intervention_dataset_AndOr,
                                make_intervention_dataset_AndOrAnd_DAS_Fitted,
                                make_intervention_dataset_AndOr_DAS_Fitted)
from das.RevNet import RevNet
from das.Rotation_Model import Rotation
from das.DAS import phi_class
from das.DAS_MLP import Distributed_Alignment_Search_MLP

In [4]:
#For different Settings and transformation functions please adapt this configurations:

DEVICE  = "cpu" #"cuda"/"cpu"
Setting = "AndOrAnd"
#Setting = "AndOr"

FitModelTo = "AndOrAnd"
#FitModelTo = "AndOr"


transformation_config = {"type"        : "Rotation",
                         "in_features" :         24}
"""
transformation_config = {"type"          : "RevNet",
                         "number_blocks" :       10,
                         "in_features"   :       24,
                         "hidden_size"   :       24}
"""

Max_Epochs                       = 50
Early_Stopping_Epochs            = 5
early_stopping_improve_threshold = 0.001
ReduceLROnPlateau_patience       = 10

In [None]:
results=[]
for acseed in [4287, 3837, 9097, 2635, 5137, 6442, 5234, 4641, 8039, 2266]:
    results.append({})
    set_seed(acseed)
    if FitModelTo == "AndOrAnd":
        DAS_Train = make_intervention_dataset_AndOrAnd_DAS_Fitted(1280000,4)
        DAS_Test  = make_intervention_dataset_AndOrAnd_DAS_Fitted(10000,4)
        DAS_Eval  = make_intervention_dataset_AndOrAnd_DAS_Fitted(10000,4)
    elif FitModelTo == "AndOr":
        DAS_Train = make_intervention_dataset_AndOr_DAS_Fitted(1280000,4)
        DAS_Test  = make_intervention_dataset_AndOr_DAS_Fitted(10000,4)
        DAS_Eval  = make_intervention_dataset_AndOr_DAS_Fitted(10000,4)
    
    model,accuracy = make_model_DAS_fitted(DAS_Train=DAS_Train,
                                           DAS_Test=DAS_Test,
                                           DAS_Eval=DAS_Eval,
                                           Hidden_Layer_Size=24,
                                           inter_dim=[list(range(0,12)),list(range(12,24))],
                                           DEVICE=DEVICE,
                                           Max_Epochs=50,
                                           Early_Stopping_Epochs=5,
                                           early_stopping_improve_threshold=0.001,
                                           ReduceLROnPlateau_patience=10)

        
    Layers=[]
    Layers.append(("Layer1",model.mlp.h[0]))
    Layers.append(("Layer2",model.mlp.h[1]))
    Layers.append(("Layer3",model.mlp.h[2]))
    inter_dims=[]
    
    inter_dims.append([list(range(0,transformation_config["in_features"]//2)),list(range(transformation_config["in_features"]//2,transformation_config["in_features"]))])
    inter_dims.append([list(range(0,2)),list(range(2,4))])
    inter_dims.append([list(range(0,1)),list(range(1,2))])
    



    if Setting == "AndOrAnd":
        DAS_Train = make_intervention_dataset_AndOrAnd(1280000,4)
        DAS_Test  = make_intervention_dataset_AndOrAnd(10000,4)
        DAS_Eval  = make_intervention_dataset_AndOrAnd(10000,4)
    elif Setting == "AndOr":
        DAS_Train = make_intervention_dataset_AndOr(1280000,4)
        DAS_Test  = make_intervention_dataset_AndOr(10000,4)
        DAS_Eval  = make_intervention_dataset_AndOr(10000,4)
    else:
        Exception("Unknown Setting")
        
    results[-1]["accuracy"]=accuracy
    for LayerName,Layer in Layers:
        results[-1][LayerName]={}
        for inter_dim in inter_dims:
            print(LayerName,":",inter_dim, flush=True)
            
    
            #Initialize transformation function
            if transformation_config["type"]=="Rotation":
                p = Rotation(transformation_config["in_features"])
            elif transformation_config["type"]=="RevNet":
                p = RevNet(number_blocks =  transformation_config["number_blocks"],
                           in_features   =  transformation_config["in_features"],
                           hidden_size   =  transformation_config["hidden_size"]
                          )
            else:
                Exception("Unknown transformation function")
            p.to(DEVICE)
            p_inverse = p.inverse
            optimizer = optim.Adam(p.parameters(), lr=0.001)
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=ReduceLROnPlateau_patience)
            criterion = nn.CrossEntropyLoss()
            
            
            phi=phi_class(p,p_inverse,criterion,optimizer,scheduler)

            
    
            DAS_Experiment=Distributed_Alignment_Search_MLP(Model=model,
                                                            Model_Layer=Layer,
                                                            Train_Data_Raw=DAS_Train,
                                                            Test_Data_Raw=DAS_Test,
                                                            Eval_Data_Raw=DAS_Eval,
                                                            Hidden_Layer_Size=transformation_config["in_features"],
                                                            Variable_Dimensions=inter_dim,
                                                            Transformation_Class=phi,
                                                            Device=DEVICE)
    
            DAS_Experiment.train_test(batch_size=6400,
                                      epochs=Max_Epochs,
                                      mode=1,
                                      early_stopping_threshold=Early_Stopping_Epochs,
                                      early_stopping_improve_threshold=early_stopping_improve_threshold) #Train
    
            accuracy=DAS_Experiment.train_test(batch_size=6400,
                                               mode=2)#Test
            
            results[-1][LayerName][str(inter_dim)]=accuracy
            DAS_Experiment.Cleanup()
            DAS_Experiment=None
            with open('results.json', 'w') as f:
                json.dump(results, f)

  intervention_data[-1]["sources"]=torch.tensor([source0,source1], dtype=torch.float32)


start
Epoch 1/1 Training:


Loss: 0.4578, Batch Acc: 0.7781: 100%|████████| 200/200 [00:19<00:00, 10.24it/s]


Validation:


Validation Accuracy: 0.7681: 100%|████████████████| 2/2 [00:00<00:00, 11.67it/s]

Epoch 1, Avg Loss: 0.6175,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001





Loading best phi with loss: 0.9483
Layer1 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.6506,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.1320
Layer1 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.8515,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.2979
Layer1 : [[0], [1]]
Epoch 1, Avg Loss: 0.8512,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.3311
Layer2 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.5362,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.0399
Layer2 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.6812,  Steps w/o Improvement: 0,  Eval Loss (End of Epoc

Loss: 0.3332, Batch Acc: 0.8553: 100%|████████| 200/200 [00:19<00:00, 10.38it/s]


Validation:


Validation Accuracy: 0.8588: 100%|████████████████| 2/2 [00:00<00:00, 14.64it/s]

Epoch 1, Avg Loss: 0.5797,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001





Loading best phi with loss: 0.6521
Layer1 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.4898,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 0.8451
Layer1 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.7841,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 0.9803
Layer1 : [[0], [1]]
Epoch 1, Avg Loss: 0.9997,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.1802
Layer2 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.4065,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 0.7626
Layer2 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.5908,  Steps w/o Improvement: 0,  Eval Loss (End of Epoc

Loss: 0.4249, Batch Acc: 0.7989: 100%|████████| 200/200 [00:19<00:00, 10.09it/s]


Validation:


Validation Accuracy: 0.8005: 100%|████████████████| 2/2 [00:00<00:00, 16.26it/s]

Epoch 1, Avg Loss: 0.6253,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001





Loading best phi with loss: 0.8509
Layer1 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.5625,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.0349
Layer1 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.7477,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.1608
Layer1 : [[0], [1]]
Epoch 1, Avg Loss: 0.7835,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.2250
Layer2 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.4784,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 0.9359
Layer2 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.6236,  Steps w/o Improvement: 0,  Eval Loss (End of Epoc

Loss: 0.4694, Batch Acc: 0.7833: 100%|████████| 200/200 [00:26<00:00,  7.63it/s]


Validation:


Validation Accuracy: 0.7831: 100%|████████████████| 2/2 [00:00<00:00,  7.79it/s]

Epoch 1, Avg Loss: 0.6389,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001





Loading best phi with loss: 0.9341
Layer1 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.5906,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.0583
Layer1 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.6887,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.1079
Layer1 : [[0], [1]]
Epoch 1, Avg Loss: 0.7905,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.2373
Layer2 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.5334,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.0179
Layer2 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.6164,  Steps w/o Improvement: 0,  Eval Loss (End of Epoc

Loss: 0.2998, Batch Acc: 0.8755: 100%|████████| 200/200 [00:19<00:00, 10.31it/s]


Validation:


Validation Accuracy: 0.8725: 100%|████████████████| 2/2 [00:00<00:00, 14.96it/s]

Epoch 1, Avg Loss: 0.5495,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001





Loading best phi with loss: 0.6131
Layer1 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.5347,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 0.7860
Layer1 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.8260,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.0118
Layer1 : [[0], [1]]
Epoch 1, Avg Loss: 1.0216,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.4857
Layer2 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.4281,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 0.6907
Layer2 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.5666,  Steps w/o Improvement: 0,  Eval Loss (End of Epoc

Loss: 0.4880, Batch Acc: 0.7797: 100%|████████| 200/200 [00:20<00:00,  9.89it/s]


Validation:


Validation Accuracy: 0.7795: 100%|████████████████| 2/2 [00:00<00:00, 16.42it/s]

Epoch 1, Avg Loss: 0.6493,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001





Loading best phi with loss: 0.9614
Layer1 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.5820,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.0737
Layer1 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.6908,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.1638
Layer1 : [[0], [1]]
Epoch 1, Avg Loss: 0.7374,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.1684
Layer2 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.5273,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.0390
Layer2 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.6216,  Steps w/o Improvement: 0,  Eval Loss (End of Epoc

Loss: 0.5456, Batch Acc: 0.7231: 100%|████████| 200/200 [00:28<00:00,  7.07it/s]


Validation:


Validation Accuracy: 0.7269: 100%|████████████████| 2/2 [00:00<00:00,  6.44it/s]

Epoch 1, Avg Loss: 0.6572,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001





Loading best phi with loss: 1.0798
Layer1 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.5861,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.1244
Layer1 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.7040,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.1681
Layer1 : [[0], [1]]
Epoch 1, Avg Loss: 0.7596,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.2690
Layer2 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.5883,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.1536
Layer2 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.6678,  Steps w/o Improvement: 0,  Eval Loss (End of Epoc

Loss: 0.3579, Batch Acc: 0.8384: 100%|████████| 200/200 [00:17<00:00, 11.28it/s]


Validation:


Validation Accuracy: 0.8385: 100%|████████████████| 2/2 [00:00<00:00, 14.31it/s]

Epoch 1, Avg Loss: 0.5829,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001





Loading best phi with loss: 0.7193
Layer1 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.5579,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 0.8986
Layer1 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.7321,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 0.9812
Layer1 : [[0], [1]]
Epoch 1, Avg Loss: 0.9129,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 1.4622
Layer2 : [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]
Epoch 1, Avg Loss: 0.4422,  Steps w/o Improvement: 0,  Eval Loss (End of Epoch): inf , LR for next epoch (base): 0.001
Loading best phi with loss: 0.8426
Layer2 : [[0, 1], [2, 3]]
Epoch 1, Avg Loss: 0.6505,  Steps w/o Improvement: 0,  Eval Loss (End of Epoc