In [1]:
import sys
sys.path.append("../")
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

import torch
import torch.nn as nn
import numpy as np

from HyperSINDy import Net
from baseline import Trainer
from library_utils import Library
from Datasets import SyntheticDataset
from other import init_weights, set_random_seed

from exp_utils import get_equations, log_equations

import seaborn as sns
import matplotlib.pyplot as plt
from tabulate import tabulate

sns.set()

In [2]:
def load_model(device, z_dim, poly_order, include_constant,
               noise_dim, hidden_dim, stat_size, batch_size,
               num_hidden, batch_norm, cp_path):

    torch.cuda.set_device(device=device)
    device = torch.cuda.current_device()

    library = Library(n=z_dim, poly_order=poly_order, include_constant=include_constant)

    net = Net(library, noise_dim=noise_dim, hidden_dim=hidden_dim,
              statistic_batch_size=stat_size,
              num_hidden=num_hidden, batch_norm=batch_norm).to(device)

    cp = torch.load(cp_path, map_location="cuda:" + str(device)) 
    net.load_state_dict(cp['model'])
    net.to(device)
    
    return net, library, device

def gather_data(coefs_mean, coefs_std, feature_names, nonzero, z_dim):
    coefs_mean = np.round(coefs_mean, 2)
    coefs_std = np.round(coefs_std, 2)
    res = []
    for i in range(z_dim):
        cm, cs = coefs_mean[i], coefs_std[i]
        curr_list = []
        for j in range(len(feature_names)):
            if nonzero[i][j]:
                curr_list.append((feature_names[j], cm[j], cs[j])) # could also append cd[i][j]
        res.append(curr_list)
    return res

def get_coef_stats(net, batch_size=1000, device=2):
    coefs = net.get_masked_coefficients(batch_size=1000, device=device).detach().cpu().numpy()
    coefs_t = np.transpose(coefs, (2, 1, 0))
    nonzero = coefs_t.mean(2) != 0
    coefs_mean, coefs_std = np.mean(coefs_t, 2), np.std(coefs_t, 2)
    return coefs_mean, coefs_std, nonzero

def build_table(net, batch_size=1000, device=2, print_file=False, print_fancy=True, print_latex=False):
    set_random_seed(SEED)
    coefs_mean, coefs_std, nonzero = get_coef_stats(net, batch_size, device)

    feature_names = net.library.get_feature_names()

    data = gather_data(coefs_mean, coefs_std, feature_names, nonzero, net.z_dim)
    
    eq_starts = ["dx" + str(i + 1) for i in range(z_dim)]
    terms = np.array(['x' + str(i + 1) for i in range(z_dim)])
    gts = {}
    for i in range(len(eq_starts)):
        curr_start = eq_starts[i]
        eq_terms = [(8, "")]
        eq_terms.append((-1, terms[i]))
        eq_terms.append((1, terms[(i + 1) % z_dim] + terms[i - 1]))
        eq_terms.append((-1, terms[i - 2] + terms[i - 1]))
        gts[curr_start] = eq_terms
            

    table = []
    ct1 = 0
    eq_ct = 0
    for eq in data:
        ct = 0
        curr_eq_start = eq_starts[eq_ct]
        curr_true_eq = gts[curr_eq_start]
        for term in eq:
            curr_term = term[0]
            true_coef = 0
            for true_term in curr_true_eq:
                if curr_term == true_term[1]:
                    true_coef = true_term[0]
                else:
                    t_idx = np.char.find(curr_term, 'x', start=1)
                    if (curr_term[t_idx:] + curr_term[0:t_idx]) == true_term[1]:
                        true_coef = true_term[0]
                    
            if ct == 0:
                row = [eq_starts[eq_ct], term[0], str(true_coef), term[1], term[2]]
            else:
                row = ["", term[0],  str(true_coef), term[1], term[2]]
            table.append(row)
            ct += 1

        # loop through true terms to see if learned model missed any terms
        for true_term in curr_true_eq:
            found = False
            for term in eq:
                if term[0] == true_term[1]:
                    found = True
                else:
                   # flip it to avoid case of x1x3 != x3x1
                    t_idx = np.char.find(term[0], 'x', start=1)
                    if (term[0][t_idx:] + term[0][0:t_idx]) == true_term[1]:
                        found = True
            if not found:
                row = ["", true_term[1], str(true_term[0]), 0, 0]
                table.append(row)

        eq_ct += 1
        ct1 += 1

    headers = ["EQUATION", "TERM", "TRUE", "MEAN", "STD"]
    fancy_table = tabulate(table, headers, tablefmt="fancy_outline")
    latex_table = tabulate(table, headers, tablefmt="latex")
    if print_fancy:
        print(fancy_table)
    if print_latex:
        print(latex_table)
    if print_file:
        with open("table.txt", "w") as f:
            print(fancy_table, file=f)
        with open("table_latex.txt", "w") as f:
            print(latex_table, file=f)

# Load Model

In [3]:
SEED = 5281998

In [4]:
data_folder = "../data/"
model = "HyperSINDy"
dt = 0.01
hidden_dim = 128
stat_size = 250
num_hidden = 5
z_dim = 10
adam_reg = 1e-2
gamma_factor = 0.999
poly_order = 3
include_constant = True
device = 2
batch_norm = False
noise_dim = 20
runs = "../runs/lorenz96"
library = Library(n=z_dim, poly_order=poly_order, include_constant=include_constant)

In [5]:
net1, library, device = load_model(device, z_dim, poly_order, include_constant,
                                  noise_dim, hidden_dim, stat_size, stat_size,
                                  num_hidden, batch_norm, runs + "/cp_1.pt")

# Table

In [6]:
eq1 = get_equations(net1, library, model, device, seed=SEED)
all_eqs = [eq1]

# print equations in a latext friendly format

In [7]:
eq_starts = ["dx" + str(i + 1) for i in range(z_dim)]
terms = np.array(['x' + str(i + 1) for i in range(z_dim)])
gts = {}
for i in range(len(eq_starts)):
    curr_start = eq_starts[i]
    eq_terms = [(8, "")]
    eq_terms.append((-1, terms[i]))
    eq_terms.append((1, terms[(i + 1) % z_dim] + terms[i - 1]))
    eq_terms.append((-1, terms[i - 2] + terms[i - 1]))
    gts[curr_start] = eq_terms

In [8]:
gts

{'dx1': [(8, ''), (-1, 'x1'), (1, 'x2x10'), (-1, 'x9x10')],
 'dx2': [(8, ''), (-1, 'x2'), (1, 'x3x1'), (-1, 'x10x1')],
 'dx3': [(8, ''), (-1, 'x3'), (1, 'x4x2'), (-1, 'x1x2')],
 'dx4': [(8, ''), (-1, 'x4'), (1, 'x5x3'), (-1, 'x2x3')],
 'dx5': [(8, ''), (-1, 'x5'), (1, 'x6x4'), (-1, 'x3x4')],
 'dx6': [(8, ''), (-1, 'x6'), (1, 'x7x5'), (-1, 'x4x5')],
 'dx7': [(8, ''), (-1, 'x7'), (1, 'x8x6'), (-1, 'x5x6')],
 'dx8': [(8, ''), (-1, 'x8'), (1, 'x9x7'), (-1, 'x6x7')],
 'dx9': [(8, ''), (-1, 'x9'), (1, 'x10x8'), (-1, 'x7x8')],
 'dx10': [(8, ''), (-1, 'x10'), (1, 'x1x9'), (-1, 'x8x9')]}

In [9]:
def reformat(eqs, gts, eq_starts, filename=None):
    eq_ct = 0
    for eq in eqs:
        if eq == "MEAN":
            continue
        if eq == "STD":
            eq_ct = 0
            continue
            
        curr_eq_start = eq_starts[eq_ct]
        curr_true_eq = gts[curr_eq_start]
        
        eq = eq.split(" ")
        
        result = ""
        for i in range(len(eq)):
            curr_term = eq[i]
            if curr_term[0:2] == "dx":
                result += "\dot{x}_{" + curr_term[2:] + "}"
            elif curr_term == "=":
                result += " = "
            elif "x" not in curr_term and curr_term != "+":
                result += curr_term + " "
            elif curr_term == "+":
                if eq[i + 1][0] == "-":
                    result += "- "
                    next_term = eq[i + 1][1:]
                else:
                    result += "+ "
                    next_term = eq[i + 1]
                next_term = next_term.split("x")
                coef = next_term[0]
                result += coef
                for j in range(1, len(next_term)):
                    result += "x_{" + next_term[j] + "}"
                result += " "
                    
        print(result)
        if filename is not None:
            print(result, file=filename)
    print()
    if filename is not None:
        print(file=filename)  
        
        
        eq_ct += 1

In [10]:
with open("../results/lorenz96.txt", "w") as f:
    for curr_eqs in all_eqs:
        reformat(curr_eqs, gts, eq_starts, f)

\dot{x}_{1} = 6.75 - 0.75x_{1} + 0.99x_{2}x_{10} - 0.99x_{9}x_{10} 
\dot{x}_{2} = 8.46 - 0.74x_{2} + 0.98x_{1}x_{3} - 0.99x_{1}x_{10} 
\dot{x}_{3} = 7.59 - 0.71x_{3} - 0.99x_{1}x_{2} + 0.98x_{2}x_{4} 
\dot{x}_{4} = 7.89 - 0.77x_{4} - 0.97x_{2}x_{3} + 0.97x_{3}x_{5} 
\dot{x}_{5} = 6.7 - 0.81x_{5} - 0.97x_{3}x_{4} + 0.97x_{4}x_{6} 
\dot{x}_{6} = 7.86 - 0.76x_{6} - 0.99x_{4}x_{5} + 0.99x_{5}x_{7} 
\dot{x}_{7} = 7.31 - 0.75x_{7} - 0.99x_{5}x_{6} + 0.98x_{6}x_{8} 
\dot{x}_{8} = 7.56 - 0.76x_{8} - 0.98x_{6}x_{7} + 0.98x_{7}x_{9} 
\dot{x}_{9} = 7.45 - 0.74x_{9} - 0.98x_{7}x_{8} + 0.99x_{8}x_{10} 
\dot{x}_{10} = 6.49 - 0.75x_{10} + 0.99x_{1}x_{9} - 0.99x_{8}x_{9} 
\dot{x}_{1} = 8.44 + 0.04x_{1} + 0.02x_{2}x_{10} + 0.01x_{9}x_{10} 
\dot{x}_{2} = 7.98 + 0.05x_{2} + 0.02x_{1}x_{3} + 0.02x_{1}x_{10} 
\dot{x}_{3} = 8.1 + 0.05x_{3} + 0.01x_{1}x_{2} + 0.02x_{2}x_{4} 
\dot{x}_{4} = 7.72 + 0.04x_{4} + 0.01x_{2}x_{3} + 0.01x_{3}x_{5} 
\dot{x}_{5} = 7.65 + 0.04x_{5} + 0.02x_{3}x_{4} + 0.01x_{4}x_{6} 
\do