#### Note: This will run if you have the the movies saved as "start_numor".npz files. Check other notebook titled SANS_to_npz.ipynb for code to assist in doing this.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
from ipywidgets import interact, IntSlider
from filter_functions import mask_and_blur_images, filter_function, normalize_min_max, find_k_value, grid_search_offset_cw, grid_search_offset_ccw
from bg_mpl_stylesheets.styles import all_styles
plt.style.use(all_styles["bg-style"])


In [None]:
folder_movies = {
    # field sweeps
    '55.3 K 50 mW Negative Field Sweep': ['54261', '54863', '55162', '55461', '55760', '56059', '56358'],
    '55.3 K 50 mW Positive Field Sweep': ['73131', '73430', '73729', '74028', '74327', '74626'],
    '55.8 K 25 mW Negative Field Sweep': ['68463', '68762', '69061', '69360', '69659', '69958'],
    '55.8 K 25 mW Positive Field Sweep': ['70797', '71096', '71395', '71694', '71993', '72292'],

    # temperature sweeps (only including those with signal)
    '-29 mT 10 mW Temp Sweep': ['61221', '61759', '62417'],
    '-29 mT 25 mW Temp Sweep': ['64690', '64152', '63614', '63076'],
    '+29 mT 25 mW Temp Sweep': ['76899', '76361', '75823', '75285'],
    '-29 mT 50 mW Temp Sweep': ['52666', '59887', '52128', '59349', '51590', '58811', '51052', '58273', '50514', '57735', '49976', '57197', '49438'],
    '+29 mT 50 mW Temp Sweep': ['47963', '47425', '67624', '46887', '67086', '46349', '45811', '66670', '45268', '44730', '44192']
}
# bad data
# '-29 mT 25 mW Temp Sweep': 65766, 65228, 66304


# Movie numor matched with each title
movie_to_title = {
    # --- Field Sweeps ---
    
    # 55.3 K 50 mW Negative Field Sweep
    '54261': '54261 55.3 K -23 mT 50 mW',
    '54863': '54863 55.3 K -25 mT 50 mW',
    '55162': '55162 55.3 K -27 mT 50 mW',
    '55461': '55461 55.3 K -29 mT 50 mW',
    '55760': '55760 55.3 K -31 mT 50 mW',
    '56059': '56059 55.3 K -33 mT 50 mW',
    '56358': '56358 55.3 K -35 mT 50 mW',

    # 55.3 K 50 mW Positive Field Sweep
    '73131': '73131 55.3 K +23 mT 50 mW',
    '73430': '73430 55.3 K +25 mT 50 mW',
    '73729': '73729 55.3 K +27 mT 50 mW',
    '74028': '74028 55.3 K +29 mT 50 mW',
    '74327': '74327 55.3 K +31 mT 50 mW',
    '74626': '74626 55.3 K +33 mT 50 mW',

    # 55.8 K 25 mW Negative Field Sweep
    '68463': '68463 55.8 K -23 mT 25 mW',
    '68762': '68762 55.8 K -25 mT 25 mW',
    '69061': '69061 55.8 K -27 mT 25 mW',
    '69360': '69360 55.8 K -29 mT 25 mW',
    '69659': '69659 55.8 K -31 mT 25 mW',
    '69958': '69958 55.8 K -33 mT 25 mW',

    # 55.8 K 25 mW Positive Field Sweep
    '70797': '70797 55.8 K +23 mT 25 mW',
    '71096': '71096 55.8 K +25 mT 25 mW',
    '71395': '71395 55.8 K +27 mT 25 mW',
    '71694': '71694 55.8 K +29 mT 25 mW',
    '71993': '71993 55.8 K +31 mT 25 mW',
    '72292': '72292 55.8 K +33 mT 25 mW',

    # --- Temperature Sweeps ---

    # -29 mT 10 mW Temp Sweep
    '61221': '61221 56.0 K -29 mT 10 mW',
    '61759': '61759 56.3 K -29 mT 10 mW',
    '62417': '62417 57.5 K -29 mT 10 mW',

    # -29 mT 25 mW Temp Sweep
    '65766': '65766 54.8 K -29 mT 25 mW',
    '65228': '65228 55.3 K -29 mT 25 mW',
    '64690': '64690 55.8 K -29 mT 25 mW',
    '64152': '64152 56.3 K -29 mT 25 mW',
    '63614': '63614 56.8 K -29 mT 25 mW',
    '63076': '63076 57.3 K -29 mT 25 mW',

    # +29 mT 25 mW Temp Sweep
    '76899': '76899 55.8 K +29 mT 25 mW',
    '76361': '76361 56.3 K +29 mT 25 mW',
    '75823': '75823 56.8 K +29 mT 25 mW',
    '75285': '75285 57.3 K +29 mT 25 mW',

    # -29 mT 50 mW Temp Sweep
    '52666': '52666 54.3 K -29 mT 50 mW',
    '59887': '59887 54.5 K -29 mT 50 mW',
    '52128': '52128 54.8 K -29 mT 50 mW',
    '59349': '59349 55.0 K -29 mT 50 mW',
    '51590': '51590 55.3 K -29 mT 50 mW',
    '58811': '58811 55.5 K -29 mT 50 mW',
    '51052': '51052 55.8 K -29 mT 50 mW',
    '58273': '58273 56.0 K -29 mT 50 mW',
    '50514': '50514 56.3 K -29 mT 50 mW',
    '57735': '57735 56.5 K -29 mT 50 mW',
    '49976': '49976 56.8 K -29 mT 50 mW',
    '57197': '57197 57.0 K -29 mT 50 mW',
    '49438': '49438 57.3 K -29 mT 50 mW',

    # +29 mT 50 mW Temp Sweep
    '47963': '47963 53.8 K +29 mT 50 mW',
    '47425': '47425 54.3 K +29 mT 50 mW',
    '67624': '67624 54.5 K +29 mT 50 mW',
    '46887': '46887 54.8 K +29 mT 50 mW',
    '67086': '67086 55.0 K +29 mT 50 mW',
    '46349': '46349 55.3 K +29 mT 50 mW',
    '45811': '45811 55.8 K +29 mT 50 mW',
    '66670': '66670 55.9 K +29 mT 50 mW',
    '45268': '45268 56.3 K +29 mT 50 mW',
    '44730': '44730 56.8 K +29 mT 50 mW',
    '44192': '44192 57.3 K +29 mT 50 mW',
}
[46349, 69958, 56059]
filtered_indices = {
    # --- Field Sweeps ---
    
    # 55.3 K 50 mW Negative Field Sweep
    '54261': [],
    '54863': [], 
    '55162': [], 
    '55461': [15, 16],
    '55760': [], 
    '56059': [], 
    '56358': [],

    # 55.3 K 50 mW Positive Field Sweep
    '73131': [7], 
    '73430': [], 
    '73729': [], 
    '74028': [], 
    '74327': [129, 130, 214, 215], 
    '74626': [],

    # 55.8 K 25 mW Negative Field Sweep
    '68463': [265, 266],
    '68762': [39, 40, 175, 176, 206],
    '69061': [], 
    '69360': [254, 255], 
    '69659': [132, 133, 252, 253], 
    '69958': [271, 272],

    # 55.8 K 25 mW Positive Field Sweep
    '70797': [47, 48, 216, 265, 266], 
    '71096': [], 
    '71395': [], 
    '71694': [35, 36], 
    '71993': [265, 266],
    '72292': [],

    # --- Temperature Sweeps ---
    
    # -29 mT 10 mW Temp Sweep
    '61221': [], 
    '61759': [], 
    '62417': [],

    # -29 mT 25 mW Temp Sweep
    '65766': [], 
    '65228': [], 
    '64690': [], 
    '64152': [], 
    '63614': [], 
    '63076': [],

    # +29 mT 25 mW Temp Sweep
    '76899': [], 
    '76361': [], 
    '75823': [], 
    '75285': [],

    # -29 mT 50 mW Temp Sweep
    '52666': [], 
    '59887': [], 
    '52128': [], 
    '59349': [], 
    '51590': [], 
    '58811': [], 
    '51052': [], 
    '58273': [], 
    '50514': [], 
    '57735': [], 
    '49976': [], 
    '57197': [], 
    '49438': [],

    # +29 mT 50 mW Temp Sweep
    '47963': [], 
    '47425': [], 
    '67624': [], 
    '46887': [], 
    '67086': [], 
    '46349': [], 
    '45811': [], 
    '66670': [], 
    '45268': [], 
    '44730': [], 
    '44192': []
}

print(filtered_indices.keys())


In [None]:
# "easy": 69958, "med": 56059, "hard": 46349

numor = 46349
path_to_npz = f'/Users/cadenmyers/billingelab/dev/sym_adapted_filts/experimental_data/npz_sept_numor_data/{numor}.npz'
data = np.load(path_to_npz)['data']
print(data[0].shape)
h, w = data[0].shape[:2]
cx, cy = w // 2, h // 2
x_grid, y_grid = np.meshgrid(np.arange(w), np.arange(h))
# handle slight offcentering of DPs
shift_x, shift_y = 0, -2
x_grid_shifted = x_grid - shift_x
y_grid_shifted = y_grid - shift_y
DATA_THETA = np.arctan2(y_grid_shifted - cy, x_grid_shifted - cx)
n_folds=6
OFFSET_ADJUSTMENT = int(360/n_folds)
OFFSET_ADJUSTMENT_rad = np.deg2rad(OFFSET_ADJUSTMENT)
resolution=10.8

dataset = '55.3 K 50 mW Negative Field Sweep'
numors = folder_movies[dataset]
print(numor)

In [None]:
subset_data = data[:50]
gt_offsets = np.load(f'/Users/cadenmyers/billingelab/dev/sym_adapted_filts/working_code/ratchet_model/data/ratchet_offsets_{numor}.npz')
gt_offset1 = gt_offsets['offset1']
gt_offset2 = gt_offsets['offset2']
offsets1, offsets2 = grid_search_offset_cw(dps=subset_data, n_folds=6, tolerance_forward=20, tolerance_reverse=3)

In [None]:
plt.plot(offsets1)
plt.plot(gt_offset1[:50], label='gt offset1')
# plt.plot(offsets2)
plt.legend()

## Grid search through list of numors

In [None]:
# TEST BELOW FOR LOOP ON INDIVIDUAL SWEEPS
path_to_npzs = '/Users/cadenmyers/billingelab/dev/sym_adapted_filts/experimental_data/npz_sept_numor_data/'

tolerance_forward = 20
tolerance_reverse = 5
jump_threshold_offset1 = 20
jump_threshold_offset2 = 20

def process_numors(numors, tolerance_reverse=tolerance_reverse, tolerance_forward=tolerance_forward, 
                   path_to_npzs=path_to_npzs, filtered_indices=filtered_indices):
    """Processes a list of numors, removes bad DPs, applies masking, and calculates offsets.
    Returns:
        offsets_info (dict): {numor: {"offset1": np.array, "offset2": np.array}}
    """
    offsets_info = {}
    print(f'Processing {len(numors)} numors...')

    for numor in numors:
        print(f'Processing numor: {numor}')
        data = np.load(f"{path_to_npzs}{numor}.npz")['data']
        # Remove bad DPs if numor exists in filtered_indices
        if numor in filtered_indices:
            for index in sorted(filtered_indices[numor], reverse=True):  # Reverse to avoid index shifting issues
                print(f'    Deleting index {index} from {numor}')
                data = np.delete(data, index, axis=0)
        dps = mask_and_blur_images(data)
        if '+' in movie_to_title[str(numor)] or 'Positive' in movie_to_title[str(numor)]:
            print('   Positive field detected, cw search', movie_to_title[str(numor)])
            offsets1, offsets2 = grid_search_offset_cw(dps, offset_step=0.5, tolerance_forward=tolerance_forward, 
                                                    tolerance_reverse=tolerance_reverse)
        else:
            print('   Negative field detected, ccw search', movie_to_title[str(numor)])
            offsets1, offsets2 = grid_search_offset_ccw(dps, offset_step=0.5, tolerance_forward=tolerance_forward, 
                                                    tolerance_reverse=tolerance_reverse)
        offsets_info[numor] = {"offset1": offsets1, "offset2": offsets2}
    print('Processing complete.')
    return offsets_info  # Dictionary: {numor: {"offset1": array, "offset2": array}}


def plot_numor_offsets(offsets_info, tolerance_reverse=tolerance_reverse, tolerance_forward=tolerance_forward):
    """Plots offsets for all processed numors."""
    numors = list(offsets_info.keys())
    colors = [plt.cm.viridis(i / len(numors)) for i in range(len(numors))]
    print("Generating plots...")
    for i, (numor, offsets) in enumerate(offsets_info.items()):
        offset1, offset2 = offsets["offset1"], offsets["offset2"]
        plt.plot(offset1, label=f'offset1 {numor}', color=colors[i])
        plt.plot(offset2, color=colors[i], linestyle='dotted')
    plt.xlabel("Index")
    plt.title(f"{dataset}")
    plt.ylabel("Offset (degrees)")
    plt.minorticks_on()
    plt.tick_params(direction='in')
    plt.legend()
    plt.grid(True)




## Process all datasets

In [None]:
for i, (dataset, numors) in enumerate(folder_movies.items()):
    # if i == 0 or i==2: # skip any dataset with index i
    #     print(f'skipping datatset {dataset}')
    #     continue
    print(f'{i}:', dataset)
    print('-------------------------------------')
    print(numors)
    offset_info = process_numors(numors)
    # offset_info_new_rf = change_reference_frame(offset_info)
    for numor, offsets in offset_info.items():
        np.savez(f"data/ratchet_offsets_{numor}.npz", 
                    **{f"offset1": offsets["offset1"], 
                    f"offset2": offsets["offset2"]})
        print(f'    data saved as: data/ratchet_offsets_{numor}.npz')
    plot_numor_offsets(offset_info)
    plt.title(f"{dataset}")
    plt.savefig(f'data/{dataset}.png')
    print(f'    fig saved as: data/{dataset}.png')
    plt.show()

## Plot and Save with good labels

In [None]:


import re

def extract_variable(title, dataset):
    if 'Field Sweep' in dataset:
        # Extract field
        field_match = re.search(r'([+-]?\d+)\s*mT', title)
        return f"{field_match.group(1)} mT" if field_match else title
    elif 'Temp Sweep' in dataset:
        # Extract temperature
        temp_match = re.search(r'(\d+\.\d+)\s*K', title)
        return f"{temp_match.group(1)} K" if temp_match else title
    else:
        return title  # Return full title if dataset type is not recognized

fontsize = 16
def plot_offsets(dataset, numors, movie_to_title):
    plt.figure(figsize=(10, 7))
    
    for numor in numors:
        # Load the data from the npz file
        data = np.load(f'/Users/cadenmyers/billingelab/dev/sym_adapted_filts/working_code/ratchet_model/data/ratchet_offsets_{numor}.npz')
        offset1 = data['offset1']
        offset2 = data['offset2']
        
        # Create a color for this numor
        color = plt.cm.viridis(numors.index(numor) / len(numors))
        
        # Extract the variable label based on the dataset type
        label = extract_variable(movie_to_title[numor], dataset)
        time2 = (np.array(range(offset2.shape[0])) + 1) *10
        time1 = (np.array(range(offset1.shape[0])) + 1) *10

        # Plot offset1 and offset2
        plt.plot(time1, offset1, label=f'{label}', color=color)
        plt.plot(time2, offset2, color=color, linestyle='dotted')
    
    plt.xlabel("Time (s)", fontsize=fontsize)
    plt.ylabel("Offset (degrees)", fontsize=fontsize)
    plt.title(dataset, fontsize=fontsize)
    plt.minorticks_on()
    plt.tick_params(direction='in')
    plt.xticks(fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    plt.legend(fontsize=fontsize)#bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f'data/{dataset.replace(" ", "_")}_offsets.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

# Plot for each dataset
for dataset, numors in folder_movies.items():
    plot_offsets(dataset, numors, movie_to_title)
    print(f"Plot generated for {dataset}")


## Interactive plotting

In [None]:
numor = '74028'
path_to_offsets = '/Users/cadenmyers/billingelab/dev/sym_adapted_filts/working_code/ratchet_model/data/'
filename = f'ratchet_offsets_{numor}.npz'
offset1 = np.load(path_to_offsets + filename)['offset1']
offset2 = np.load(path_to_offsets + filename)['offset2']

# offsets1 = np.rad2deg(offsets1)
# print(offsets1)
# offsets2 = np.deg2rad(np.array(offset2))
# print(offset1[:5])
# print(offset2[:5])
# offsets2 = np.rad2deg(offsets2)
# print(offsets2)

def interactive_plot(frame_idx):
    '''Plot the intensity data and filter images with respect to the selected frame index.'''
    offset1_rad = np.deg2rad(offset1[frame_idx])
    offset2_rad = np.deg2rad(offset2[frame_idx])
    n_folds = 6
    k=find_k_value(n_folds=n_folds)
    image1 = filter_function(k, offset1_rad, n_folds=n_folds)
    image2 = filter_function(k, offset2_rad, n_folds=n_folds)
    dp_norm = normalize_min_max(dps[frame_idx])
    # Create figure and axes
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 5))

    # first domain
    ax1.imshow(image1 + dp_norm)
    # ax1.imshow(image1 + 2*image2 + normalize_min_max(dps[frame_idx]), origin="lower")
    ax1.set_title(f"Frame {frame_idx+1}: first domain")
    ax1.axis("off")

    # second domain
    ax2.imshow(image2 + dp_norm)
    ax2.set_title("2nd domain")

    # raw DP
    ax3.imshow(dp_norm)
    ax3.set_title(f"Frame {frame_idx+1}: Intensity Data")
    ax3.axis("off")
    # Update the figure title with offset values
    fig.suptitle(f"numor={numor}, \
                 Offset1={round(offset1[frame_idx], 2)}°, \
                Offset2={round(offset2[frame_idx], 2)}°, \
                Time={(frame_idx+1)*10}s, \
                Reverse tol={tolerance_reverse}")
    
    # Display the plot
    plt.tight_layout()
    plt.show()

# Create interactive plot with sliders for frame index
interact(
    interactive_plot,
    frame_idx=IntSlider(value=0, min=0, max=len(dps)-1, step=1, description='Frame Index')
);


## Plot comparison with hand-calc data

In [None]:
# load hand clicked
seed_path = "/Users/cadenmyers/billingelab/dev/sym_adapted_filts/working_code/seed_generating/seed_data/"
analysis_path = "/Users/cadenmyers/billingelab/dev/sym_adapted_filts/analysis-generated_data/"

numor = '69958'
# Define file paths
cm_offset1_path = f"{seed_path}cm_{numor}_offset1_seed.npz"
cm_offset2_path = f"{seed_path}cm_{numor}_offset2_seed.npz"
yc_offset1_path = f"{seed_path}yc_{numor}_offset1_seed.npz"
yc_offset2_path = f"{seed_path}yc_{numor}_offset2_seed.npz"
dr_offset1_path = f"{seed_path}dr_{numor}_offset1_seed.npz"
dr_offset2_path = f"{seed_path}dr_{numor}_offset2_seed.npz"
nc_offset1_path = f"{seed_path}nc_{numor}_offset1_seed.npz"
nc_offset2_path = f"{seed_path}nc_{numor}_offset2_seed.npz"
oob_offset1_path = f"{analysis_path}{numor}_offsets.npz"
oob_offset2_path = f"{analysis_path}{numor}_offsets.npz"

# Load data
offset1cm = np.load(cm_offset1_path)['data']
offset2cm = np.load(cm_offset2_path)['data']
offset1yc = np.load(yc_offset1_path)['data']
offset2yc = np.load(yc_offset2_path)['data']
offset1dr = np.load(dr_offset1_path)['data']
offset2dr = np.load(dr_offset2_path)['data']
offset1nc = np.load(nc_offset1_path)['data']
offset2nc = np.load(nc_offset2_path)['data']
offset1oob = np.load(oob_offset1_path)['offset1']
offset2oob = np.load(oob_offset2_path)['offset2']

window_length = 5
polyorder=4
def compute_smoothed_derivative(offset, window_length=window_length, polyorder=polyorder):
    '''compute velocity of data after savgol_filter is applied, assumes frame rate of 10s'''
    smoothed_angle = savgol_filter(offset, window_length=window_length, polyorder=polyorder)
    time = (np.arange(offset.shape[0])+1)*10
    smoothed_derivative = (np.gradient(smoothed_angle, time))
    return smoothed_derivative

colors = {
    "cm": "blue",
    "yc": "red",
    "dr": "green",
    "nc": "purple",
    "oob": 'orange',
    'ratchet': 'hotpink'
}

# plot over different tolerances

tols = [5, 7, 6, 5, 4, 3]
plt.figure(figsize=(7,7))
for i, tol in enumerate(tols):
    viri = [plt.cm.viridis(i / len(tols)) for i in range(len(tols))]
    offset1r = np.load(f'/Users/cadenmyers/billingelab/dev/sym_adapted_filts/working_code/ratchet_model/data/reverse_tol_{tol}_{numor}_offsets.npz')['offset1']
    offset2r = np.load(f'/Users/cadenmyers/billingelab/dev/sym_adapted_filts/working_code/ratchet_model/data/reverse_tol_{tol}_{numor}_offsets.npz')['offset2']
    plt.plot((offset1r - offset1r[0]), color=viri[i], label=f'offset1 ratchet tol={tol}')
    # velo2 = compute_smoothed_derivative(offset2r)
    # velo = compute_smoothed_derivative(offset1r)
    # plt.plot(velo, label=f'velo {numor}')
    # plt.plot(velo2, linestyle=':')
    plt.plot((offset2r - offset2r[0]), color=viri[i], linestyle=':')

plt.plot(offset1-offset1[0], label='offset1 updated', color= colors['ratchet'])
plt.plot(offset2-offset2[0], color=colors['ratchet'], linestyle=':')
# plt.plot(offset1oob - offset1oob[0], color=colors['oob'], label='offset1 GD')
# # plt.plot(offset2oob - offset2oob[0], color=colors['oob'], linestyle=':')
# plt.plot(offset1cm - offset1cm[0], label='offset1 cm', color=colors['cm'])
# # plt.plot(offset2cm - offset2cm[0], color=colors['cm'], linestyle=':')
# plt.plot(offset1yc - offset1yc[0], label='offset1 yc', color=colors["yc"])
# # plt.plot(offset2yc - offset2yc[0], color=colors["yc"], linestyle=':')
# plt.plot(offset1dr - offset1dr[0], label='offset1 dr', color=colors["dr"])
# # plt.plot(offset2dr - offset2dr[0], color=colors["dr"], linestyle=':')
# plt.plot(offset1nc-offset1nc[0], label=f'offset1 nc', color=colors["nc"])
# # plt.plot(offset2nc-offset2nc[0], color=colors["nc"], linestyle=':')

plt.grid(True)
plt.ylabel('offset')
# plt.ylabel('angular velo deg/sec')
plt.xlabel('index')
plt.legend()
plt.title(f'{numor}')
plt.show()


# # plot over different numors
# tol = 5
# numor_list = ['46349', '56059', '54261', '69958']
# plt.figure(figsize=(15,5))
# for i, numor in enumerate(numor_list):
#     viri = [plt.cm.viridis(i / len(numor_list)) for i in range(len(numor_list))]
#     offset1r = np.load(f'/Users/cadenmyers/billingelab/dev/sym_adapted_filts/working_code/ratchet_model/data/reverse_tol_{tol}_{numor}_offsets.npz')['offset1']
#     offset2r = np.load(f'/Users/cadenmyers/billingelab/dev/sym_adapted_filts/working_code/ratchet_model/data/reverse_tol_{tol}_{numor}_offsets.npz')['offset2']
#     # plt.plot(-(offset1r - offset1r[0]), color=viri[i], label=f'offset1 ratchet tol={tol}')
#     # velo2 = compute_smoothed_derivative(offset2r)
#     # velo = compute_smoothed_derivative(offset1r)
#     # plt.plot(velo, color=viri[i], label=f'velo {numor}')
#     # plt.axhline(0, color='black')
#     # plt.plot(velo2, color=viri[i]linestyle=':')
#     # plt.plot(-(offset2r - offset2r[0]), color=viri[i], linestyle=':')

# plt.grid(True)
# plt.ylabel('angular velo deg/sec')
# plt.xlabel('index')
# plt.legend()
# plt.title(f'{numor_list}, window length={window_length}, poly order={polyorder}')
# plt.show()



# --MISC FUNCTION BELOW--

## Use GD to tune ratchet model

In [None]:
import torch
MAX_ITER_OFFSET1 = 50
MAX_ITER_OFFSET2 = 50
LR = 1e-2
GD_numor = '44192'
offset1_init = torch.tensor(0., requires_grad=True)
offset2_init = torch.tensor(0., requires_grad=True)
dps_seed = np.load(f'/Users/cadenmyers/billingelab/dev/sym_adapted_filts/experimental_data/npz_sept_numor_data/{numor}.npz')['data']

offset1_seed_array = np.load(f'/Users/cadenmyers/billingelab/dev/sym_adapted_filts/working_code/ratchet_model/data/{GD_numor}_ratchet_offsets.npz')['offset1']
offset2_seed_array = np.load(f'/Users/cadenmyers/billingelab/dev/sym_adapted_filts/working_code/ratchet_model/data/{GD_numor}_ratchet_offsets.npz')['offset2']
i = 0
offset1_seed = offset1_seed_array[i]
offset2_seed = offset2_seed_array[i]

def filter_function_torch(k, offset, n_folds=n_folds):
    filter = torch.exp(k * torch.log((torch.cos(n_folds / 2 * (torch.tensor(DATA_THETA) + offset)))**2))
    # print('k=', k, 'n_folds=', n_folds)
    # plt.imshow(filter)
    # plt.title(f'n_folds={n_folds}, k={k}')
    # plt.show()
    return filter

def gradient_descent_optimize_seeded_offset(intensity, offset1_seed, offset2_seed):
    '''
    Takes in offset seeds from ratchet model and refines the offset.

    all offset inputs [=] degrees
    all offset outputs [=] degrees
    '''
    # initialize the offsets with the seeds from ratchet model
    offset1 = torch.tensor(np.deg2rad(offset1_seed), dtype=torch.float32, requires_grad=True)
    offset2 = torch.tensor(np.deg2rad(offset2_seed), dtype=torch.float32, requires_grad=True)

    opt1 = torch.optim.Adam([offset1], lr=LR)
    opt2 = torch.optim.Adam([offset2], lr=LR)

    for _ in range(MAX_ITER_OFFSET1):
        loss = -(torch.tensor(intensity, dtype=torch.float32) * filter_function_torch(k, offset1)).sum()
        opt1.zero_grad()
        loss.backward()
        opt1.step()

    for _ in range(MAX_ITER_OFFSET2):
        loss = -(torch.tensor(intensity, dtype=torch.float32) * filter_function_torch(k, offset2)).sum()
        opt2.zero_grad()
        loss.backward()
        opt2.step()
    offset1 = np.rad2deg(offset1.detach().item())
    offset2 = np.rad2deg(offset2.detach().item())
    return offset1, offset2

# off1, off2 = gradient_descent_optimize_seeded_offset(dps[i])
# print(offset1_seed, offset2_seed, off1, off2)

refined_offset1, refined_offset2 = [], []

for i, dp in enumerate(dps_seed):
    off1, off2 = gradient_descent_optimize_seeded_offset(dp, offset1_seed_array[i], offset2_seed_array[i])
    refined_offset1.append(off1)
    refined_offset2.append(off2)

In [None]:
# def grid_search_offset(
#     dps,
#     tolerance_forward,
#     tolerance_reverse,
#     jump_threshold_offset1,
#     jump_threshold_offset2,
#     offset_step=0.5,
# ):
#     '''
#     positive angles = forward rotation
#     negative angles = reverse rotation
#     returns offset1 and offset2 in degrees
#     '''
#     optimal_offsets1 = []
#     optimal_offsets2 = []
#     prev_offset1 = 0
#     prev_offset2 = 0

#     for index, dp in enumerate(dps):
#         loss_list1 = []
#         offset_list_deg1 = []
#         # find initial offset
#         if index == 0:
#             offset1_angles_deg = np.arange(-360/(2*n_folds), 360/(2*n_folds), offset_step)
#         else:
#             offset1_angles_deg = np.arange(prev_offset1 -  tolerance_reverse, prev_offset1 + tolerance_forward, offset_step)

#         # Compute loss function
#         for offset1 in offset1_angles_deg:
#             offset_rad = np.deg2rad(offset1)  # convert to radians
#             filt = filter_function(k, offset_rad)
#             loss = -(dp * filt).sum()
#             loss_list1.append(loss)
#             offset_list_deg1.append(offset1)
#         min_loss_idx = loss_list1.index(min(loss_list1))
#         best_offset1 = offset_list_deg1[min_loss_idx]

#         # Correct overrotations of 60 degrees for offset1
#         # delta1 = abs(best_offset1 - prev_offset1)
#         # if delta1 > jump_threshold_offset1 and index != 0:
#         #     best_offset1 -= 360/n_folds # - for cw rotating data, + for ccw rotating data
#         #     print(f"{index}: Threshold exceeded for offset1. Using closest local max: {best_offset1}")

#         # filter out first signal
#         filt1 = filter_function(k, np.deg2rad(best_offset1))
#         filt1 = np.where(filt1 > 0.01, 0, 1)
#         dp_filtered = dp * filt1
#         # Use second derivative to determine the number of domains
#         offset_range, loss_values, first_derivative, second_derivative = compute_loss_near_offset1(dp, np.deg2rad(best_offset1))
#         approx_offset_values = get_approx_offset_values(offset_range, second_derivative)
#         number_of_domains = len(approx_offset_values)
#         if number_of_domains == 1:
#             # run loop again, but using prev_offset2
#             # print(f'{index}: 1 domain')
#             loss_list2 = []
#             offset_list_deg2 = []
#             offset2_angles_deg = np.arange(prev_offset2 - 30, prev_offset2 + 30, offset_step)
#             for offset2 in offset2_angles_deg:
#                 offset_rad = np.deg2rad(offset2)
#                 filt2 = filter_function(k, offset_rad)
#                 loss = -(dp * filt2).sum()
#                 loss_list2.append(loss)
#                 offset_list_deg2.append(offset2)
#             min_loss_idx = loss_list2.index(min(loss_list2))
#             best_offset2 = offset_list_deg2[min_loss_idx]
#             # best_offset2 = prev_offset2 + best_offset1 % 60  # mod60 term is added to account for the
#         else:                                               # movement of offset2 from previous to current frame
#             # Second domain search
#             loss_list2 = []
#             offset_list_deg2 = []
#             offset2_angles_deg = np.arange(prev_offset2 - tolerance_reverse, prev_offset2 + tolerance_forward, offset_step)
#             for offset2 in offset2_angles_deg:
#                 offset_rad = np.deg2rad(offset2)
#                 filt2 = filter_function(k, offset_rad)
#                 loss = -(dp_filtered * filt2).sum()
#                 loss_list2.append(loss)
#                 offset_list_deg2.append(offset2)
#             min_loss_idx = loss_list2.index(min(loss_list2))
#             best_offset2 = offset_list_deg2[min_loss_idx]

#             # Correct overrotations of 60 degrees for offset2
#             # delta2 = abs(best_offset2) - abs(prev_offset2)
#             # if delta2 > jump_threshold_offset2:
#             #     best_offset2 -= 360/n_folds # - for cw rotating data, + for ccw rotating data
#             #     print(f"{index}: Threshold2 exceeded for offset2. Adjusting by {360/n_folds} degrees: {best_offset2}")
#             #     print(delta2, '>', jump_threshold_offset2)

#         # plotting
#         # norm_int = normalize_min_max(dp)
#         # filt1 = filter_function(k, np.deg2rad(best_offset1))
#         # filt2 = filter_function(k, np.deg2rad(best_offset2))
#         # signal = norm_int + filt1 + 2*filt2
#         # fig, axs = plt.subplots(1, 3, figsize=(12, 6))
#         # axs[0].imshow(signal, cmap='viridis', vmin=0, vmax=2)
#         # axs[1].imshow(norm_int, cmap='viridis')
#         # axs[2].imshow(dp_filtered)
#         # plt.show()
#         # print(f'{index*10}s: ', best_offset1, best_offset2)

#         optimal_offsets1.append(best_offset1)
#         optimal_offsets2.append(best_offset2)
#         prev_offset1 = best_offset1
#         prev_offset2 = best_offset2
#     return optimal_offsets1, optimal_offsets2