In [None]:
#Basic Packages
import sys
import torch
import numpy as np

# import time
import matplotlib.pyplot as plt 

# Image wranglers
import imageio
from collections import OrderedDict
from PIL import Image

#Tocohpy functions 
import Optical_Components as comp
import Optical_Propagators as prop
import Helper_Functions as HF

##GPU info for pytorch##
gpu_no = 0
use_cuda = True
device = torch.device("cuda:"+str(gpu_no) if use_cuda else "cpu")


plt.style.use('dark_background')
""" This Notebook is an example notebook for Optical NN learning with Tocohpy - Lionel Fiske  """
#Last update: 9/15/2021

Tocohpy is a library of functions designed to prototype the design of optical systems which use learnable phase and absorption gratings. This notebook is a series of example problems which demonstrate how to use the tocohpy functions to build optical systems. 

In [None]:
## In this example a 4 f system will be modeled.
##

# units
mm = 1e-3 #meters
nm = 1e-9 #meters


### Setting Problem Parameters ###

#Masking the initial field 
s= (1024, 1024)      # size of simulation in pixels
sim_size = 5*mm      # size of wavefront / simulation fox

#coordinates 
x_c = torch.linspace(-sim_size/2,sim_size/2, s[0] )  #XCoordinates 
y_c = torch.linspace(-sim_size/2,sim_size/2, s[1] )  #YCoordinates 
[X,Y] = torch.meshgrid(x_c,y_c)    #Generate Coordinate Mesh
R2= X**2 + Y**2                    #Radial coordinates
dx = x_c[1] - x_c[0]               #grid spacing


# Wavelength
lamb  = 500.0* nm

#Lens focal length
focal_length = 25* mm 

#Define output target image

#We load an image in and apply some padding 
s2=( int( s[0]/2)  ,int( s[1]/2 ) )
field_in = torch.zeros(2 , 1, s[0], s[1], device = device, dtype = torch.cfloat)

#Here I am loading in a few images and resampling them so they are the same size as my field
resized_image = torch.tensor( np.array(Image.fromarray(imageio.imread('imageio:camera.png')).resize(s2)) )
image_noise_padding = torch.nn.functional.pad(resized_image+ 0*torch.rand(s2), (int(1/4 * s[0]),int(1/4 * s[0]),int(1/4 * s[1]),int(1/4 * s[1])), mode='constant', value=0)   

resized_image2 =torch.tensor( np.array(Image.fromarray(imageio.imread('imageio:astronaut.png')[:,:,0]).resize(s2)) )
image_noise_padding2 = torch.nn.functional.pad(resized_image2+ 0*torch.rand(s2), (int(1/4 * s[0]),int(1/4 * s[0]),int(1/4 * s[1]),int(1/4 * s[1])), mode='constant', value=0)   


## Tocohpy uses pytorch, so it can handle batch and channel dimensions.
## load each image into a different batch number
field_in[0,0,:,:] = ((1j)*image_noise_padding/255.0  ).to(device) 
field_in[1,0,:,:] = ((1j)*image_noise_padding2/255.0  ).to(device) 


To create an optical path we define a dictionary of optical components and propagation steps

In [None]:
####################################################

#### We now will build an optical system using an ordered dict data structure ####
Optical_Path = OrderedDict()    

# #Display an image and prop light 1 focal length to the lens. Each step gets assigned a name of your choice. 
# We adopt the convention that propagation steps begin with P, lenses with L, SLMs with S and absorption masks with A
Optical_Path['P0'] = prop.ASM_Prop( wavelength = lamb,dx = dx, distance = 1*focal_length , N=s[0] , H=None, device = device)

# #We apply a thin lens phase delay and propagate 1 more f to the fourier plane 
Optical_Path['L1'] = comp.Thin_Lens(f= 1*focal_length, wavelength=lamb, R2 = R2 , device = device )
Optical_Path['P1'] = prop.ASM_Prop( wavelength = lamb,dx = dx , distance = 1*focal_length , N=s[0] , H=None, device = device)


#At the Fourier plane we apply a low pass absorption mask. 
#We define the mask as transmitting 1 within a certain radius and transmitting 0 elsewhere
mask= torch.ones(s)
aperture_radius_sq = .25*torch.max( R2 )
mask[R2> aperture_radius_sq] = 0

#We set an absorption grating with this mask. Since we wont optimize for this variable
#we will pass a True flag to fixed_pattern
Optical_Path['A0'] = comp.Absorption_Mask( transmission = mask.clone(), fixed_pattern = True , device =device )


# #Nex we propagate out of the fourier plane, through a second lens and to the detector
Optical_Path['P2'] = prop.ASM_Prop( wavelength = lamb,dx =dx, distance = 1*focal_length , N=s[0] , padding = 0, H=None, device = device)
Optical_Path['L2'] = comp.Thin_Lens(f= 1*focal_length, wavelength=lamb, R2 = R2  , device = device )
Optical_Path['P3'] = prop.ASM_Prop( wavelength = lamb,dx = dx, distance = 1*focal_length , N=s[0] , padding = 0, H=None, device = device)

#Once the Optical Path is defined we can combine it into a single network. 
four_f_model=torch.nn.Sequential( Optical_Path ).to(device)


#View the results for both 'batch' images
f,ax_arr = plt.subplots(2,2,figsize=(12,12))


img = 0

im_1 = ax_arr[0,0].imshow(( ( field_in[img,0,:,:]) ).abs().cpu().detach())
ax_arr[0,1].imshow(( four_f_model( field_in)[img,0,:,:] ).abs().cpu().detach())

ax_arr[0,0].set_title('Input field intensity')
ax_arr[0,1].set_title('Output intensity of 4F system with aperture')

ax_arr[0,0].axis('off')
ax_arr[0,1].axis('off')

img = 1

ax_arr[1,0].imshow(( ( field_in[img,0,:,:]) ).abs().cpu().detach())
ax_arr[1,1].imshow(( four_f_model( field_in)[img,0,:,:] ).abs().cpu().detach())

ax_arr[1,0].set_title('Input field intensity')
ax_arr[1,1].set_title('Output intensity of 4F system with aperture')

ax_arr[1,0].axis('off')
ax_arr[1,1].axis('off')


plt.show()

Because a 4f system can be written as fourier transforms we can use the FT_Lens class to rewrite this in a more compact way. The FT approach to lenses has a different gird spacing at the observation and Fourier planes. When a method has a different grid spacing it is labeled with "\_NC" (new coordinates) and contains the variable dx_new for the new grid spacing

In [None]:
#Set input images
field_in[0,0,:,:] = ((1j)*image_noise_padding/255.0  ).to(device) 
field_in[1,0,:,:] = ((1j)*image_noise_padding2/255.0  ).to(device) 

####################################################

#### We now will build an optical system using an ordered dict data structure ####
Optical_Path_2 = OrderedDict()    

#Prop light from the rear focal plane of a lens to the Fourier plane
Optical_Path_2['FT0'] =  comp.FT_Lens_NC(f= focal_length, wavelength= lamb, dx = dx , N = s[0], device = device   )

#At the Fourier plane we apply a low pass absorption mask but we need to do it in our rescaled coordinates

#new grid spacing 
dx_new = Optical_Path_2['FT0'].dx_new 

#magnification factor 
mag = dx_new/dx

#We set an absorption grating with this mask. Since we wont optimize for this variable
#we will pass a True flag to fixed_pattern
mask= torch.ones(s)
aperture_radius_sq =  (.25 / (mag**2) ) *torch.max( R2 )
mask[R2 > aperture_radius_sq] = 0
Optical_Path_2['A0'] = comp.Absorption_Mask( transmission = mask.clone(), fixed_pattern = True , device =device )

# #Display random pattern and prop light from fourier plane to observation plane. However we need to take into account the new
#coordinates (NC) by adjusting dx
Optical_Path_2['FT1'] =  comp.FT_Lens_NC(f= focal_length, wavelength= lamb, dx = dx_new , N = s[0], device = device   )


#Once the Optical Path is defined we can combine it into a single network. 
four_f_model_2=torch.nn.Sequential( Optical_Path_2 ).to(device)


#View the results 
f,ax_arr = plt.subplots(2,2,figsize=(12,12))


img = 0

ax_arr[0,0].imshow(( ( field_in[img,0,:,:]) ).abs().cpu().detach())
ax_arr[0,1].imshow(( four_f_model_2( field_in)[img,0,:,:] ).abs().cpu().detach())

ax_arr[0,0].set_title('Input field intensity')
ax_arr[0,1].set_title('Output intensity of 4F system with aperture')

ax_arr[0,0].axis('off')
ax_arr[0,1].axis('off')

img = 1

ax_arr[1,0].imshow(( ( field_in[img,0,:,:]) ).abs().cpu().detach())
ax_arr[1,1].imshow(( four_f_model_2( field_in)[img,0,:,:] ).abs().cpu().detach())

ax_arr[1,0].set_title('Input field intensity')
ax_arr[1,1].set_title('Output intensity of 4F system with aperture')

ax_arr[1,0].axis('off')
ax_arr[1,1].axis('off')




plt.show()

In this example we will optimize for a phase SLM pattern to display and image via diffraction

In [None]:
#Setting uo problem parameters 

# units
mm = 1e-3
nm = 1e-9


## Setting Problem Parameters ##

#Masking the initial field 
s= (800,800)      # size of simulation in pixels
sim_size = 5*mm     # size of wavefront 

#coordinates 
x_c = torch.linspace(-sim_size/2,sim_size/2, s[0] )  #XCoordinates 
y_c = torch.linspace(-sim_size/2,sim_size/2, s[1] )  #YCoordinates 
[X,Y] = torch.meshgrid(x_c,y_c)    #Generate Coordinate Mesh
R2= X**2 + Y**2                    #Radial coordinates
dx = x_c[1] - x_c[0]                #grid spacing


# Wavelength
lamb  = 500.0* nm

#Lens focal length
focal_length = 30* mm 

#Detector distance
detector_distance = 20* mm

#Define output target image
s2=( int( s[0]/2)  ,int( s[1]/2 ) )
target_image = torch.tensor( np.array(Image.fromarray(imageio.imread('imageio:camera.png')).resize(s)), dtype =torch.cfloat ).to(device)

#define the incident field as a delta function 
field_in = torch.zeros(s, dtype = torch.cfloat, requires_grad = False).to(device)
field_in[ int(s[0]/2) - 3 : int(s[0]/2) + 3 , int(s[1]/2) - 3 : int(s[1]/2) + 3 ] = 1


plt.imshow(target_image.abs().cpu())
plt.title('Target output image')

Our optical setup will consist of a point source going through a fourier transform lens and then propagated through a SLM to a detector.

In [None]:
#### We now will build an optical system using an ordered dict data structure ####
Optical_Path_3 = OrderedDict()    

# #Display random pattern and prop light from the rear focal plane of a lens to the Fourier plane
Optical_Path_3['FT0'] =   comp.FT_Lens_NC(f= focal_length, wavelength= lamb, dx = dx , N = s[0] , device = device   )

#At the Fourier plane we apply a phase SLM. The initial pattern is a random but we 
#will learn the phase delay later
Optical_Path_3['S0'] = comp.SLM( phase_delay = 10* torch.rand(s).to(device) ,  device =device )

# #Next we propagate out of the fourier plane, through a second lens and to the detector (our DX changed using a NC method)
Optical_Path_3['P0'] = prop.ASM_Prop( wavelength = lamb,dx = Optical_Path_3['FT0'].dx_new, distance = 1*detector_distance , N=s[0] , H=None, device = device)


#Once the Optical Path is defined we can combine it into a single network. 
diffraction_model=torch.nn.Sequential( Optical_Path_3 ).to(device)


#View the results 
f,ax_arr = plt.subplots(1,2,figsize=(10,5))

ax_arr[0].imshow(( field_in[:,:] ).abs().cpu().detach())
ax_arr[1].imshow(diffraction_model( field_in)[:,:] .abs().cpu().detach()) 

ax_arr[0].set_title('Input field intensity')
ax_arr[1].set_title('Output intensity of 4F system with aperture')

ax_arr[0].axis('off')
ax_arr[1].axis('off')


plt.show()

Tocohpy arranges optical paths similarly to a Neural network. All components are on the data graph and gradients can be back propagated through allowing for optimization using pytorch optimizers. Unless fixed_pattern =True the SLM class will always contain learnable parameters


In [None]:
#### To optimize for the diffraction grating we use a loop structure very similar to optimizing a Neural network ####
#We select our optimizer for the model parameters. Currently, SLM and Absorption Gratings have learnable patterns by default

optimizer = torch.optim.Adam(diffraction_model.parameters() , lr=0.5)

#Rescale target image to match incident light budget 

light_budget = torch.norm( field_in )
target_image = light_budget * target_image /torch.norm( target_image) 

#Loop and optimize
for t in range(1000):

    #L2: Compute and print loss
    L_fun = 1*( diffraction_model(field_in).abs() - target_image.abs() ).std() 

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    L_fun.backward()
    optimizer.step()

    
    if t % 100 == 1:
        print(t, 'Loss:' , L_fun.item() ) 

##
##Examine the results
##
f, ax = plt.subplots(1,4,figsize=(22,12))

im1=ax[0].imshow(field_in.abs().detach().cpu() )
im2=ax[1].imshow( list(diffraction_model.parameters() )[0].clone().detach().abs().cpu())
im2=ax[2].imshow( diffraction_model(field_in)[:,:].clone().detach().abs().cpu())
im1=ax[3].imshow(target_image.abs().detach().cpu() )


ax[0].axis('off')
ax[1].axis('off')
ax[2].axis('off')
ax[3].axis('off')

ax[0].set_title('Input field at 0', fontsize = 20)
ax[1].set_title('SLM pattern at Fourier plane', fontsize = 20)
ax[2].set_title('Detected image', fontsize = 20)
ax[3].set_title('Target image', fontsize = 20)             

Optical Neural Networks are a topic which has become increasingly popular. Several papers use SBN and other nonlinear crystals. Tocohpy allows for Split step propagation for solving the Nonlinear shrondinger equation. Gradients can still be propagated through these nonlinear methods. 

In [None]:
# units
mm = 1e-3
nm = 1e-9


## Setting Problem Parameters ##

#Masking the initial field 
s= (800,800)        # size of simulation in pixels
sim_size = 5*mm     # size of wavefront 

#coordinates 
x_c = torch.linspace(-sim_size/2,sim_size/2, s[0] )  #XCoordinates 
y_c = torch.linspace(-sim_size/2,sim_size/2, s[1] )  #YCoordinates 
[X,Y] = torch.meshgrid(x_c,y_c)    #Generate Coordinate Mesh
R2= X**2 + Y**2                    #Radial coordinates
dx = x_c[1] - x_c[0]               #grid spacing


# Wavelength
lamb  = 500.0* nm

#Lens focal length
focal_length = 50* mm 

#distance from SLM to Crystal
crystal_distance = 100* mm

#Distance from crystal Detector
detector_distance = 100* mm

#Define output target image
s2=( int( s[0]/2)  ,int( s[1]/2 ) )
target_image = torch.tensor( np.array(Image.fromarray(imageio.imread('imageio:camera.png')).resize(s)), dtype =torch.cfloat ).to(device)

#define the incident field as a delta function 
field_in = torch.zeros(s, dtype = torch.cfloat, requires_grad = False).to(device)
field_in[ int(s[0]/2) - 3 : int(s[0]/2) + 3 , int(s[1]/2) - 3 : int(s[1]/2) + 3 ] = 1

# # We now set the parameters of the nonlinear crystal
gamma = -500
crystal_length = 10* mm
num_steps = 50
h = crystal_length / num_steps 


####################################################

#### We now will build an optical system using an ordered dict data structure ####
Optical_Path_NLS = OrderedDict()    

# #Display random pattern and prop light from the rear focal plane of a lens to the Fourier plane
Optical_Path_NLS['FT0'] =  comp.FT_Lens_NC(f= focal_length, wavelength= lamb, dx = dx , N= s[0], device = device  )

#At the Fourier plane we apply a low pass absorption mask
Optical_Path_NLS['S0'] = comp.SLM( phase_delay = torch.rand(s).to(device) ,  device =device )

# March light through SB crystal (1cm thick) with opterator splitting
for i in range(0,num_steps):
    Optical_Path_NLS['Cp' + str(i)] = prop.NLS_Prop( wavelength = lamb,dx = x_c[1] - x_c[0], distance = 1*h , N=s[0] , H=None, padding = 1/8,  device = device)
    Optical_Path_NLS['Cu' + str(i)] =comp.SB_Crystal_Update( gamma=gamma , distance = h, device = device  )
    
#Nex we propagate out of the fourier plane, through a second lens and to the detector
Optical_Path_NLS['P0'] = prop.ASM_Prop( wavelength = lamb,dx = x_c[1] - x_c[0], distance = 1*detector_distance , N=s[0] , H=None, device = device)

#Once the Optical Path is defined we can combine it into a single network. 
crystal_model=torch.nn.Sequential( Optical_Path_NLS ).to(device)



#View the results 
f,ax_arr = plt.subplots(1,2,figsize=(10,5))

ax_arr[0].imshow(( ( field_in) ).abs().cpu().detach())
ax_arr[1].imshow( (crystal_model( field_in )[:,:]  ) .abs().cpu().detach()) 

ax_arr[0].set_title('Input field intensity')
ax_arr[1].set_title('Output intensity of 4F system with aperture')

ax_arr[0].axis('off')
ax_arr[1].axis('off')


plt.show()

In [None]:
#### To optimize for the diffraction grating we use a loop structure very similar to optimizing a Neural network ####

optimizer = torch.optim.Adam(crystal_model.parameters() , lr=0.35)

#Rescale target image to match incident light budget 
light_budget = torch.norm( field_in )
target_image = light_budget * target_image /torch.norm( target_image) 

#Loop and optimize
for t in range(500):

    
    #L2: Compute and print loss
    L_fun = 1*( crystal_model(field_in).abs() - target_image.abs() ).std() 


    # Zero gradients, perform a backward pass, and update the weights.

        # Backpropagation
                #zero my grads out
    optimizer.zero_grad()
    L_fun.backward()
    optimizer.step()

    
    if t % 100 == 1:
        print(t, 'Loss:' , L_fun.item() ) 
        
        
##
##Examine the results
##
f, ax = plt.subplots(1,4,figsize=(22,12))

im1=ax[0].imshow(field_in.abs().detach().cpu() )
im2=ax[1].imshow( list(crystal_model.parameters() )[0].clone().detach().abs().cpu())
im2=ax[2].imshow( crystal_model(field_in)[:,:].clone().detach().abs().cpu())
im1=ax[3].imshow(target_image.abs().detach().cpu() )

ax[0].axis('off')
ax[1].axis('off')
ax[2].axis('off')
ax[3].axis('off')

ax[0].set_title('Input field at 0', fontsize = 20)
ax[1].set_title('SLM pattern at Fourier plane', fontsize = 20)
ax[2].set_title('Detected image', fontsize = 20)
ax[3].set_title('Target image', fontsize = 20)         
        

In some holography applications, it may be necessary to have some components which are learned for entire batches of images and some components which are learned for specific images. This can be accomplished by adding or using batch/channel dimensions to learnable optical components.


In [None]:
# units
mm = 1e-3
nm = 1e-9


## Setting Problem Parameters ##

#Masking the initial field 
s= (100,100)      # size of simulation in pixels
sim_size = 5*mm     # size of wavefront 

#coordinates 
x_c = torch.linspace(-sim_size/2,sim_size/2, s[0] )  #XCoordinates 
y_c = torch.linspace(-sim_size/2,sim_size/2, s[1] )  #YCoordinates 
[X,Y] = torch.meshgrid(x_c,y_c)    #Generate Coordinate Mesh
R2= X**2 + Y**2                    #Radial coordinates
dx = x_c[1] - x_c[0]                #grid spacing


# Wavelength
lamb  = 500.0* nm

#Lens focal length
focal_length = 100* mm 

#distance from SLM to Crystal
crystal_distance = 40* mm

#Distance from crystal Detector
detector_distance = 100* mm

#Define output target image
s2=( int( s[0]/2)  ,int( s[1]/2 ) )
target_image=torch.rand(10,3,s[0],s[1]).to(device)

#define the incident field as a delta function 
field_in = torch.zeros(s, dtype = torch.cfloat, requires_grad = False).to(device)
field_in[ int(s[0]/2) - 3 : int(s[0]/2) + 3 , int(s[1]/2) - 3 : int(s[1]/2) + 3 ] = 1


####################################################

#### We now will build an optical system using an ordered dict data structure ####
Optical_Path_Batch = OrderedDict()    

#Prop inpput field
Optical_Path_Batch['P0'] = prop.ASM_Prop( wavelength = lamb,dx = x_c[1] - x_c[0], distance = 1*focal_length , N=s[0] , H=None, device = device)

# apply lens
Optical_Path_Batch['L0'] = comp.Thin_Lens(f= 1*focal_length, wavelength=lamb, R2 = R2  , device = device )        
    
    
#Prop light 
Optical_Path_Batch['P1'] = prop.ASM_Prop( wavelength = lamb,dx = x_c[1] - x_c[0], distance = 1*focal_length , N=s[0] , H=None, device = device)


#SLM learns 1 pattern for all inputs (Controlled by the shape of phase delay)
Optical_Path_Batch['S0'] =comp.SLM( phase_delay = torch.rand(s).to(device) ,  device =device )

#Prop light
Optical_Path_Batch['P2'] = prop.ASM_Prop( wavelength = lamb,dx = x_c[1] - x_c[0], distance = 1*focal_length , N=s[0] , H=None, device = device)

#BATCH SLM learns a different pattern for each of the 10 inputs but 1 pattern per channel 
Optical_Path_Batch['S1'] =comp.SLM( phase_delay = torch.rand( 10, 1 , s[0],s[1] ).to(device) ,  device =device )

#Prop light 
Optical_Path_Batch['P3'] = prop.ASM_Prop( wavelength = lamb,dx = x_c[1] - x_c[0], distance = 1*focal_length-crystal_length , N=s[0] , H=None, device = device)

#apply lens 
Optical_Path_Batch['L1'] = comp.Thin_Lens(f= 1*focal_length, wavelength=lamb, R2 = R2 ,  device = device )        
    
    
#prop to detector
Optical_Path_Batch['P4'] = prop.ASM_Prop( wavelength = lamb,dx = x_c[1] - x_c[0], distance = 1*focal_length , N=s[0] , H=None, device = device)



Batch_SLM_Model=torch.nn.Sequential( Optical_Path_Batch ).to(device)



#View the a sample of results 
f,ax_arr = plt.subplots(1,2,figsize=(10,5))

chan = 0
batch = 3

ax_arr[0].imshow(( ( target_image[batch,chan,:,:]) ).abs().cpu().detach())
ax_arr[1].imshow(( Batch_SLM_Model( target_image )[batch,chan,:,:].abs()).cpu().detach()) 

ax_arr[0].set_title('Input field intensity')
ax_arr[1].set_title('Output intensity of 4F system with aperture')

ax_arr[0].axis('off')
ax_arr[1].axis('off')


plt.show()