In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn

from astro_dynamo.grid import Grid, ForceGrid
from astro_dynamo.snap import SnapShot
import astro_dynamo.analysesnap
import astro_dynamo.snap
import astro_dynamo.analytic_potentials
import astro_dynamo.target
import matplotlib.pyplot as plt
import numpy as np
import math
from tqdm.notebook import tqdm
import matplotlib.colors as colors

from torch.utils.tensorboard import SummaryWriter

In [None]:
stars, dm = astro_dynamo.snap.read_nemo_snapshot('../input_model_construction/outM80')
stars = astro_dynamo.snap.symmetrize_snap(stars)
astro_dynamo.analysesnap.align_bar(stars)
omega, omegaerr = astro_dynamo.analysesnap.patternspeed(stars)
stars.omega = torch.Tensor([omega]).type(torch.float32)
print(stars.omega)

In [None]:
def einasto(m,rhor0,m0,alpha):
    m = torch.as_tensor(m)
    rhor0, m0, alpha = map(lambda var : torch.as_tensor(var,dtype=m.dtype,device=m.device), (rhor0, m0, alpha))
    rho0 = rhor0 / (torch.exp(-(2 / alpha) * ((8.2 / m0) ** alpha - 1)))
    return rho0 * torch.exp(-(2 / alpha) * ((m / m0) ** alpha - 1))

dm_pot = astro_dynamo.analytic_potentials.fit_potential_to_snap(dm,einasto,init_parms=[1e-3,8.0,0.7],plot=False)
dm_pot.to('cuda')
dm_pot.grid_accelerations()

In [None]:
n=256
nz=256
star_pot=ForceGrid(n=(n,n,nz),grid_edges=torch.tensor([10.,10.,2.5]),smoothing=0.2*20/n)
star_pot.update_density(stars.positions,weights=stars.masses.detach(),method='cic')
star_pot.grid_accelerations()

star_pot.to('cuda')

In [None]:
target_stars, target_dm = astro_dynamo.snap.read_nemo_snapshot('../input_model_construction/outM85')
astro_dynamo.analysesnap.align_bar(target_stars)
target_stars.requires_grad_(False)
target_stars.to('cuda')

In [None]:
class SurfaceDensity(nn.Module):
    def __init__(self, r_range=(0, 10), r_bins=50):
        super(SurfaceDensity, self).__init__()
        self.dr = (r_range[1] - r_range[0]) / r_bins
        self.r_min = r_range[0]
        self.r_bins = r_bins
        redge = self.r_min + torch.arange(self.r_bins+1)*self.dr
        self.register_buffer('area',math.pi * (redge[1:] ** 2 - redge[:-1] ** 2))
                    
    def forward(self,snap):
        r_cyl = (snap.positions[:,0]**2 + snap.positions[:,1]**2).sqrt()
        i = ((r_cyl - self.r_min) / self.dr).floor().type(torch.long)
        gd = (i >= 0) & (i < self.r_bins)
        mass_in_bin = torch.sparse.FloatTensor(i[gd].unsqueeze(0), snap.masses[gd], size=(self.r_bins,)).to_dense()
        surface_density = mass_in_bin / self.area
        return surface_density
    
    def extra_repr(self):
        return f'r_min={self.r_min}, r_max={self.r_min+self.dr*self.r_bins}, r_bins={self.r_bins}'
    
    @property
    def rmid(self):
        return self.r_min+self.dr/2 + self.dr*torch.arange(self.r_bins,device=self.area.device,dtype=self.area.dtype)
    
    def evalulate_function(self,surface_density):
        return surface_density(self.rmid)

In [None]:
class DynamicalModel(nn.Module):
    """DynamicalModels class. This containts a snapshot of the particles, the potentials 
    in which they move, and the targets to which the model should be fitted.
    
    Attributes:
        snap:
            Should be a SnapShot whose masses will be optimised 

        potentials: 
            The potentials add. If self gravity is not required set self_gravity_update to None.
            If self gravity is required then the potential of the snapshot should be in potentials[0]
            and self_gravity_update represents how much update the running average of the density on
            each iteration. Default value is 0.2 which is then exponential average with timescale 
            5 snapshots(=1/0.2).

        targets:
            A list of targets. Running 
                model = DynamicalModel(snap, potentials, targets)
                current_target_list = model()
            will provide an list of these targets evaluated with the present model. These are then
            typically combined to a loss that pytorch can optimise.
        
    Methods:
        forward()
            Computes the targets by evaluating them on the current snapshot. Can also be called as DynamicalModel()
        integrate(steps=256)
            Integrates the model forward by steps. Updates potential the density assocaiates to potential[0]
        update_potential()
            Recomputes the accelerations from potential[0]. Adjust each snapshots velocity by a factor vc_new/vc_old
        resample()
            Resamples the snapshot to equal mass particles. 
    """
    def __init__(self, snap, potentials, targets, self_gravity_update=0.2):
        super(DynamicalModel, self).__init__()
        self.snap = snap
        self.targets = nn.ModuleList(targets)
        self.potentials = nn.ModuleList(potentials)
        self.self_gravity_update = self_gravity_update
        
    def forward(self):
        return [target(self.snap) for target in self.targets]
    
    def integrate(self,steps=256):
        with torch.no_grad():
            self.snap.leapfrog_steps(potentials=self.potentials, steps=steps)
            if self.self_gravity_update is not None:
                self.potentials[0].update_density(self.snap.positions,self.snap.masses.detach(),
                                             fractional_update=self.self_gravity_update)

    def update_potential(self, update_velocities=True):
        with torch.no_grad():
            old_accelerations = self.snap.get_accelerations(self.potentials,self.snap.positions)
            old_vc=torch.sum(-old_accelerations*self.snap.positions,dim=-1).sqrt()
            self.potentials[0].grid_accelerations()
            new_accelerations = self.snap.get_accelerations(self.potentials,self.snap.positions)
            new_vc=torch.sum(-new_accelerations*self.snap.positions,dim=-1).sqrt()
            gd = torch.isfinite(old_vc) & torch.isfinite(new_vc)
            self.snap.velocities[gd,:]*=(new_vc/old_vc)[gd,None]
        
    def resample(self, velocity_perturbation=0.01):
        """Resample the model to equal mass particles.

        Note that the snapshot changes and so the parameters of
        the model also change in a way that any optimiser that keeps parameter-by-parameter information e.g.
        gradients must also be update."""
        with torch.no_grad():
            self.snap = self.snap.resample(model.potentials)

In [None]:
#construct objects that represent our targets
surface_density_obj = SurfaceDensity(r_range=(4.,9.),r_bins=50).to('cuda')
grid_obj = Grid(grid_edges=torch.tensor((4.,4.,1)),n=(20,20,20)).to('cuda')

#construct our dynamical model
model = DynamicalModel(stars,[star_pot,dm_pot],[surface_density_obj,grid_obj]).to('cuda')

#compute the numerical values of our targets 
fiducial_r,dr = 4.,0.1
fiducial_sig = SurfaceDensity(r_range=(fiducial_r-dr,fiducial_r+dr),r_bins=1).to('cuda')(stars).item()
surface_density_func=lambda x: fiducial_sig*torch.exp(-(x-fiducial_r)/2.4)

target_surface_density = surface_density_obj.evalulate_function(surface_density_func)
target_surface_density_sig = 0.05*target_surface_density

target_3d_density = grid_obj.grid_data(target_stars.positions,target_stars.masses,method='cic')
symmetrize_grid = lambda x,dim : (x+x.flip(dims=[dim]))/2
target_3d_density=symmetrize_grid(symmetrize_grid(symmetrize_grid(target_3d_density,dim=2),dim=1),dim=0)

target_3d_density_sig = 0.01*target_3d_density
min_sig = target_stars.masses[0].item()
target_3d_density_sig[target_3d_density_sig<min_sig] = min_sig

big_sig = 1e5*target_3d_density_sig.max()
target_3d_density_sig[:,:,0]=target_3d_density_sig[:,:,-1]=big_sig
target_3d_density_sig[:,0,:]=target_3d_density_sig[:,-1,:]=big_sig
target_3d_density_sig[0,:,:]=target_3d_density_sig[-1,:,:]=big_sig

In [None]:
import matplotlib.colors as colors
def plot_3d_density(model):
    f,axs = plt.subplots(3,11,sharex=True,sharey=True,figsize=(14,5))
    extent = torch.stack((grid_obj.min[0]+grid_obj.dx[0],
                          grid_obj.max[0]-grid_obj.dx[0],
                          grid_obj.min[1]+grid_obj.dx[1],
                          grid_obj.max[1]-grid_obj.dx[1])).cpu().numpy()
    plot_density = lambda x : symmetrize_grid(x[1:-1,1:-1,:],dim=2).detach().cpu()
    plot_target_density = plot_density(target_3d_density)
    plot_model_density = plot_density(model()[1])
    plot_density_sig = plot_density(target_3d_density_sig)

    vmin = plot_target_density[plot_target_density>0].min()*20
    vmax = plot_target_density.max()
    chi = (plot_model_density-plot_target_density)/plot_density_sig
    chi_min, chi_max = -chi.abs().max(), chi.abs().max()
    for i,ax in enumerate(axs[:,:-1].T):
        im0 = ax[0].imshow(plot_model_density[:,:,i].t(),norm=colors.LogNorm(vmin=vmin,vmax=vmax), extent=extent)
        im1 = ax[1].imshow(plot_target_density[:,:,i].t(),norm=colors.LogNorm(vmin=vmin,vmax=vmax), extent=extent)
        im2 = ax[2].imshow(chi[:,:,i].t(),norm=colors.SymLogNorm(vmin=chi_min,vmax=chi_max,linthresh=10), 
                     extent=extent,cmap='RdBu_r')
    axs[0,-1].axis('off')
    axs[1,-1].axis('off')
    axs[2,-1].axis('off')
    f.colorbar(im0, ax=axs[0,-1],fraction=0.9, extend='both',shrink=0.8)
    f.colorbar(im1, ax=axs[1,-1],fraction=0.9, extend='both',shrink=0.8)
    f.colorbar(im2, ax=axs[2,-1],fraction=0.9, extend='both',shrink=0.8)
    axs[0,4].set_title('Model')
    axs[1,4].set_title('Target')
    axs[2,4].set_title('(Model-Target)/$\sigma$')
    f.subplots_adjust(hspace=0,wspace=0)
    return f
f=plot_3d_density(model)

In [None]:
def plot_radialprofile(ax,model,target,vmin=1e-5,vmax=1):
    device = model.snap.masses.device
    surface_density_full = SurfaceDensity().to(device)
    ax.semilogy(surface_density_full.rmid.cpu(),surface_density_full(model.snap).detach().cpu(),label='Model')
    ax.semilogy(model.targets[0].rmid.cpu(),target.cpu(),label='Target')
    ax.set_ylim(vmin,vmax)
    ax.set_xlabel('r')
    ax.set_ylabel('$\Sigma$')
    ax.legend()

def plot_snap_projections(axs,snap,plotmax=10.,vmin=1e-5,vmax=1e-2,particle_plot_i=None):
    x=snap.x.cpu()
    y=snap.y.cpu()
    z=snap.z.cpu()
    m=snap.masses.detach().cpu()
    projections = ((x,y),(x,z),(y,z))
    projection_labels = (('x','y'),('x','z'),('y','z'))

    for ax,projection,projection_label in zip(axs,projections,projection_labels):
        ax.hexbin(projection[0],projection[1],C=m,bins='log',
                   extent=(-plotmax,plotmax,-plotmax,plotmax),reduce_C_function=np.sum,
                     vmin=1e-6,vmax=1e-2,cmap=plt.cm.get_cmap('nipy_spectral'))
        ax.set_xlabel(projection_label[0])
        ax.set_ylabel(projection_label[1])
        if particle_plot_i is not None:
            ax.plot(projection[0][particle_plot_i],projection[1][particle_plot_i],'ko',markersize=4)
        ax.set_xlim(-plotmax,plotmax)
        ax.set_ylim(-plotmax,plotmax)

def plot_fit_step(model,particle_plot_i=None):
    f,axs = plt.subplots(2,2,figsize=(9,9))
    plot_snap_projections((axs[0,0],axs[1,0],axs[0,1]),model.snap,
                          particle_plot_i=particle_plot_i)
    plot_radialprofile(axs[1,1],model,target_surface_density)
    f.tight_layout()
    return f

f=plot_fit_step(model)

In [None]:
def chi2_loss(input, target, error, reduced=True):
    chi2=torch.sum((input - target) ** 2/error**2)
    if reduced:
        return chi2/target.nelement()
    else:
        return chi2

In [None]:
# Before running its worth checking that the loss is set up reasonably i.e. is giving finite numbers
model_surface_density, model_3d_density = model()
loss_surf = chi2_loss(model_surface_density,target_surface_density,target_surface_density_sig)
loss_3d = chi2_loss(model_3d_density,target_3d_density,target_3d_density_sig)
loss = loss_surf + loss_3d
print(f'loss {loss.item():.5f} loss_surf {loss_surf.item():.5f} loss_3d {loss_3d.item():.5f}')
loss.backward()
print(f'min grad: {model.snap.logmasses.grad.min()}, max grad {model.snap.logmasses.grad.max()}')
_=model.snap.logmasses.grad.zero_()

In [None]:
writer = SummaryWriter(comment='LessResamples')

prefix = 'three_d'
epochs,potential_update,resample,plot_figures=400,25,100,5

optim_kws={'lr':0.5, 'momentum':0.8, 'nesterov':True}
optimizer = torch.optim.SGD(model.parameters(), **optim_kws)

h_parms={'optimizer':f'{type(optimizer)}','epochs':epochs,'potential_update':potential_update,'resample':resample}
h_parms.update(optim_kws)
writer.add_hparams(h_parms,{})

lossvec = torch.zeros((epochs,3))
model.integrate(steps=800*8)

for epoch in tqdm(range(epochs)):
    optimizer.zero_grad()
    model_surface_density, model_3d_density = model()
    loss_surf = chi2_loss(model_surface_density,target_surface_density,target_surface_density_sig)
    loss_3d = chi2_loss(model_3d_density,target_3d_density,target_3d_density_sig)
    loss = loss_surf + loss_3d
    writer.add_scalar('Loss/Total', loss.item(), epoch)
    writer.add_scalar('Loss/Surface Density', loss_surf.item(), epoch)
    writer.add_scalar('Loss/3D Density', loss_3d.item(), epoch)
    loss.backward()
    optimizer.step()    
    
    with torch.no_grad():
        model.integrate()
        lossvec[epoch,:]=torch.stack((loss.detach(),loss_surf.detach(),loss_3d.detach()))

        if epoch % plot_figures == 0:
            writer.add_histogram('Log10Masses',model.snap.masses.detach().log10(), epoch)
            writer.add_figure('Image',plot_fit_step(model), epoch)
            writer.add_figure('3DSlices',plot_3d_density(model), epoch)

        if (epoch % potential_update == 0) & (epoch != 0):
            model.update_potential()
            optimizer = torch.optim.SGD(model.parameters(), **optim_kws)


        if (epoch % resample == 0) & (epoch != 0):
            print('Resampling')
            model.resample()
            optimizer = torch.optim.SGD(model.parameters(), **optim_kws)
    writer.flush()
writer.close()