In [None]:
import pandas as pd
import nibabel as nib
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
import sys
sys.path.append('..')
import sfp
import pyPyrTools as ppt
import math
from scipy import stats
from scipy import optimize as opt
import torch

In [None]:
sns.set_style('whitegrid')

In [None]:
df = pd.read_csv("/home/billbrod/Data/spatial_frequency_preferences/derivatives/first_level_analysis/stim_class/posterior/sub-wlsubj045/ses-02/sub-wlsubj045_ses-02_task-sfp_v1_e1-12_summary.csv")

In [None]:
df[(df.varea==1)&(df.R2>55)].drop_duplicates('voxel').sort_values('R2', ascending=False).head(5)[['voxel', 'R2', 'varea', 'hemi', 'angle' ,'eccen',]]

In [None]:
# Pick a V1 voxel with a good R2
voxel_df = df[(df.voxel==53)]
voxel_df.head()

In [None]:
def scatter_sizes(x, y, s, plot_color=False, cmap=None, **kwargs):
    if plot_color:
        kwargs.pop('color')
        if cmap is None:
            cmap = 'Blues'
        plt.scatter(x, y, s=s*80, c=s, cmap=cmap, **kwargs)
    else:
        plt.scatter(x, y, s=s*80, **kwargs)

voxel_df['normalized_resp'] = voxel_df['amplitude_estimate_median'].copy()
voxel_df['normalized_resp'] = (voxel_df['normalized_resp'] - voxel_df['normalized_resp'].min()) / (voxel_df['normalized_resp'].max() - voxel_df['normalized_resp'].min())
g=sns.FacetGrid(voxel_df, size=5, aspect=1, hue='stimulus_superclass')
g.map(scatter_sizes, 'local_w_x', 'local_w_y', 'normalized_resp', plot_color=False)
g.add_legend()
scatter_ax = plt.gca()
scatter_ax.set_aspect('equal')

In [None]:
g=sns.FacetGrid(voxel_df, hue='stimulus_superclass', size=5, aspect=1)
g.map(scatter_sizes, 'local_w_r', 'local_w_a', 'normalized_resp')
g.add_legend()
scatter_ax = plt.gca()
scatter_ax.set_xscale('symlog', basex=2, linthreshx=2**(-3))
scatter_ax.set_yscale('symlog', basey=2, linthreshy=2**(-3))

In [None]:
def loggaussian_donut((r, th), amplitude, major_axis, minor_axis, major_axis_sigma, minor_axis_sigma, rotation_angle):

    r = np.log2(r)
    th += rotation_angle

    # transform angles based on ellipse axes
    transformed_theta = np.arctan2(major_axis*np.sin(th), minor_axis*np.cos(th))

    # Gaussian center as function of angle
    ctr = np.sqrt((major_axis*np.cos(transformed_theta))**2 + (minor_axis*np.sin(transformed_theta))**2)

    # rotational sigma
    sigma = np.sqrt((major_axis_sigma*np.cos(transformed_theta))**2 + (minor_axis_sigma*np.sin(transformed_theta))**2)
    
    # This is our function
    return amplitude*np.exp(-(r-ctr)**2 / (2*sigma**2))

def torch_meshgrid(x, y=None):
    """from https://github.com/pytorch/pytorch/issues/7580"""
    if y is None:
        y = x
    x = torch.tensor(x)
    y = torch.tensor(y)
    m, n = x.size(0), y.size(0)
    grid_x = x[None].expand(n, m)
    grid_y = y[:, None].expand(n, m)
    return grid_x, grid_y

class LogGaussianDonut(torch.nn.Module):
    """LogGaussianDonut in pytorch
    """
    def __init__(self, amplitude, major_axis, minor_axis, major_axis_sigma, minor_axis_sigma, rotation_angle):
        super(LogGaussianDonut, self).__init__()
        self.amplitude = torch.nn.parameter.Parameter(torch.tensor(amplitude, dtype=torch.float))
        self.major_axis = torch.nn.parameter.Parameter(torch.tensor(major_axis, dtype=torch.float))
        self.minor_axis = torch.nn.parameter.Parameter(torch.tensor(minor_axis, dtype=torch.float))
        self.major_axis_sigma = torch.nn.parameter.Parameter(torch.tensor(major_axis_sigma, dtype=torch.float))
        self.minor_axis_sigma = torch.nn.parameter.Parameter(torch.tensor(minor_axis_sigma, dtype=torch.float))
        self.rotation_angle = torch.nn.parameter.Parameter(torch.clamp(torch.tensor(rotation_angle, dtype=torch.float), 0, np.pi))
    
    def create_image(self, extent=None, n_samps=1001):
        if extent is None:
            extent = (-10, 10)
        x = torch.linspace(extent[0], extent[1], n_samps)
        x, y = torch_meshgrid(x)
        r = torch.sqrt(torch.pow(x, 2) + torch.pow(y, 2))
        th = torch.atan2(y, x)
        return self.evaluate(r, th)

    def evaluate(self, r, theta):
        if not torch.is_tensor(r):
            r = torch.tensor(r, dtype=torch.float)
        if not torch.is_tensor(theta):
            theta = torch.tensor(theta, dtype=torch.float)
        r = torch.log2(r)
        theta += self.rotation_angle
        # transform angles based on ellipse axes
        theta = torch.atan2(self.major_axis*torch.sin(theta), self.minor_axis*torch.cos(theta))
        # Gaussian center as function of angle
        self.ctr = torch.sqrt(torch.pow(self.major_axis*torch.cos(theta), 2) + torch.pow(self.minor_axis*torch.sin(theta), 2))
        # rotational sigma
        self.sigma = torch.sqrt(torch.pow(self.major_axis_sigma*torch.cos(theta), 2) + torch.pow(self.minor_axis_sigma*torch.sin(theta),2))
        # This is our function
        return self.amplitude*torch.exp(-(r-torch.log2(self.ctr))**2 / (2*self.sigma**2))

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        r, th = x
        self.rotation_angle = torch.nn.parameter.Parameter(torch.clamp(self.rotation_angle, 0, np.pi))
        return self.evaluate(r, th)

In [None]:
donut = LogGaussianDonut(20, 3, 1, .2, .5, 0)
x = np.linspace(-5, 5, 1001)
xgrid, ygrid = np.meshgrid(x, x)
# detach() is required to separate it from the graph implied by setting `requires_grad=True` above
plt.imshow(donut.create_image((x.min(), x.max())).detach(), extent=(x.min(),x.max(), x.min(), x.max()), cmap='Reds')
plt.colorbar()

In [None]:
img = donut.create_image((x.min(), x.max()), len(x)).detach().numpy()
fig, axes = plt.subplots(2,2,figsize=(10, 5))
axes=axes.flatten()
R = ppt.mkR(len(x))
R *= (np.sqrt(2*x.max()**2)/R.max())
R[xgrid<0] *= -1
for ax, a in zip(axes.flatten(), [0, 1, 2, 3]):
    idx = np.where(xgrid==a*ygrid)
    r = R[idx]
    ax.plot(r, img[idx])
    ax.set(xlim=(-8, 8))

In [None]:
g=sns.FacetGrid(voxel_df, size=5, aspect=1)
g.map(scatter_sizes, 'local_w_x', 'local_w_y', 'normalized_resp')
scatter_ax = plt.gca()
scatter_ax.set_aspect('equal')

donut = LogGaussianDonut(voxel_df.amplitude_estimate_median.max(), 1, 1, .2, .2, 0)
x = np.linspace(-3, 3, 101)
# detach() is required to separate it from the graph implied by setting `requires_grad=True` above
c = scatter_ax.contour(x, x, donut.create_image((x.min(), x.max()), len(x)).detach(), cmap="Reds")
g.fig.colorbar(c, shrink=.5)
scatter_ax.set(xlim=(-4.5, 4.5), ylim=(-.5,5))
scatter_ax.set_aspect('equal')

x1 = torch.tensor(voxel_df.local_sf_magnitude.values, dtype=torch.float)
x2 = torch.tensor(voxel_df.local_sf_direction.values, dtype=torch.float)
y = torch.tensor(voxel_df.amplitude_estimate_median.values, dtype=torch.float)

loss_fn = torch.nn.MSELoss(False)
optimizer = torch.optim.Adam(donut.parameters(), lr=1e-3)

In [None]:
list(donut.named_parameters())

In [None]:
loss_prev = 0.01
n_steps = 2000
thresh = .00001
for t in range(n_steps):
    y_pred = donut((x1,x2))
    loss = loss_fn(y_pred, y)
    if t % 100 == 0:
        print(t, loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if abs((loss - loss_prev) / loss_prev) < thresh:
        break
    loss_prev = loss
print("Final loss: %02f" % loss)

In [None]:
list(donut.named_parameters())

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20,5))
vals = [y.detach().numpy(), y_pred.detach().numpy(), y.detach().numpy() - y_pred.detach().numpy()]
titles = ['ground truth', 'predicted', 'ground truth - predicted']
for ax, v, t in zip(axes.flatten(), vals, titles):
    scaled_v = (v - abs(v).min()) / (abs(v).max() - abs(v).min())
    pts=ax.scatter(voxel_df['local_w_x'], voxel_df['local_w_y'], s=abs(scaled_v)*50, c=v, cmap='RdBu_r', norm=sfp.plotting.MidpointNormalize(midpoint=0))
    ax.set_aspect('equal')
    plt.colorbar(pts, ax=ax, shrink=.6)
    ax.set(xlim=(-4.5, 4.5), ylim=(-.5, 5))
    ax.set_title(t)

In [None]:
x = np.linspace(-8, 8, 1001)
# detach() is required to separate it from the graph implied by setting `requires_grad=True` above
plt.imshow(donut.create_image((x.min(), x.max())).detach(), extent=(x.min(),x.max(), x.min(), x.max()),cmap='RdBu_r', norm=sfp.plotting.MidpointNormalize(midpoint=0))
ax = plt.gca()
ax.set(xlim=(-4.5, 4.5), ylim=(-.5, 5))
plt.colorbar(shrink=.7)