In [None]:
from omegaconf import OmegaConf
import numpy as np
import os
import re
import os.path as osp
import torch
import pytorch_lightning as pl
from tqdm import tqdm
from omegaconf import OmegaConf
import wandb
from pytorch_lightning.loggers import WandbLogger
from glob import glob
import time
import yaml

import model_factory
from graph_data_module import GraphDataModule
from train import Runner
from datasets_torch_geometric.dataset_factory import create_dataset
from torch_geometric.loader import DataLoader
from utils.config_utils import recursive_dict_compare
from datatransforms.event_transforms import FilterNodes
import torch_geometric.transforms as T

In [None]:
entity = "haraghi"
project = "sweep EST (FAN1VS3) (multi val test num 20)"
# project = "sweep EST (FAN1VS3) 25000 (multi val test num 20)"
# project = "sweep_EST_NCALTECH101_1024_multi20"
run_id = 'm529xtfz'


checkpoint_file = glob(osp.join(project, run_id, "checkpoints","*"))
if checkpoint_file:
    assert len(checkpoint_file) == 1
    checkpoint_file = checkpoint_file[0]
    print("loading checkpoint from", checkpoint_file)
else:
    checkpoint_file = glob(osp.join(run_id, "checkpoints","*"))
    if checkpoint_file:
        assert len(checkpoint_file) == 1
        checkpoint_file = checkpoint_file[0]
        print("loading checkpoint from", checkpoint_file)
    else:
        raise ValueError("no checkpoint file found")

In [None]:
api = wandb.Api()

In [None]:
cfg_bare = OmegaConf.load("config_bare.yaml")
config = api.run(osp.join(entity, project, run_id)).config
cfg = OmegaConf.create(config) 

if "cfg_path" in cfg.keys():
    print(cfg.cfg_path)
    cfg_file = OmegaConf.merge(cfg_bare,OmegaConf.load(cfg.cfg_path))
else:
    cfg_file = cfg
cfg = OmegaConf.merge(cfg_file, cfg)


In [None]:
print(50*"=")
print("cfg_file")
print(50*"-")
print(yaml.dump(recursive_dict_compare(OmegaConf.to_object(cfg),OmegaConf.to_object(cfg_file)), default_flow_style=False))
print(50*"=")
print("cfg")
print(50*"-")
print(yaml.dump(recursive_dict_compare(OmegaConf.to_object(cfg_file),OmegaConf.to_object(cfg)), default_flow_style=False))

In [None]:
cfg.dataset.num_workers = 2
gdm = GraphDataModule(cfg)
cfg.dataset.num_classes = gdm.num_classes

In [None]:
folder_address = osp.join("landscape_plots",project,run_id)
if not osp.exists(folder_address):
    os.makedirs(folder_address)

trainloader = gdm.train_dataloader()
torch.save(trainloader, osp.join(folder_address,"trainloader.pt"))
testloader = gdm.test_dataloader()[0]
torch.save(testloader, osp.join(folder_address,"testloader.pt"))

In [None]:
OmegaConf.save(cfg, osp.join(folder_address, "cfg.yaml"))

In [None]:
model = model_factory.factory(cfg)
runner = Runner.load_from_checkpoint(checkpoint_path = checkpoint_file, cfg=cfg, model=model)
torch.save(runner.model.state_dict(),osp.join(folder_address,"state_dict.pt"))

In [None]:
trainer = pl.Trainer(
    enable_progress_bar=True,
    # Use DDP training by default, even for CPU training
    # strategy="ddp_notebook",
    devices=torch.cuda.device_count(),
    accelerator="auto"
)

In [None]:
print(runner.cfg.model)
results = trainer.test(runner, datamodule=gdm)

In [None]:
# surf_file = "multiple_train_loss_*.h5"

folder_address = osp.join("landscape_plots",project,run_id)
desired_args = [
    '--log', '--cuda', '--mpi', '--dataset', 'from_file', '--model', 'EST',
    '--x=-1:1:101',
    '--dir_type', 'states', '--xignore', 'biasbn', '--xnorm', 'filter',
    '--model_folder', folder_address,
    '--model_file', osp.join(folder_address,"state_dict.pt"),
    '--testloader', osp.join(folder_address,"testloader.pt"),
    '--trainloader', osp.join(folder_address,"trainloader.pt"),
    '--surf_file', osp.join(folder_address,surf_file)
]

dir_file_list = glob(osp.join(folder_address,"state_dict.pt*.h5"))
if dir_file_list:
    assert len(dir_file_list) == 1
    desired_args.append(['--dir_file', dir_file_list[0]])

In [None]:
# second_script.py
import subprocess

# Construct the command to run the first script with the desired arguments
command = ['python', 'loss-landscape/plot_surface.py'] + desired_args

# Call the first script with subprocess
subprocess.run(command)


In [None]:
    
class FilterNodesFixedEvents(FilterNodes):
    
    def __init__(self, num_indices):
        super().__init__()
        self.num_indices = num_indices
        self.indices = None

    def get_indices(self,data):
        if self.indices is None:
            self.indices = torch.randperm(data.num_nodes)[:self.num_indices]
        return self.indices

In [None]:
dataset_name = "FAN1VS3"
dataset_path  = osp.join('datasets_torch_geometric', dataset_name, 'data')
num_events_per_sample = cfg.transform.train.num_events_per_sample
dataset = create_dataset(
                dataset_path = dataset_path,
                dataset_name  = dataset_name,
                dataset_type = 'test',
                transform = T.Compose([FilterNodesFixedEvents(num_events_per_sample)]),
                num_workers=2
            )
dataset_random = create_dataset(
                dataset_path = dataset_path,
                dataset_name  = dataset_name,
                dataset_type = 'test',
                transform = T.Compose([T.FixedPoints(num_events_per_sample, replace = False, allow_duplicates = True)]),
                num_workers=2
            )

In [None]:
train_dataloader_0 = DataLoader(
     [dataset[0]],
     batch_size=1,
     shuffle=False,
     num_workers=1)

In [None]:
torch.save(train_dataloader_0, osp.join(folder_address,"trainloader_0.pt"))

In [None]:
loader_loaded_0 = torch.load(osp.join(folder_address,"trainloader_0.pt"))

In [None]:
for data in train_dataloader_0:
    print(data.pos.int())

In [None]:
for data in loader_loaded_0:
    print(data.pos.int())