In [None]:
%load_ext autoreload
%autoreload 2
import torch
import numpy as np
import matplotlib.pyplot as plt
print(torch.__version__)

from astro_dynamo.grid import Grid, ForceGrid
from astro_dynamo.snap import SnapShot, ParticleType
import astro_dynamo.target
import astro_dynamo.analysesnap
import  astro_dynamo.analytic_potentials

import math
from tqdm.notebook import trange, tqdm
import torch.nn.functional as F 
import mwtools.nemo


%aimport -math,torch,numpy,matplotlib.pyplot,sys

Load our input N-body model and compute its pattern speed

In [None]:
fullsnap=SnapShot('../inputmodels/M85_0',omega=1.,particle_type_mapping={0:ParticleType.DarkMatter,1:ParticleType.Star})
omega,omegaerr = astro_dynamo.analysesnap.patternspeed(fullsnap.stars)
fullsnap.omega = torch.Tensor([omega]).type(torch.float32)
print(fullsnap.omega)

In [None]:
def ein(m,rhor0,m0,alpha):
    m = torch.as_tensor(m)
    rhor0, m0, alpha = map(lambda var : torch.as_tensor(var,dtype=m.dtype,device=m.device), (rhor0, m0, alpha))
    rho0 = rhor0 / (torch.exp(-(2 / alpha) * ((8.2 / m0) ** alpha - 1)))
    return rho0 * torch.exp(-(2 / alpha) * ((m / m0) ** alpha - 1))

dm_pot = astro_dynamo.analytic_potentials.fit_potential_to_snap(fullsnap.dm,ein,init_parms=[1e-3,8.0,0.7],plot=True)

In [None]:
dm_pot_gpu = dm_pot.to('cuda')
dm_pot_gpu.grid_accelerations()

Compute the potential of the N-body model

In [None]:
try:
    del potential
except NameError:
    pass
n=256
nz=256
snap = fullsnap.stars
star_pot=ForceGrid(n=(n,n,nz),
                    grid_edges=torch.tensor([10.,10.,4.],dtype=torch.float32),
                    smoothing=0.2*20/n)
_=star_pot.grid_data(snap.positions,weights=snap.masses,method='cic')
star_pot.grid_accelerations()

Setup our targets which we wish to optimise

In [None]:
device='cuda'
snap_gpu=snap.to(device)

full_radial_profile=astro_dynamo.target.RadialProfile(device='cuda')

fiducial_r = 4.
fiducial_sig=full_radial_profile.interpolate_surface_density(snap_gpu,fiducial_r)

surface_density=lambda x: fiducial_sig*torch.exp(-(x-fiducial_r)/2.4)
target = astro_dynamo.target.RadialProfile(r_range=(fiducial_r,9),
                                       surface_density=surface_density,device='cuda')


In [None]:
def plot_radialprofile(ax,full_radial_profile,target,snap,vmin=1e-5,vmax=1):
    surface_density_full = full_radial_profile.observe(snap).detach().cpu().numpy()
    ax.semilogy(full_radial_profile.rmid.cpu().numpy(),surface_density_full,label='Inital')
    ax.semilogy(target.rmid.cpu().numpy(),target.target.cpu().numpy(),label='Target')
    ax.semilogy(target.rmid.cpu().numpy(),target.observe(snap).detach().cpu().numpy(),
            'r',label='Snapshot')
    ax.set_ylim(vmin,vmax)
    ax.set_xlabel('r')
    ax.set_ylabel('$\Sigma$')
    ax.legend()

def plot_snap_projections(axs,snap,plotmax=10.,vmin=1e-5,vmax=1e-2,particle_plot_i=None):
    x=snap.x.cpu()
    y=snap.y.cpu()
    z=snap.z.cpu()
    m=snap.masses.detach().cpu()
    projections = ((x,y),(x,z),(y,z))
    projection_labels = (('x','y'),('x','z'),('y','z'))

    for ax,projection,projection_label in zip(axs,projections,projection_labels):
        ax.hexbin(projection[0],projection[1],C=m,bins='log',
                   extent=(-plotmax,plotmax,-plotmax,plotmax),reduce_C_function=np.sum,
                     vmin=1e-6,vmax=1e-2,cmap=plt.cm.get_cmap('nipy_spectral'))
        ax.set_xlabel(projection_label[0])
        ax.set_ylabel(projection_label[1])
        if particle_plot_i is not None:
            ax.plot(projection[0][particle_plot_i],projection[1][particle_plot_i],'ko',markersize=4)
        ax.set_xlim(-plotmax,plotmax)
        ax.set_ylim(-plotmax,plotmax)

def plot_fit_step(snap,step,prefix='fit_step',particle_plot_i=None):
    f,axs = plt.subplots(2,2,figsize=(9,9))
    plot_snap_projections((axs[0,0],axs[1,0],axs[0,1]),snap,
                          particle_plot_i=particle_plot_i)
    plot_radialprofile(axs[1,1],full_radial_profile,target,snap)
    f.tight_layout()
    f.savefig(f'plots/{prefix}_{step:05}.png',dpi=150)
    f.show()
    plt.close(f)

particle_plot_i=((snap_gpu.rcyl>3) & (snap_gpu.rcyl<5)).nonzero()[0:7]   
plot_fit_step(snap_gpu,0,particle_plot_i=particle_plot_i)

In [None]:
import gc
gc.collect()

In [None]:
device = 'cuda'
star_pot_gpu = star_pot.to(device)
progress = tqdm(range(100), total=100)

snap.dt = torch.full(snap.masses.shape,float('inf'),dtype=snap.positions.dtype)
snap_gpu=snap.to(device)
particle_plot_i=range(10)

for step in progress:
    if step % 5 == 0: plot_fit_step(snap_gpu, step,prefix='fiducial', particle_plot_i=particle_plot_i)
    snap_gpu.leapfrog_steps(potentials=[star_pot_gpu,dm_pot_gpu], steps=64, stepsperorbit=800)

In [None]:
with torch.autograd.profiler.profile(use_cuda=True) as prof:
    for step in range(5):
        gpusnap.leapfrog_steps(potentials=[gpupotential,gpu_pot], steps=64)
print(prof)
prof.export_chrome_trace('testcase')

In [None]:
gpusnap=snap.to(device)
gpusnap.masses=gpusnap.masses.requires_grad_(True)
loss = target.loss(gpusnap)
loss.backward()
print(gpusnap.masses.grad)

In [None]:
prof.export_chrome_trace('testcase')

In [None]:
gpusnap.masses.requires_grad_(True)
target.loss(gpusnap),gpusnap.masses

In [None]:
device='cuda'
prefix='fiducial'
star_pot_gpu=star_pot.to(device)

snap.dt = torch.full(snap.masses.shape,float('inf'),dtype=snap.positions.dtype)
snap_gpu=snap.to(device)
astro_dynamo.analysesnap.align_bar(snap_gpu)
snap_gpu.leapfrog_steps(potentials=[star_pot_gpu,dm_pot_gpu], steps=800*8)

snap_gpu.masses.requires_grad_(True)
print(f'Using pattern speed {snap_gpu.omega[0]:.4f}')

learning_rate = 1e-2

tvec = torch.linspace(0.,100,101,device=device)
plt.ioff()
progress = tqdm(enumerate(tvec),total=len(tvec))
lossvec = torch.zeros_like(tvec)

potential_step=0
potenial_updates=[0]

for step, time in progress:
    verbose=True
    loss = target.loss(snap_gpu)
    loss.backward()
    lossvec[step] = loss.detach()
    with torch.no_grad():
        snap_gpu.masses -= learning_rate * snap_gpu.masses * snap_gpu.masses.grad
        snap_gpu.masses.grad.zero_()
        plot_fit_step(snap_gpu,step,prefix=prefix)
        star_pot_gpu.grid_data(snap_gpu.positions,weights=snap_gpu.masses.detach(),
                              method='nearest',fractional_update=0.2)
        fractional_loss_change = (lossvec[potential_step]-loss.detach()).abs()/loss.detach()
        progress.write(f'Loss: {loss:.4f}, Fractional loss change: {fractional_loss_change:4f}')

        if step - potential_step > 25 or (step - potential_step > 5 and fractional_loss_change > 0.5):
            progress.write('Recomputing Potential')
            potenial_updates+=[step]
            
            #astro_dynamo.analysesnap.align_bar(snap_gpu)
            
            old_accelerations = star_pot_gpu.get_accelerations(snap_gpu.positions) + \
                dm_pot_gpu.get_accelerations(snap_gpu.positions)
            old_vc=torch.sum(-old_accelerations*snap_gpu.positions,dim=-1).sqrt()
            
            star_pot.data=star_pot_gpu.data.cpu()
            star_pot.grid_accelerations()
            del star_pot_gpu
            star_pot_gpu=star_pot.to(device)
            
            new_accelerations = star_pot_gpu.get_accelerations(snap_gpu.positions) + \
                dm_pot_gpu.get_accelerations(snap_gpu.positions)
            new_vc=torch.sum(-new_accelerations*snap_gpu.positions,dim=-1).sqrt()
            
            gd = torch.isfinite(old_vc) & torch.isfinite(new_vc)
            snap_gpu.velocities[gd,:]*=(new_vc/old_vc)[gd,None]
            potential_step=step
            snap_gpu.resample([star_pot_gpu,dm_pot_gpu],verbose=verbose)
        snap_gpu.leapfrog_steps(potentials=[star_pot_gpu,dm_pot_gpu], steps=256)

plt.ion()

In [None]:
plt.semilogy(lossvec.cpu().numpy())
i=np.array(potenial_updates)
plt.semilogy(i,lossvec.cpu().numpy()[i],'o',label='Potential Update')
plt.legend()
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.savefig(f'plots/{prefix}loss.png')