In [None]:
%load_ext autoreload
%autoreload 2
import torch
import numpy as np
import matplotlib.pyplot as plt
print(torch.__version__)
from pydynmod.grid import Grid, ForceGrid
from pydynmod.snap import SnapShot, ParticleType
from tqdm import tqdm_notebook
#import mwtools
#import mwtools.nemo

import pydynmod.analysesnap 
import math
%aimport -math,torch,numpy,matplotlib.pyplot,sys
potential=None

In [None]:
snap=SnapShot('../inputmodels/M85_0',omega=1.)
particletype = torch.full((snap.n,),ParticleType.Star,dtype=torch.uint8)
particletype[snap.particletype==0]=ParticleType.DarkMatter
snap.particletype = particletype
omega,omegaerr = pydynmod.analysesnap.patternspeed(snap.stars)
snap.omega = torch.Tensor([omega]).type(torch.float32)
print(snap.omega)

In [None]:
potential=None
n=512
nz=512
potential=ForceGrid(n=(n,n,nz),
                    gridedges=torch.tensor([10.,10.,10.],dtype=torch.float32),
                    smoothing=0.5*20/n)
_=potential.griddata(snap.positions,weights=snap.masses,method='cic')
potential.grid_accelerations()

In [None]:
gpusnap=snap.to('cuda')
gpusnap.positions = gpusnap.corotating_frame(3.,gpusnap.omega,gpusnap.positions)
plt.ion()

f,ax = plt.subplots(1,2,sharex=True,sharey=True)
ax[0].hexbin(gpusnap.stars.x.cpu(),gpusnap.stars.y.cpu(),C=gpusnap.stars.masses.cpu(),bins='log',
           extent=(-10,10,-10,10),reduce_C_function=np.sum)
ax[1].hexbin(gpusnap.stars.x.cpu(),gpusnap.stars.z.cpu(),C=gpusnap.stars.masses.cpu(),bins='log',
           extent=(-10,10,-10,10),reduce_C_function=np.sum)
ax[0].set_aspect('equal', 'box')
ax[0].set(xlim=(-10, 10), ylim=(-10, 10))
ax[1].set_aspect('equal', 'box')
ax[1].set(xlim=(-10, 10), ylim=(-10, 10))


In [None]:
device='cuda'
gpusnap=snap.to(device)
gpupotential=potential.to(device)
#gpusnap.omega=torch.zeros_like(gpusnap.omega)
print(f'Using pattern speed {gpusnap.omega[0]:.4f}')

plotmax=10.
tvec = torch.linspace(0.,100,101,device=device)
plt.ioff()
for i,time in tqdm_notebook(enumerate(tvec),total=len(tvec)):
    verbose=False
    #gpusnap.integrate(time=time,potential=gpupotential,verbose=verbose)
    gpusnap.stars.leapfrog_steps(potential=gpupotential, steps=16)
    f,axs = plt.subplots(1,2,sharex=True,sharey=True)
    x=gpusnap.stars.x.cpu()
    y=gpusnap.stars.y.cpu()
    z=gpusnap.stars.z.cpu()
    m=gpusnap.stars.masses.cpu()
    axs[0].hexbin(x,y,C=m,bins='log',
               extent=(-plotmax,plotmax,-plotmax,plotmax),reduce_C_function=np.sum,
                 vmin=1e-6,vmax=1e-2)
    axs[1].hexbin(x,z,C=m,bins='log',
               extent=(-plotmax,plotmax,-plotmax,plotmax),reduce_C_function=np.sum,
                 vmin=1e-6,vmax=1e-2)
    for ax in axs:
        ax.set_aspect('equal', 'box')
        ax.set(xlim=(-plotmax, plotmax), ylim=(-plotmax, plotmax))
    f.savefig(f'fixvelrot_frame{i:04}.png')
    plt.close(f)
plt.ion()

In [None]:
import copy
device='cpu'
maxt=200
nt=101
nparticles=100
plt.ion()
smallsnap = copy.deepcopy(snap[(snap.r>7) & (snap.r<8)][0:nparticles]).to(device)
tvec = torch.linspace(0.,maxt,nt,device=device)
positions = torch.zeros((nt,nparticles,3),device=device)
velocities = torch.zeros((nt,nparticles,3),device=device)
dt = torch.zeros((nt,nparticles),device=device)
particle_times = torch.zeros((nt,nparticles),device=device)

for i,time in enumerate(tvec):
    if i % 100 == 1:
        print(time)
        verbose=True
    else:
        verbose=False
    #smallsnap.integrate(time=time,potential=potential,verbose=verbose)
    time=smallsnap.leapfrog_steps(potential=potential, steps=10, return_time=True)
    positions[i,:,:] = smallsnap.positions
    velocities[i,:,:] = smallsnap.velocities
    dt[i,:] = smallsnap.dt
    particle_times[i,:]+=time
    if i<99:
        particle_times[i+1,:]=particle_times[i,:]
positions=positions.cpu()
velocities=velocities.cpu()
dt=dt.cpu()

In [None]:
smallsnap.dt

In [None]:
plt.hist(positions[0,:,:].norm(dim=-1),alpha=0.5,label='Inital')
plt.hist(positions[16,:,:].norm(dim=-1),alpha=0.5,label='Final')
plt.legend()

In [None]:
print(velocities[:maxt,idx])
print(positions[:maxt,idx])

In [None]:
idx=24
maxt=-1

f,ax = plt.subplots(2,2)
positions=positions.cpu()
velocities=velocities.cpu()
tvec=tvec.cpu()
ax[0,0].plot(positions[:maxt,idx,0].numpy(),positions[:maxt,idx,1].numpy(),'-')
ax[0,0].plot(positions[0,idx,0].numpy(),positions[0,idx,1].numpy(),'o')
ax[0,0].plot(0,0,'o')
ax[0,0].set_xlabel('x')
ax[0,0].set_ylabel('y')
ax[0,0].set_aspect('equal', 'box')

ax[0,1].plot(positions[:maxt,idx,0].numpy(),positions[:maxt,idx,2].numpy(),'-')
ax[0,1].plot(positions[0,idx,0].numpy(),positions[0,idx,2].numpy(),'o')
ax[0,1].plot(0,0,'o')

ax[0,1].set_xlabel('x')
ax[0,1].set_ylabel('z')
ax[0,1].set_aspect('equal', 'box')

ax[1,0].plot( np.sqrt(positions[:maxt,idx,0]**2 + positions[:maxt,idx,1]**2).numpy(),
             positions[:maxt,idx,2].numpy(),'-')
ax[1,0].plot( np.sqrt(positions[0,idx,0]**2 + positions[0,idx,1]**2).numpy(),
             positions[0,idx,2].numpy(),'o')

ax[1,0].set_xlabel('R')
ax[1,0].set_ylabel('z')
ax[1,0].set_aspect('equal', 'box')

ax[1,1].plot(particle_times[:maxt,idx].numpy(),positions[:maxt,idx,2].numpy(),'-')
ax[1,1].plot(particle_times[:maxt,idx].numpy(),velocities[:maxt,idx,2].numpy(),'-')
ax[1,1].set_xlabel('r')
ax[1,1].set_ylabel('$r$ and $v_r$')
f.tight_layout()

In [None]:
plt.plot(particle_times[:maxt,1].numpy())

In [None]:
idx=47
f,ax = plt.subplots(2,2)

ax[0,0].plot(positions[:,idx,0].numpy(),positions[:,idx,1].numpy())
ax[0,0].set_xlabel('x')
ax[0,0].set_ylabel('y')

ax[0,1].plot(positions[:,idx,0].numpy(),positions[:,idx,2].numpy())
ax[0,1].set_xlabel('x')
ax[0,1].set_ylabel('z')

ax[1,0].plot( np.sqrt(positions[:,idx,0]**2 + positions[:,idx,1]**2).numpy(),positions[:,idx,2].numpy())
ax[1,0].set_xlabel('R')
ax[1,0].set_ylabel('z')

ax[1,1].plot(tvec.numpy(),positions[:,idx,0].numpy())
ax[1,1].plot(tvec.numpy(),velocities[:,idx,0].numpy())
ax[1,1].set_xlabel('r')
ax[1,1].set_ylabel('$r$ and $v_r$')
f.tight_layout()

In [None]:
import math
def circular_velocity(potential,rvec=torch.linspace(0,10,100),thetavec=torch.linspace(0,math.pi,60)):
    r,theta=torch.meshgrid(rvec,thetavec)
    x,y,z=r*torch.sin(theta),r*torch.cos(theta),torch.zeros_like(r)
    posvcirc=torch.stack((x.flatten(),y.flatten(),z.flatten()),dim=0).t()
    accvcirc=potential.get_accelerations(posvcirc)
    accvcirc=accvcirc.view(r.shape+(3,))
    fr = (accvcirc[...,0]*x + accvcirc[...,1]*y + accvcirc[...,2]*z)/r
    vcirc = (fr.mean(dim=1)*rvec).sqrt()
    return vcirc

In [None]:
dmpotential=ForceGrid(n=(128,128,128),
                    gridedges=torch.tensor([10.,10.,10.],dtype=torch.float32),
                    smoothing=0.3*20/256)
_=dmpotential.griddata(snap.dm.positions,weights=snap.dm.masses,method='cic')
dmpotential.grid_accelerations()
stellarpotential=ForceGrid(n=(128,128,128),
                    gridedges=torch.tensor([10.,10.,10.],dtype=torch.float32),
                    smoothing=0.3*20/256)
_=stellarpotential.griddata(snap.stars.positions,weights=snap.stars.masses,method='cic')
stellarpotential.grid_accelerations()

In [None]:
rbins = np.linspace(0,10,100)
H,edges = np.histogram(snap.dm.r,rbins,weights=snap.dm.masses)
vol=4*np.pi/3*(edges[1:]**3-edges[:-1]**3)
mid=0.5*(edges[1:]+edges[:-1])
plt.loglog(mid,H/vol)
H,edges = np.histogram(snap.stars.r,rbins,weights=snap.stars.masses)
vol=4*np.pi/3*(edges[1:]**3-edges[:-1]**3)
mid=0.5*(edges[1:]+edges[:-1])
plt.loglog(mid,H/vol)

In [None]:
rvec=torch.linspace(0,10,100)
plt.plot(rvec,circular_velocity(potential,rvec))
plt.plot(rvec,circular_velocity(dmpotential,rvec))
plt.plot(rvec,circular_velocity(stellarpotential,rvec))
plt.ylim([0,2.5])

In [None]:
plt.hexbin(snap.stars.x,snap.stars.z,C=snap.stars.masses,
           bins='log',reduce_C_function=np.sum,extent=(-10,10,-10,10))

In [None]:
plt.hexbin(snap.dm.x,snap.dm.z,C=snap.dm.masses,
           bins='log',reduce_C_function=np.sum,extent=(-10,10,-10,10))

In [None]:
%%timeit
acc=grid.get_acc(positions)

In [None]:
testpos=torch.tensor([[0.,0.,0.]],dtype=torch.float32)
grid=ForceGrid(n=(256,256,256),gridedges=torch.tensor([10.,10.,10.],dtype=torch.float32),smoothing=0.3*20/256)
_=grid.griddata(testpos,method='cic')
grid.grid_acc()
grid.get_acc(torch.tensor([[-9.9,0.,0.]],dtype=torch.float32))

In [None]:
plt.imshow(torch.log(rho.sum(2).type(torch.float)).transpose(0,1))


In [None]:
f,ax = plt.subplots(1,2)
ax[0].imshow(pot[:,:,128].log())
ax[1].imshow(rho[:,:,128].log())