# Testing Influencer GNN

In [8]:
%load_ext autoreload
%autoreload 2

# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from torch_geometric.data import Batch
from tqdm import tqdm

import warnings

warnings.filterwarnings("ignore")
sys.path.append("../")
device = "cuda" if torch.cuda.is_available() else "cpu"

from lightning_modules.toyGNN.submodels.influencer_gravnet import InfluencerGravnet


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Download Images

In [2]:
import wandb
import plotly.graph_objects as go
import PIL
import io

In [3]:
# Use wandb api to pull files
api = wandb.Api()

In [4]:
# Load config file
with open("montage_config.yaml") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [5]:
# Download the files for a given run
root_dir = "/global/cfs/cdirs/m3443/usr/dtmurnane/InfluencerNet"
run = api.run(f"murnanedaniel/{config['project']}/{config['run']}")
files = run.files()

In [15]:
# Get all files with title containing "original_space"
original_space_files = [file for file in files if "original_space" in file.name]
original_space_files = sorted(original_space_files, key=lambda x: int(x.name.split("_")[2]))

# Get all files with title containing "embeddings"
embedding_files = [file for file in files if "embeddings" in file.name]
embedding_files = sorted(embedding_files, key=lambda x: int(x.name.split("_")[1]))

In [16]:
for file in tqdm(original_space_files):
    file.download(replace=True, root=f"{root_dir}/{config['run']}")

for file in tqdm(embedding_files):
    file.download(replace=True, root=f"{root_dir}/{config['run']}")

100%|██████████| 42/42 [00:12<00:00,  3.37it/s]
100%|██████████| 42/42 [00:18<00:00,  2.25it/s]


In [17]:
original_space_frames = []
for file in tqdm(original_space_files):
    with open(f"{root_dir}/{config['run']}/{file.name}") as f:
        data = yaml.load(f, Loader=yaml.FullLoader)
        fig = go.Figure(data=data)
        original_space_frames.append(PIL.Image.open(io.BytesIO(fig.to_image(format="png"))))

embedding_frames = []
for file in tqdm(embedding_files):
    with open(f"{root_dir}/{config['run']}/{file.name}") as f:
        data = yaml.load(f, Loader=yaml.FullLoader)
        fig = go.Figure(data=data)
        embedding_frames.append(PIL.Image.open(io.BytesIO(fig.to_image(format="png"))))

100%|██████████| 42/42 [00:03<00:00, 12.61it/s]
100%|██████████| 42/42 [04:58<00:00,  7.11s/it]


In [20]:
original_space_frames[0].save(
        "original_space_test.gif",
        save_all=True,
        append_images=original_space_frames[1:],
        optimize=True,
        duration=400,
    )

embedding_frames[0].save(
        "embedding_test.gif",
        save_all=True,
        append_images=embedding_frames[1:],
        optimize=True,
        duration=400,
    )