In [4]:
import os
import ast

def find_imports(repo_path):
    imports = set()

    for root, _, files in os.walk(repo_path):
        for file in files:
            if file.endswith('.py'):
                file_path = os.path.join(root, file)
                try:
                    with open(file_path, 'r') as f:
                        tree = ast.parse(f.read())
                        for node in ast.walk(tree):
                            if isinstance(node, ast.Import):
                                for alias in node.names:
                                    imports.add(f"import {alias.name}")
                            elif isinstance(node, ast.ImportFrom):
                                module = node.module if node.module else ''
                                names = ', '.join(alias.name for alias in node.names)
                                imports.add(f"from {module} import {names}")
                except Exception as e:
                    print(f"Error parsing {file_path}: {e}")

    return sorted(imports)

# Usage
repo_path = r"C:\repo\nranthony\hamiltonian-nn"
all_imports = find_imports(repo_path)
for imp in all_imports:
    print(imp)

from data import get_dataset
from hnn import HNN
from hnn import HNN, PixelHNN
from nn_models import MLP
from nn_models import MLPAutoencoder, MLP
from urllib.request import urlretrieve
from utils import L2_loss
from utils import L2_loss, rk4
from utils import L2_loss, to_pickle, from_pickle
from utils import choose_nonlinearity
from utils import read_lipson, str2array
from utils import rk4
from utils import to_pickle, from_pickle
import argparse
import jax
import jax.numpy
import gym
import imageio
import numpy
import os
import pickle
import scipy
import scipy.integrate
import scipy.misc
import shutil
import sys
import torch
import zipfile


#### Hamiltonian Gradient Field


In [8]:
import numpy as np
import jax
import jax.numpy as jnp

def hamiltonian_fn(coords):
    q, p = jnp.split(coords,2)
    H = p**2 + q**2 # spring hamiltonian (linear oscillator)
    return H

def dynamics_fn(t, coords):
    dcoords = jax.grad(hamiltonian_fn)(coords)
    dqdt, dpdt = np.split(dcoords,2)
    S = np.concatenate([dpdt, -dqdt], axis=-1)
    return S

def get_field(xmin=-1.2, xmax=1.2, ymin=-1.2, ymax=1.2, gridsize=20):
    field = {'meta': locals()}

    # meshgrid to get vector field
    b, a = jnp.meshgrid(jnp.linspace(xmin, xmax, gridsize), jnp.linspace(ymin, ymax, gridsize))
    ys = jnp.stack([b.flatten(), a.flatten()])
    
    # get vector directions
    dydt = [dynamics_fn(None, y) for y in ys.T]
    dydt = np.stack(dydt).T

    field['x'] = ys.T
    field['dx'] = dydt.T
    return field

In [None]:
get_field()

In [15]:
xmin=-1.2
xmax=1.2
ymin=-1.2
ymax=1.2
gridsize=20

b, a = jnp.meshgrid(jnp.linspace(xmin, xmax, gridsize), jnp.linspace(ymin, ymax, gridsize))
ys = jnp.stack([b.flatten(), a.flatten()])

In [21]:
pq_vect = [(None, y) for y in ys.T]
pq_vect[0][1]

Array([-1.2, -1.2], dtype=float32)

In [26]:
def hamiltonian_fn_qp(q,p):
    return p**2 + q**2
grad_hamiltonian_q = jax.grad(hamiltonian_fn_qp, argnums=0)
grad_hamiltonian_p = jax.grad(hamiltonian_fn_qp, argnums=1)

In [25]:
def dynamics_fn_qp(coords):
    dqdt, dpdt = grad_hamiltonian_q(coords[0], coords[1]), grad_hamiltonian_p(coords[0], coords[1])
    return jnp.concatenate([dpdt, -dqdt], axis=-1)


In [29]:
coords = [-1.2, -1.2]
dqdt, dpdt = grad_hamiltonian_q(coords[0], coords[1]), grad_hamiltonian_p(coords[0], coords[1])
print(dqdt, dpdt)
print(hamiltonian_fn_qp(coords[0], coords[1]))

-2.4 -2.4
2.88


In [27]:
dynamics_fn_qp([-1.2, -1.2])

ValueError: Zero-dimensional arrays cannot be concatenated.