In [None]:
import numpy as np
from PIL import Image, ImageSequence
import matplotlib.pyplot as plt
import torch

In [None]:
def project_theta(theta, m_values):
    projections = []
    for m in m_values:
        sin_m_theta = torch.sin(m * theta)
        cos_m_theta = torch.cos(m * theta)
        projected_vectors = torch.stack((cos_m_theta, sin_m_theta), axis=-1)
        projections.append(projected_vectors)
    return torch.stack(projections, axis=0)

def evaluate_functions_on_theta(theta, coefficients_list, m_values):
    evaluated_function = torch.zeros(theta.shape, dtype=torch.float32)    
    for (a_cos, a_sin), m in zip(coefficients_list, m_values):
        evaluated_function += a_sin * torch.sin(m * theta) + a_cos * torch.cos(m * theta)
    return evaluated_function

### Note, I think the columns and rows are showing x vs y respectively (opposite how we normally plot)
#### We can double check this later

## Load data and optimize offset

In [None]:
# Define parameters
ms = torch.arange(12)
angles = torch.arange(0, 6) * 2 * torch.pi / 6.

# Extract data_theta
data = np.load('/Users/cadenmyers/billingelab/dev/skyrmion_lattices/images/image_111020.npz')['data']
data_theta = torch.atan2(torch.tensor(data[1]), torch.tensor(data[0]))

# Extract data intensity and phi (unmasked)
movie = '121855.npz'
movie_images = np.load(f'/Users/cadenmyers/billingelab/dev/skyrmion_lattices/tests/{movie}')
movie_intensity = torch.tensor(movie_images['intensity'])
movie_intensity.shape

# Preprocess phi to get angle difference
# phi = training_data['phi'] - 253.1473
# print("phi:", phi)

In [None]:
# Made offset a required argument
def optimize_offset(intensity, offset):
    max_iter = 101
    opt = torch.optim.Adam([offset], lr=1e-2)
    for i in range(max_iter):
        projection = project_theta(angles + offset, ms).sum(1)
        evaluate_image_theta = evaluate_functions_on_theta(data_theta, projection, ms)
        loss = -(intensity * evaluate_image_theta).sum()
        opt.zero_grad()
        loss.backward()
        opt.step()
        #if i % 100 == 0:
        #    print(loss.item(), offset.item())
    return offset, evaluate_image_theta

offset =  torch.tensor(0.)
offset_list = []
offset.requires_grad = True
for i in range(movie_intensity.shape[0]):
    offset_ang, intensity = optimize_offset(movie_intensity[i], offset)
    print(f'{i}: offset angle=', offset_ang.item())
    offset_list.append(offset_ang.item())
    plt.plot(offset_list)

In [None]:
optimize_offset(movie_intensity[0], torch.tensor(0.))

In [None]:
# Training
offset_list, evaluate_image_theta_list = [], []
offset = torch.tensor(0.)
offset.requires_grad = True
for intensity in movie_intensity:
    offset, evaluate_image_theta = optimize_offset(intensity, offset)
    offset_list.append(offset.item()), evaluate_image_theta_list.append(evaluate_image_theta)
print("offset in radius:", offset_list)

In [None]:
# Note: we can probably decrease max_iter above.. I used 1501 because I wanted to see if it can improve accuracy
# Results
offset_diff_degrees = []
for offset in offset_list:
    # We're plotting y vs x instead of x vs y hence we need to use 90-offset instead of offset to compare with phi
    # -90+15.46960665870267 makes the first offset 0
    new_offset = 90 - offset/torch.pi*180 - 90 + 15.46960665870267
    offset_diff_degrees.append(new_offset)
print("offset (preprocessed) in degrees:", offset_diff_degrees)
# print("\nAbsolute errors:", np.abs(phi - offset_diff_degrees))

In [None]:
# The weird case (need a mask to train again)
the_offset = torch.tensor(-1.2826639413833618)
the_offset.requires_grad = True
the_offset, the_evaluate_image_theta = optimize_offset(movie_intensity[51], the_offset)
print(-the_offset.item()/torch.pi*180 + 15.46960665870267)
plt.imshow((the_evaluate_image_theta / the_evaluate_image_theta.abs().max() + movie_intensity[51] / movie_intensity[51].abs().max()).detach(), cmap='plasma')



In [None]:
fig, ax = plt.subplots(nrows=6, ncols=10, figsize=(30, 18))
for i in range(6):
    for j in range(10):
        ax[i, j].imshow((evaluate_image_theta_list[i*6 + j] / evaluate_image_theta_list[i*6 + j].abs().max() + movie_intensity[i*6 + j] / movie_intensity[i*6 + j].abs().max()).detach(), cmap='plasma')

