In [None]:
import astropy.coordinates as coord
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

# from scipy.fftpack import rfft as fft, irfft as ifft
from scipy.fftpack import fft2 as fft, ifft2 as ifft
from PIL import Image as pil

In [None]:
# import cv2
# vidcap = cv2.VideoCapture('/Users/apricewhelan/Downloads/48728220883_b87ea8cf30_vm.mp4')
# success, image = vidcap.read()
# count = 0
# while success:
#     if count == 0 or count == 31:
#         cv2.imwrite("frame%d.jpg" % count, image)     # save frame as JPEG file      
#     success, image = vidcap.read()
#     count += 1

In [None]:
# read in the two images
rgbs = []
for filename in ['frame0.jpg', 'frame30.jpg']:
    im = pil.open(filename)
    rgb = np.array(im).astype(np.float64)
    rgbs.append(rgb)
rgbs[0].shape

In [None]:
# take a tiny subsection for testing locality of this. Or not!
rgbs = [rgb[900:1028,1200:1328] for rgb in rgbs]
rgbs[0].shape

In [None]:
# show me the images
for rgb in rgbs:
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(rgb[..., 0], cmap='Greys')
    ax.set_aspect('equal')

In [None]:
# get the FFTs
all_ffts = []
for rgb in rgbs:
    for band in range(3):
        all_ffts.append(fft(rgb[..., band]))
    
#     fig, ax = plt.subplots(figsize=(6, 6))
#     ax.imshow(ifft(fft(rgb)), cmap='Greys')
#     ax.set_aspect('equal')

In [None]:
def get_ims_at_new_times_stupidly(data1, data2, times, write=False):
    # implicitly time=0 is data1, time=1 is data2
    # data1 and data2 are ffts of the images at t=0 and t=1
    img1 = ifft(data1).real
    img2 = ifft(data2).real
    
    ims = []
    for j, time in enumerate(times):
        newimg = img1 + (img2 - img1) * time
        ims.append(newimg)
        if write:
            fig, ax = plt.subplots(figsize=(10, 8))
            ax.imshow(ims[-1], cmap='Greys')
            ax.set_aspect('equal')
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            fig.tight_layout()
            fig.savefig('scratch/sn_frame_stupid{:03d}.png'.format(j), 
                        dpi=150)
            plt.close(fig)
    return ims

In [None]:
def get_ims_at_new_times(data1, data2, times, write=False):
    # implicitly time=0 is data1, time=1 is data2
    # data1 and data2 are ffts of the images at t=0 and t=1

    # make unit vectors
    f1 = np.stack((data1.real, data1.imag), axis=0)
    f2 = np.stack((data2.real, data2.imag), axis=0)
    amps1 = np.abs(data1)
    amps2 = np.abs(data2)
    ufs1 = f1 / amps1[None]
    ufs2 = f2 / amps2[None]
    
    # make angles
    cosdthetas = np.sum(ufs1 * ufs2, axis=0)
    sindthetas = np.cross(ufs1, ufs2, axis=0)
    thetas1 = np.arctan2(data1.imag, data1.real)
    thetas2 = np.arctan2(data2.imag, data2.real)
    dthetas = np.arctan2(sindthetas, cosdthetas)
    
    ims = []
    for j, time in enumerate(times):
        if time < 0:
            newamps = amps1
        elif time > 1:
            newamps = amps2
        else:
            newamps = amps1 + (amps2 - amps1) * time
        newthetas = thetas1 + dthetas * time
        newdata = newamps * np.exp(1j * newthetas)
        ims.append(ifft(newdata).real)
        if write:
            fig, ax = plt.subplots(figsize=(10, 8))
            ax.imshow(ims[-1], cmap='Greys')
            ax.set_aspect('equal')
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            fig.tight_layout()
            fig.savefig('scratch/sn_frame{:03d}.png'.format(j), 
                        dpi=150)
            plt.close(fig)
    
    return ims

In [None]:
times = np.arange(-2., 3.001, 0.25)
Rims = get_ims_at_new_times(all_ffts[0], all_ffts[3], times, write=True)
# Gims = get_ims_at_new_times(all_ffts[1], all_ffts[4], times)
# Bims = get_ims_at_new_times(all_ffts[2], all_ffts[5], times)
len(Rims)

In [None]:
Rims_s = get_ims_at_new_times_stupidly(all_ffts[0], all_ffts[3], times, write=True)
# Gims_s = get_ims_at_new_times_stupidly(all_ffts[1], all_ffts[4], times)
# Bims_s = get_ims_at_new_times_stupidly(all_ffts[2], all_ffts[5], times)
len(Rims_s)