In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn

from astro_dynamo.grid import Grid, ForceGrid
from astro_dynamo.snap import SnapShot
import astro_dynamo.analysesnap
import astro_dynamo.snap
import astro_dynamo.analytic_potentials
import astro_dynamo.target
import matplotlib.pyplot as plt
import numpy as np
import math
from tqdm.notebook import tqdm


In [None]:
stars, dm = astro_dynamo.snap.read_nemo_snapshot('../input_model_construction/outM80')
stars = astro_dynamo.snap.symmetrize_snap(stars)
astro_dynamo.analysesnap.align_bar(stars)
omega, omegaerr = astro_dynamo.analysesnap.patternspeed(stars)
stars.omega = torch.Tensor([omega]).type(torch.float32)
print(stars.omega)

In [None]:
def einasto(m,rhor0,m0,alpha):
    m = torch.as_tensor(m)
    rhor0, m0, alpha = map(lambda var : torch.as_tensor(var,dtype=m.dtype,device=m.device), (rhor0, m0, alpha))
    rho0 = rhor0 / (torch.exp(-(2 / alpha) * ((8.2 / m0) ** alpha - 1)))
    return rho0 * torch.exp(-(2 / alpha) * ((m / m0) ** alpha - 1))

dm_pot = astro_dynamo.analytic_potentials.fit_potential_to_snap(dm,einasto,init_parms=[1e-3,8.0,0.7],plot=True)
dm_pot_gpu = dm_pot.to('cuda')
dm_pot_gpu.grid_accelerations()

In [None]:
n=256
nz=256
star_pot=ForceGrid(n=(n,n,nz),
                    grid_edges=torch.tensor([10.,10.,2.5],dtype=torch.float32),
                    smoothing=0.2*20/n)
_=star_pot.grid_data(stars.positions,weights=stars.masses.detach(),method='cic')
star_pot.grid_accelerations()
star_pot_gpu=star_pot.to('cuda')

In [None]:
class SurfaceDensityTarget(nn.Module):
    def __init__(self, r_range=(0, 10), r_bins=50):
        super(SurfaceDensityTarget, self).__init__()
        self.dr = (r_range[1] - r_range[0]) / r_bins
        self.r_min = r_range[0]
        self.r_bins = r_bins
        redge = self.r_min + torch.arange(self.r_bins+1)*self.dr
        self.register_buffer('area',math.pi * (redge[1:] ** 2 - redge[:-1] ** 2))
                    
    def forward(self,snap):
        r_cyl = (snap.positions[:,0]**2 + snap.positions[:,1]**2).sqrt()
        i = ((r_cyl - self.r_min) / self.dr).floor().type(torch.long)
        gd = (i >= 0) & (i < self.r_bins)
        mass_in_bin = torch.sparse.FloatTensor(i[gd].unsqueeze(0), snap.masses[gd], size=(self.r_bins,)).to_dense()
        surface_density = mass_in_bin / self.area
        return surface_density
    
    def extra_repr(self):
        return f'r_min={self.r_min}, r_max={self.r_min+self.dr*self.r_bins}, r_bins={self.r_bins}'
    
    @property
    def rmid(self):
        return self.r_min+self.dr/2 + self.dr*torch.arange(self.r_bins,device=self.area.device,dtype=self.area.dtype)
    
    def evalulate_function(self,surface_density):
        return surface_density(self.rmid)

In [None]:
class Model(nn.Module):
    def __init__(self, snap, potentials, targets, self_gravity_update=0.2):
        super(Model, self).__init__()
        self.snap = snap
        self.targets = nn.ModuleList(targets)
        self.potentials = nn.ModuleList(potentials)
        self.self_gravity_update = self_gravity_update
        
    def forward(self):
        return [target(self.snap) for target in self.targets]
    
    def integrate(self,steps=256):
        self.snap.leapfrog_steps(potentials=self.potentials, steps=steps)
        if self.self_gravity_update is not None:
            self.potentials[0].grid_data(self.snap.positions,self.snap.masses.detach(),
                                         fractional_update=self.self_gravity_update)

target = SurfaceDensityTarget(r_range=(4.,9.),r_bins=50)
model = Model(stars,[star_pot_gpu,dm_pot_gpu],[target]).to('cuda')

In [None]:
fiducial_r,dr = 4.,0.1
fiducial_sig = SurfaceDensityTarget(r_range=(fiducial_r-dr,fiducial_r+dr),r_bins=1).to('cuda')(stars).item()
surface_density_func=lambda x: fiducial_sig*torch.exp(-(x-fiducial_r)/2.4)
target_surface_density = target.evalulate_function(surface_density_func)

In [None]:
def plot_radialprofile(ax,model,target,vmin=1e-5,vmax=1):
    device = model.snap.masses.device
    surface_density_full = SurfaceDensityTarget().to(device)
    ax.semilogy(surface_density_full.rmid.cpu(),surface_density_full(model.snap).detach().cpu(),label='Model')
    ax.semilogy(model.targets[0].rmid.cpu(),target.cpu(),label='Target')
    ax.set_ylim(vmin,vmax)
    ax.set_xlabel('r')
    ax.set_ylabel('$\Sigma$')
    ax.legend()

def plot_snap_projections(axs,snap,plotmax=10.,vmin=1e-5,vmax=1e-2,particle_plot_i=None):
    x=snap.x.cpu()
    y=snap.y.cpu()
    z=snap.z.cpu()
    m=snap.masses.detach().cpu()
    projections = ((x,y),(x,z),(y,z))
    projection_labels = (('x','y'),('x','z'),('y','z'))

    for ax,projection,projection_label in zip(axs,projections,projection_labels):
        ax.hexbin(projection[0],projection[1],C=m,bins='log',
                   extent=(-plotmax,plotmax,-plotmax,plotmax),reduce_C_function=np.sum,
                     vmin=1e-6,vmax=1e-2,cmap=plt.cm.get_cmap('nipy_spectral'))
        ax.set_xlabel(projection_label[0])
        ax.set_ylabel(projection_label[1])
        if particle_plot_i is not None:
            ax.plot(projection[0][particle_plot_i],projection[1][particle_plot_i],'ko',markersize=4)
        ax.set_xlim(-plotmax,plotmax)
        ax.set_ylim(-plotmax,plotmax)

def plot_fit_step(model,step,prefix='fit_step',particle_plot_i=None):
    f,axs = plt.subplots(2,2,figsize=(9,9))
    plot_snap_projections((axs[0,0],axs[1,0],axs[0,1]),model.snap,
                          particle_plot_i=particle_plot_i)
    plot_radialprofile(axs[1,1],model,target_surface_density)
    f.tight_layout()
    f.savefig(f'plots/{prefix}_{step:05}.png',dpi=150)
    f.show()
    plt.close(f)

particle_plot_i=((model.snap.rcyl>3) & (model.snap.rcyl<5)).nonzero()[0:7]   
plot_fit_step(model,0,particle_plot_i=particle_plot_i)

In [None]:
plt.hist(stars.logmasses.cpu().detach())
plt.yscale('log')

In [None]:
model

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=2e4, momentum=0.7, nesterov=True)
epochs=100
lossvec = torch.zeros((epochs,))
model.integrate(steps=800*8)

for epoch in tqdm(range(100)):
    optimizer.zero_grad()   # zero the gradient buffers
    loss = ((target_surface_density - model()[0])**2/target_surface_density).sum()
    print(f'loss {loss.item()}')
    loss.backward()
    optimizer.step()    # Does the update
    model.integrate()
    lossvec[epoch]=loss.detach()
    if epoch % 5 == 0:
        plot_fit_step(model,epoch,prefix='refactor',particle_plot_i=particle_plot_i)
    if epoch % 10 == 9:
        print('Recomputing Potential')
        old_accelerations = model.snap.get_accelerations(model.potentials,model.snap.positions)
        old_vc=torch.sum(-old_accelerations*model.snap.positions,dim=-1).sqrt()
        model.potentials[0].grid_accelerations()
        new_accelerations = model.snap.get_accelerations(model.potentials,model.snap.positions)
        new_vc=torch.sum(-new_accelerations*model.snap.positions,dim=-1).sqrt()
        gd = torch.isfinite(old_vc) & torch.isfinite(new_vc)
        model.snap.velocities[gd,:]*=(new_vc/old_vc)[gd,None]
        
        #snap_gpu.resample([star_pot_gpu,dm_pot_gpu],verbose=verbose)


In [None]:
potenial_updates=np.arange(0,100,5)

In [None]:
plt.semilogy(lossvec.cpu().numpy())
i=np.array(potenial_updates)
plt.semilogy(i,lossvec.cpu().numpy()[i],'o',label='Potential Update')
plt.legend()
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.savefig(f'plots/{prefix}loss.png')

In [None]:
for var,val in model.potentials[0].named_parameters():
    print(f'{var}: {val}')
    
for var,val in model.potentials[0].named_buffers():
    print(f'{var}: {val}')