In [None]:
from omegaconf import OmegaConf
import numpy as np
import os
import re
import os.path as osp
import torch
import pytorch_lightning as pl
from tqdm import tqdm
from omegaconf import OmegaConf
import wandb
from pytorch_lightning.loggers import WandbLogger
from glob import glob
import time

import model_factory
from graph_data_module import GraphDataModule
from train import Runner
from datasets_torch_geometric.dataset_factory import create_dataset
from torch_geometric.loader import DataLoader

In [None]:
entity = "haraghi"
project = "sweep EST (FAN1VS3) (multi val test num 20)"
# project = "sweep EST (FAN1VS3) 25000 (multi val test num 20)"
# project = "sweep_EST_NCALTECH101_1024_multi20"
run_ids = ['i8ng54ek']

In [None]:
run_id = run_ids[0]
artifact_dirs = []
for run_id in run_ids:
    glob_results = glob(osp.join(project, run_id, "checkpoints","*"))
    assert len(glob_results) == 1
    artifact_dirs.append(glob_results[0])

In [None]:
api = wandb.Api()

In [None]:
cfg_bare = OmegaConf.load("config_bare.yaml")
configs = [api.run(osp.join(entity, project, run_id)).config for run_id in run_ids]
cfgs = [OmegaConf.create(config) for config in configs]
cfg_files = []
for cfg in cfgs:
    if "cfg_path" in cfg.keys():
        print(cfg.cfg_path)
        cfg_files.append(OmegaConf.merge(cfg_bare,OmegaConf.load(cfg.cfg_path)))
    else:
        cfg_files.append(cfg)
            

cfgs = [OmegaConf.merge(cfg_file, cfg) for cfg_file, cfg in zip(cfg_files, cfgs)]
# print(OmegaConf.to_yaml(cfgs[0]))

In [None]:
def recursive_dict_compare(all_cfg, other_cfg):
    """
    Recursively compare two dictionaries and return their differences.
    """

    
    # Initialize the result dictionary
    diff = {}

    # Check for keys in dict1 that are not in dict2
    for key in other_cfg:
        if key not in all_cfg:
            diff[key] = other_cfg[key]
        else:
            # If the values are dictionaries, recursively compare them
            if isinstance(all_cfg[key], dict) and isinstance(other_cfg[key], dict):
                nested_diff = recursive_dict_compare(all_cfg[key], other_cfg[key])
                if nested_diff:
                    diff[key] = nested_diff
            # Otherwise, compare the values directly
            elif all_cfg[key] != other_cfg[key]:
                if not(key == "num_classes" and other_cfg[key] is None and all_cfg[key] is not None):
                    diff[key] = other_cfg[key]
                    

    return diff


In [None]:
print([recursive_dict_compare(OmegaConf.to_object(cfg),OmegaConf.to_object(cfg_file)) for cfg, cfg_file in zip(cfgs, cfg_files)])

In [None]:
# Seed everything. Note that this does not make training entirely
# deterministic.
for cfg in cfgs:
    pl.seed_everything(cfg.seed, workers=True)

for cfg in cfgs[1:]:
    compare_dict = recursive_dict_compare(OmegaConf.to_object(cfgs[0].dataset),OmegaConf.to_object(cfg.dataset))
    if len(compare_dict)!=0:
        if not (len(compare_dict) == 1 and 'num_workers' in compare_dict.keys()):
            print(compare_dict)
            print(cfg.dataset)
            print(cfgs[0].dataset)
            # raise Exception("Datasets are not the same")
# Create datasets using factory pattern


gdm = GraphDataModule(cfgs[0])
for cfg in cfgs:
    cfg.dataset.num_classes = gdm.num_classes

In [None]:
cfgs[0].dataset.num_workers = 2

In [None]:
gdm = GraphDataModule(cfgs[0])

In [None]:
folder_address = "landscape_plots/EST_NCALTECH101_1024_vzjrvjlz"

trainloader = gdm.train_dataloader()
torch.save(trainloader, osp.join(folder_address,"trainloader.pt"))
testloader = gdm.test_dataloader()[0]
torch.save(testloader, osp.join(folder_address,"testloader.pt"))

In [None]:
OmegaConf.save(cfgs[0], osp.join(folder_address, "cfg.yaml"))

In [None]:
models = [model_factory.factory(cfg) for cfg in cfgs]
runners = [Runner.load_from_checkpoint(artifact_dir, cfg=cfg, model=model) for artifact_dir, cfg, model in zip(artifact_dirs, cfgs, models)]
torch.save(runners[0].model.state_dict(),osp.join(folder_address,"state_dict.pt"))

In [None]:
trainer_dl = torch.load("landscape_plots/EST_FAN1VS3_1024_i8ng54ek_092/trainloader.pt")

In [None]:
models = [model_factory.factory(cfg) for cfg in cfgs]
runners = [Runner(cfg=cfg, model=model) for cfg, model in zip(cfgs, models)]

In [None]:
trainer = pl.Trainer(
    enable_progress_bar=True,
    # Use DDP training by default, even for CPU training
    # strategy="ddp_notebook",
    devices=torch.cuda.device_count(),
    accelerator="auto"
)

In [None]:
for runner in runners:
    print(runner.cfg.model)
    trainer.test(runner, datamodule=gdm)