In [None]:
import os
import pickle
from pathlib import Path

from gnn_tracking.models.graph_construction import GraphConstructionFCNN
from tcn_trainer_mdmm import TCNTrainer

from gnn_tracking.metrics.losses import GraphConstructionHingeEmbeddingLoss
from gnn_tracking.utils.loading import get_loaders, TrackingDataset

In [None]:
graph_dir = Path("D:\Devdoot\Princeton RSE\dataset\graph constructed new")
assert graph_dir.is_dir()

In [None]:
datasets = {
    "train": TrackingDataset(graph_dir, stop=810),
    "val": TrackingDataset(graph_dir, start=810, stop=900),
}
loaders = get_loaders(datasets, batch_size=1)

In [None]:
def list_of_dicts_to_dict_of_lists(list_of_dicts):
    dict_of_lists = {}
    for dictionary in list_of_dicts:
        for key, value in dictionary.items():
            dict_of_lists.setdefault(key, []).append(value)
    return dict_of_lists

In [None]:
def train_model_mdmm(epsilon,
                     damping,
                     main_weight,
                     constraint_weight,
                     loaders,
                     main_loss="attractive",
                     constraint_loss="repulsive",
                     num_epochs=50
                    ):
    
    main_loss_functions = {
        "embedding_loss": (GraphConstructionHingeEmbeddingLoss(), {main_loss: main_weight}),
    }
    constraint_loss_functions = {
        "embedding_loss": (GraphConstructionHingeEmbeddingLoss(), {constraint_loss: (constraint_weight, epsilon, damping)}),
    }

    model = GraphConstructionFCNN(
        in_dim = 14,
        hidden_dim = 64,
        out_dim = 10,
        depth = 4,
        beta = 0.4
    )

    trainer = TCNTrainer(
        model=model,
        loaders=loaders,
        main_loss_functions=main_loss_functions,
        constraint_loss_functions=constraint_loss_functions,
        lr=0.005,
    )

    loss_history = trainer.train(epochs=num_epochs)
    return list_of_dicts_to_dict_of_lists(loss_history)

In [None]:
def save_loss_history(constraints, path, main_loss="attractive", constraint_loss="repulsive"):
    if not os.path.exists(path):
        os.makedirs(path)
    
    for constraint in constraints:
        print(f'Training for scaling coefficients = {constraint}')
        loss_history = train_model_mdmm(epsilon=constraint[0],
                                damping=constraint[1],
                                main_weight=constraint[2],
                                constraint_weight=constraint[3],
                                loaders=loaders,
                                main_loss=main_loss,
                                constraint_loss=constraint_loss,
                                num_epochs=10)
        model_dict = {'loss_history':loss_history,
                      'epsilon':constraint[0],
                      'damping':constraint[1],
                      'weight':constraint[2]}
        
        file_path = os.path.join(path, f'{constraint[0]}_{constraint[1]}_{constraint[2]}.pkl')

        with open(file_path, "wb") as f:
            pickle.dump(model_dict,f)
            f.close()
        print("\n")

In [None]:
constraints_1 = [(0.0009653154573041118, 10.0, 1.0, 1.0),
               (0.0004771597001106582, 10.0, 1.0, 1.0),
               (0.00028893887271952276, 10.0, 1.0, 1.0),
               (0.00016839923191582784, 10.0, 1.0, 1.0),
               (0.00010062895363475553, 10.0, 1.0, 1.0)]

In [None]:
constraints_2 = [(0.0009653154573041118, 5.0, 1.0, 1.0),
               (0.0004771597001106582, 5.0, 1.0, 1.0),
               (0.00028893887271952276, 5.0, 1.0, 1.0),
               (0.00016839923191582784, 5.0, 1.0, 1.0),
               (0.00010062895363475553, 5.0, 1.0, 1.0)]

In [None]:
constraints_3 = [(0.0009653154573041118, 1.0, 1.0, 1.0),
               (0.0004771597001106582, 1.0, 1.0, 1.0),
               (0.00028893887271952276, 1.0, 1.0, 1.0),
               (0.00016839923191582784, 1.0, 1.0, 1.0),
               (0.00010062895363475553, 1.0, 1.0, 1.0)]

In [None]:
save_loss_history(constraints_1, "loss_histories/damping_10")

In [None]:
save_loss_history(constraints_2, "loss_histories/damping_05")

In [None]:
save_loss_history(constraints_3, "loss_histories/damping_01")

In [None]:
constraints_4 = [(0.0009300140692725962, 1.0, 1.0, 1.0),
               (0.0008121351376578304, 1.0, 1.0, 1.0),
               (0.0006466396967880428, 1.0, 1.0, 1.0),
               (0.0005346179523010864, 1.0, 1.0, 1.0),
               (0.000420255135826732, 1.0, 1.0, 1.0)]

In [None]:
save_loss_history(constraints_4, "loss_histories/damping_01_attractive",
                  main_loss="repulsive", constraint_loss="attractive")