# Investiagtion

This notebooks aims to investigate various WARM solving algorithms

In [2]:
import numpy as np
from igraph import Graph
from warm import WarmModel

In [None]:
graph = Graph(n=3, edges=[[0, 1], [0, 2]])
model = WarmGraph(graph)
solution = model.analyze(np.array([0.5, 0.5]))
solution

In [4]:
solution.y

array([[ 0.5       ,  0.41659474, -0.09303513, -0.13201969],
       [ 0.5       ,  0.41659474, -0.09303513, -0.13201969]])

In [3]:
import json
import logging
from typing import TypedDict, Generic, TypeVar, Optional, NamedTuple

import igraph as ig
import pandas as pd

import warm
import sqlite3

from dash import Dash, dcc, html, Input, Output, callback
import dash_cytoscape as cyto
import plotly.express as px

from ContextManagers import DatabaseManager
from config import config

logger = logging.getLogger(__name__)

# Canvas dimension.
width = 400
height = 400


T = TypeVar("T")
SimMap= dict[tuple[str, int, int], warm.WarmSimData]
GraphFamilies = ["ring_2d"]


class PositionDict(TypedDict):
    x: int
    y: int


class NodeDict(TypedDict):
    id: Optional[str]
    name: Optional[str]
    label: Optional[str]
    name: Optional[str]
    value: Optional[str]


class EdgeDict(TypedDict):
    label: Optional[str]
    source: str
    target: str


class CytoElement(TypedDict):
    data: dict
    position: Optional[PositionDict]


class InfoTuple(NamedTuple):
    model: warm.WarmModel
    solution: warm.WarmSolution


def escape_string(string: str, escape_symb = "\\", escape_chars: list[str] = ("%", "_")):
    """ Escape the original string. """
    out_buf = []

    for s in string:
        if s in escape_chars:
            out_buf.append(escape_symb)
        out_buf.append(s)

    return "".join(out_buf)


def parse_model(model: warm.WarmModel) -> list[CytoElement]:
    nodes: list[NodeDict]
    edges: list[EdgeDict]
    nodes, edges = model.graph.to_dict_list(use_vids=False)

    # Change 'name' into 'id' for the nodes. I know.
    for node in nodes:
        node["id"] = node["name"]
        node["label"] = node["name"]
        del node["name"]  # For efficiency

    # Add edge labels.
    for i, edge in enumerate(edges):
        edge["label"] = str(i)

    # Compute layouts. Default layout should be between 0 and 1 so scales can
    # be applied consistently for a canvas of arbitrary size.
    layout = warm.ring_2d_layout(model.bins_count // 2)
    positions = [PositionDict(x=coord[0] * width, y=coord[1] * height) for coord in layout]

    # Wrap into Cyto and output the result.
    elements: list[CytoElement] = [CytoElement(
        data=node,
        position=positions[int(node["id"])]
    ) for node in nodes] + [CytoElement(
        data=edge
    ) for edge in edges]
    return elements


def fetch_data(graph_family_selected: str) -> tuple[pd.DataFrame, dict[str, InfoTuple]]:
    if graph_family_selected in GraphFamilies:
        # Fetch data.
        with DatabaseManager("../data.sqlite") as db:
            logger.info("Data fetching commenced")
            id_pattern = escape_string(graph_family_selected) + "%"
            df = pd.read_sql(f"""
                SELECT simId, trialId, endTime, root, counts
                FROM SimData
                WHERE simId LIKE ? ESCAPE '\\'
            """, con=db, params=[id_pattern])

            cursor: sqlite3.Cursor
            cursor = db.cursor()
            cursor.execute(f"""
                SELECT simId, model, solution
                FROM SimInfo
                WHERE simId LIKE ? ESCAPE '\\'
            """, [id_pattern])
            sims = {
                x[0]: InfoTuple(
                    model=warm.WarmModel.from_json(x[1]),
                    solution=warm.WarmSolution.from_json(x[2])
                ) for x in cursor.fetchall()
            }
            logger.info("Data fetched and transformed successfully")
            return df, sims
    else:
        logger.warning("Don't try to hack me! I am vulnerable.")


def update_slider_props(df: pd.DataFrame):
    model_indices = df["endTime"].unique()
    marks = {x: str(x) for x in range(min(model_indices), max(model_indices) + 1)}
    return min(model_indices), max(model_indices), marks


# Construct the default values.
df, sims = fetch_data("ring_2d")


ImportError: attempted relative import with no known parent package

In [33]:

filtered_df = df[df["simId"] == "ring_2d_5"]
filtered_df["counts"] = filtered_df["counts"].str.strip("[]")
ncol = filtered_df["counts"].str.split(",").transform(len).max()
x_names = [f"x_{i}" for i in range(ncol)]
filtered_df[x_names] = filtered_df["counts"].str.split(',', expand=True).astype("int")
filtered_df[x_names] = filtered_df[x_names].div(filtered_df["endTime"] + ncol, axis=0)
mean_df = filtered_df.groupby("endTime")[x_names].mean().reset_index()

fig = px.scatter(mean_df, x="endTime", y="x_0")
fig




A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/

In [4]:
import plotly.express as px

filtered_df = df[df["simId"] == "ring_2d_5"]
filtered_df["counts"] = filtered_df["counts"].str.strip("[]")
ncol = filtered_df["counts"].str.split(",").transform(len).max()
x_names = [f"x_{i}" for i in range(ncol)]
filtered_df[x_names] = filtered_df["counts"].str.split(',', expand=True).astype("int")
filtered_df[x_names] = filtered_df[x_names].div(filtered_df["endTime"] + ncol, axis=0)
var_df = filtered_df.groupby("endTime")[x_names].var().reset_index()

fig = px.scatter(var_df, x="endTime", y="x_0")
fig


NameError: name 'df' is not defined