In [None]:
import astropy.coordinates as coord
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

FILEPATH = "/Users/dhogg/ImagePhaseSandbox/ipynb"

# from scipy.fftpack import rfft as fft, irfft as ifft
from scipy.fftpack import fft2 as fft, ifft2 as ifft
from PIL import Image as pil

In [None]:
# import cv2
# vidcap = cv2.VideoCapture('/Users/apricewhelan/Downloads/48728220883_b87ea8cf30_vm.mp4')
# success, image = vidcap.read()
# count = 0
# while success:
#     if count == 0 or count == 31:
#         cv2.imwrite("frame%d.jpg" % count, image)     # save frame as JPEG file      
#     success, image = vidcap.read()
#     count += 1

In [None]:
rgbs = []
for filename in ['frame0.jpg', 'frame30.jpg']:
    im = pil.open(filename)
    rgb = np.array(im).astype(np.float64)
    rgbs.append(rgb)

In [None]:
# for rgb in rgbs:
#     fig, ax = plt.subplots(figsize=(10, 10))
#     ax.imshow(rgb[..., 0], cmap='Greys')
#     ax.set_aspect('equal')

In [None]:
all_ffts = []
for rgb in rgbs:
    for band in range(3):
        all_ffts.append(fft(rgb[..., band]))
    
#     fig, ax = plt.subplots(figsize=(6, 6))
#     ax.imshow(ifft(fft(rgb)), cmap='Greys')
#     ax.set_aspect('equal')

In [None]:
def get_ims_at_new_times_stupidly(data1, data2, times, write=False):
    # implicitly time=0 is data1, time=1 is data2
    # data1 and data2 are ffts of the images at t=0 and t=1
    img1 = ifft(data1).real
    img2 = ifft(data2).real
    
    ims = []
    for j, time in enumerate(times):
        newimg = img1 + (img2 - img1) * time
        ims.append(newimg)
        if write:
            fig, ax = plt.subplots(figsize=(10, 8))
            ax.imshow(ims[-1], cmap='Greys')
            ax.set_aspect('equal')
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            fig.tight_layout()
            fig.savefig(FILEPATH + '/scratch/sn_frame_stupid{:03d}.png'.format(j), 
                        dpi=150)
            plt.close(fig)
    return ims

In [None]:
def get_ims_at_new_times(data1, data2, times, write=False):
    # implicitly time=0 is data1, time=1 is data2
    # data1 and data2 are ffts of the images at t=0 and t=1

    # make unit vectors
    f1 = np.stack((data1.real, data1.imag), axis=0)
    f2 = np.stack((data2.real, data2.imag), axis=0)
    amps1 = np.abs(data1)
    amps2 = np.abs(data2)
    ufs1 = f1 / amps1[None]
    ufs2 = f2 / amps2[None]
    
    # make angles
    cosdthetas = np.sum(ufs1 * ufs2, axis=0)
    sindthetas = np.cross(ufs1, ufs2, axis=0)
    thetas1 = np.arctan2(data1.imag, data1.real)
    thetas2 = np.arctan2(data2.imag, data2.real)
    dthetas = np.arctan2(sindthetas, cosdthetas)
    
    ims = []
    for j, time in enumerate(times):
        if time < 0:
            newamps = amps1
        elif time > 1:
            newamps = amps2
        else:
            newamps = amps1 + (amps2 - amps1) * time
        newthetas = thetas1 + dthetas * time
        newdata = newamps * np.exp(1j * newthetas)
        ims.append(ifft(newdata).real)
        if write:
            fig, ax = plt.subplots(figsize=(10, 8))
            ax.imshow(ims[-1], cmap='Greys')
            ax.set_aspect('equal')
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            fig.tight_layout()
            fig.savefig(FILEPATH + '/scratch/sn_frame{:03d}.png'.format(j), 
                        dpi=150)
            plt.close(fig)
    
    return ims

In [None]:
times = np.arange(-1., 2.001, 0.25)
Rims = get_ims_at_new_times(all_ffts[0], all_ffts[3], times, write=True)
# Gims = get_ims_at_new_times(all_ffts[1], all_ffts[4], times)
# Bims = get_ims_at_new_times(all_ffts[2], all_ffts[5], times)

In [None]:
Rims_s = get_ims_at_new_times_stupidly(all_ffts[0], all_ffts[3], times, write=True)
# Gims_s = get_ims_at_new_times_stupidly(all_ffts[1], all_ffts[4], times)
# Bims_s = get_ims_at_new_times_stupidly(all_ffts[2], all_ffts[5], times)

In [None]:
len(Rims)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(ims[40], cmap='Greys')

In [None]:
def get_ims(data1, data2, nslice, save_dont_return=False):
    f1 = np.stack((data1.real, data1.imag), axis=0)
    f2 = np.stack((data2.real, data2.imag), axis=0)
    
    amp1 = np.abs(data1)
    amp2 = np.abs(data2)
    
    uf1 = f1 / amp1[None]
    uf2 = f2 / amp2[None]

    cross = np.cross(uf1, uf2, axis=0)
    theta2 = np.arcsin(cross)
    thetas = [np.linspace(0, theta2, nslice)[1] - 0.
    dtheta = thetas[1] - thetas[0]
    
    frames = np.concatenate((np.arange(-4*nslice, 0), 
                             np.arange(0, nslice),
                             np.arange(nslice, nslice + 4*nslice)))
    
    ims = [] 
    for j, i in enumerate(frames):
        theta = dtheta * i
        if i < 0:
            fac = amp1
        elif i > nslice:
            fac = amp2
        else:
            fac = (amp2 - amp1) / nslice * i + amp1
        
        R = np.array([[np.cos(theta), np.sin(theta)],
                      [-np.sin(theta), np.cos(theta)]])
        rot_f1 = np.einsum('ijnm,inm->jnm', R, uf1)
            
        rot_f1 *= fac
        rot_f1 = rot_f1[0] + 1j*rot_f1[1]
        rot_f1 = ifft(rot_f1).real
        
        if save_dont_return:
            fig, ax = plt.subplots(figsize=(10, 8))
            ax.imshow(rot_f1.real, cmap='Greys')
            ax.set_aspect('equal')
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            fig.tight_layout()
            fig.savefig(FILEPATH + '/scratch/supernova-phase-bump/sn_frame_{j:03d}.png', 
                        dpi=150)
            plt.close(fig)
            
        else:
            ims.append(rot_f1)
    
    if not save_dont_return:
        return np.array(ims)

In [None]:
import glob, os
for filename in glob.glob(FILEPATH + '/scratch/supernova-phase-bump/*'):
    num = int(filename.split('.')[0].split('_')[-1])
    new_filename = FILEPATH + '/scratch/supernova-phase-bump/sn_frame_{:03d}.png'.format(num)
    os.system('mv {} {}'.format(filename, new_filename))

In [None]:
nslice = 16
ims = get_ims(all_ffts[0], all_ffts[3], nslice, save_dont_return=True)

In [None]:
# for im in [ims[0], ims[nslice], ims[nslice+nslice], ims[-1]]:
#     fig, ax = plt.subplots(figsize=(10, 8))
#     ax.imshow(im.real, cmap='Greys')
#     ax.set_aspect('equal')
#     ax.xaxis.set_visible(False)
#     ax.yaxis.set_visible(False)
#     fig.tight_layout()

In [None]:
# max_theta = np.arcsin(cross[2])

R = np.array([[np.cos(theta), -np.sin(theta)],
                  [np.sin(theta), np.cos(theta)]])
R.shape

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(ifft(amp1 * np.exp(1j * phase2)).real - rgbs[0][..., 0], cmap='Greys')
# plt.colorbar()

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(ifft(amp2 * np.exp(1j * phase1)).real, cmap='Greys')

In [None]:
fft(x, n=None, axis=-1, overwrite_x=0)¶


In [None]:
(x, shape=None, axes=(-2, -1), overwrite_x=False)[source]¶