In [None]:
%load_ext autoreload
%autoreload 2
import torch
import numpy as np
import matplotlib.pyplot as plt
print(torch.__version__)
from pydynmod.grid import Grid, ForceGrid
from pydynmod.snap import SnapShot, ParticleType
import pydynmod.target
import pydynmod.analysesnap 
import math
from tqdm import tqdm_notebook 
import torch.nn.functional as F 

%aimport -math,torch,numpy,matplotlib.pyplot,sys

Load our input N-body model and compute its pattern speed

In [None]:
snap=SnapShot('../inputmodels/M85_0',omega=1.)
particletype = torch.full((snap.n,),ParticleType.Star,dtype=torch.uint8)
particletype[snap.particletype==0]=ParticleType.DarkMatter
snap.particletype = particletype
omega,omegaerr = pydynmod.analysesnap.patternspeed(snap.stars)
snap.omega = torch.Tensor([omega]).type(torch.float32)
print(snap.omega)

Compute the potential of the N-body model

In [None]:
try:
    del potential
except NameError:
    pass
n=512
nz=512
potential=ForceGrid(n=(n,n,nz),
                    gridedges=torch.tensor([10.,10.,10.],dtype=torch.float32),
                    smoothing=0.5*20/n)
_=potential.griddata(snap.positions,weights=snap.masses,method='cic')
potential.grid_accelerations()

Setup our targets which we wish to optimise

In [None]:
device='cuda'
gpusnap=snap.to(device)

full_radial_profile=pydynmod.target.RadialProfile(device='cuda')

fiducial_r = 4.
fiducial_sig=full_radial_profile.interpolate_surface_density(gpusnap.stars,fiducial_r)

surface_density=lambda x: fiducial_sig*torch.exp(-(x-fiducial_r)/2.4)
target = pydynmod.target.RadialProfile(rrange=(fiducial_r,9),
                                       surface_density=surface_density,device='cuda')


In [None]:
def plot_radialprofile(ax,full_radial_profile,target,snap,vmin=1e-5,vmax=1):
    surface_density_full = full_radial_profile.observe(snap).cpu().numpy()
    ax.semilogy(full_radial_profile.rmid.cpu().numpy(),surface_density_full,label='Inital')
    ax.semilogy(target.rmid.cpu().numpy(),target.target.cpu().numpy(),label='Target')
    ax.semilogy(target.rmid.cpu().numpy(),target.observe(snap).detach().cpu().numpy(),
            'r',label='Snapshot')
    ax.set_ylim(vmin,vmax)
    ax.set_xlabel('r')
    ax.set_ylabel('$\Sigma$')
    ax.legend()

def plot_snap_projections(axs,snap,plotmax=10.,vmin=1e-6,vmax=1e-2):
    x=snap.stars.x.cpu()
    y=snap.stars.y.cpu()
    z=snap.stars.z.cpu()
    m=gpusnap.stars.masses.cpu()
    projections = ((x,y),(x,z),(y,z))
    projection_labels = (('x','y'),('x','z'),('y','z'))

    for ax,projection,projection_label in zip(axs,projections,projection_labels):
        ax.hexbin(projection[0],projection[1],C=m,bins='log',
                   extent=(-plotmax,plotmax,-plotmax,plotmax),reduce_C_function=np.sum,
                     vmin=1e-6,vmax=1e-2)
        ax.set_xlabel(projection_label[0])
        ax.set_ylabel(projection_label[1])

def plot_fit_step(snap,step,prefix='fit_step'):
    f,axs = plt.subplots(2,2,figsize=(9,9))
    plot_snap_projections((axs[0,0],axs[1,0],axs[0,1]),gpusnap.stars)
    plot_radialprofile(axs[1,1],full_radial_profile,target,gpusnap.stars)
    f.tight_layout()
    f.savefig(f'{prefix}_{step:05}.png',dpi=150)
    
plot_fit_step(gpusnap,0)


In [None]:
import gc
gc.collect()

In [None]:
gpusnap.stars.dt

In [None]:
device='cuda'
gpupotential=potential.to(device)

progress = tqdm_notebook(range(20),total=20)

snap.dt = torch.full(snap.masses.shape,float('inf'),dtype=snap.positions.dtype)
gpusnap=snap.to(device)
for step in progress:
    plot_fit_step(gpusnap,step,prefix='fixed_dt')
    gpusnap.stars.leapfrog_steps(potential=gpupotential, steps=20)


In [None]:
device='cuda'
gpupotential=potential.to(device)

snap.dt = torch.full(snap.masses.shape,float('inf'),dtype=snap.positions.dtype)
gpusnap=snap.to(device)
gpusnap.masses=gpusnap.masses.requires_grad_(True)

print(f'Using pattern speed {gpusnap.omega[0]:.4f}')

learning_rate = 1e-2

tvec = torch.linspace(0.,100,101,device=device)
plt.ioff()
progress = tqdm_notebook(enumerate(tvec),total=len(tvec))
lossvec = torch.zeros_like(tvec)

potential_step=0

for step, time in progress:
    verbose=True
    loss = target.loss(gpusnap.stars)
    loss.backward()
    lossvec[step] = loss.detach()

    with torch.no_grad():
        #if step % 10 == 0:
        gpusnap.stars.leapfrog_steps(potential=gpupotential, steps=64)
        #gpusnap.integrate(time=time,potential=gpupotential,verbose=verbose)
        plot_fit_step(snap,step,prefix='fixed_steps')
        plt.close()
        
        gpusnap.masses -= learning_rate * gpusnap.masses * gpusnap.masses.grad
        gpusnap.masses.grad.zero_()
        
        gpupotential.griddata(gpusnap.positions,weights=gpusnap.masses.detach(),
                              method='nearest',fractional_update=0.3)
        
        fractional_loss_change = (lossvec[potential_step]-loss.detach()).abs()/loss.detach()
        progress.write(f'Loss: {loss:.4f}, Fractional loss change: {fractional_loss_change:4f}')
        
        if step - potential_step > 25 or (step - potential_step > 5 and fractional_loss_change > 0.5):
            progress.write('Recomputing Potential')
            
            old_accelerations = gpupotential.get_accelerations(gpusnap.positions)
            old_vc=torch.sum(old_accelerations*gpusnap.positions,dim=-1).sqrt()
            
            potential.data=gpupotential.data.cpu()
            potential.grid_accelerations()
            del gpupotential
            gpupotential=potential.to(device)
            
            new_accelerations = gpupotential.get_accelerations(gpusnap.positions)
            new_vc=torch.sum(new_accelerations*gpusnap.positions,dim=-1).sqrt()
            gd = (old_vc>0)
            gpusnap.velocities[gd,:]*=(new_vc/old_vc)[gd,None]
            potential_step=step
            gpusnap.stars.resample(gpupotential,verbose=verbose)
    
plt.ion()

In [None]:
plt.semilogy(lossvec.cpu().numpy())

In [None]:
%%timeit
ix=(nx*snap.stars.x/10).round().type(torch.long)

In [None]:
%load_ext line_profiler

In [None]:
%%timeit
(32*gpusnap.stars.x/10)

In [None]:
%prun (32*gpusnap.stars.x/10)

In [None]:
%lprun -f grid.griddata_sparse grid.griddata_sparse(stars.positions,weights=stars.masses)