In [None]:
import numpy as np
import yaml
import matplotlib.pyplot as plt

#Import hardware classes
from pyRTC.hardware.OOPAOInterface import OOPAOInterface
from pyRTC.SlopesProcess import SlopesProcess
from pyRTC.Pipeline import *
from pyRTC.Loop import Loop
from pyRTC.hardware.LoopGymInterface import LoopGymInterface

import gymnasium
from rtgym import DEFAULT_CONFIG_DICT
import time

In [None]:
"""
Shared memory in python is a bit annoying, we are required to unlink it from the garbage collector
so that it will stick around in between runs, however sometime you can get into a situation where 
the SHM is not intialized properly. Usually you will see an error like: 
TypeError: buffer is too small for requested array

To reset a SHM you can run the following code. Note: it will throw some garbage collector errors.
"""
# shm_names = ["wfs", "wfsRaw", "wfc", "wfc2D", "signal", "signal2D", "psfShort", "psfLong"] #list of SHMs to reset
# clear_shms(shm_names)

In [None]:
# Load the configuration file
def read_yaml_file(file_path):
    with open(file_path, 'r') as file:
        conf = yaml.safe_load(file)
    return conf

#Now we can read our YAML config file 
conf = read_yaml_file("simple_OOPAO_config.yaml")

#And separate it into sections for each of our AO loop components
confLoop = conf["loop"]
confWFS = conf["wfs"]
confWFC = conf["wfc"]
confPSF = conf["psf"]
confSlopes = conf["slopes"]

print(confLoop)
print(confWFS)
print(confWFC)
print(confPSF)
print(confSlopes)

In [None]:
from simpleParamFile import *
param = initializeParameterFile()

In [None]:
"""
Create the OOPAO simulation interface object 
Running this cell will initialize the dm, wfs, psf, and slopes objects, 
but will not start their real time computations. This inialization includes
the creation of the Shared Memory Objects, and the simulation inialization.
"""
sim = OOPAOInterface(conf=conf, param=param)
wfs, dm, psf = sim.get_hardware()

plt.imshow(wfs.wfs.cam.frame)
plt.show()

plt.imshow(dm.layout)
plt.show()

print(f"NUM VALID ACT: {np.sum(dm.layout)}")

In [None]:
"""
It's important to set the full basis and number of possible modes before
initializing the loop object. Here I define a KL basis for the system
"""
from OOPAO.calibration.compute_KL_modal_basis import compute_KL_basis

NUM_MODES = 18

M2C_KL = compute_KL_basis(sim.tel, sim.atm, sim.dm)
dm.setM2C(M2C_KL[:,:NUM_MODES])

"""
"""
slopes = SlopesProcess(conf=conf)

""" 
"""
#Initialize our AO loop object
loop = Loop(conf=conf)

In [None]:
"""
Start the processes. Here the real-time computations selected in
the config will begin.
"""
dm.start()
dm.flatten()

wfs.start()
slopes.start()

#Take new reference slopes while dm is flat.
# time.sleep(1)
# slopes.takeRefSlopes()

print(sim.dm.OPD.shape)
psf.start()

In [None]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces


class CustomEnv(gym.Env):
    """Custom Environment that follows gym interface."""

    metadata = {"render_modes": ["human"], "render_fps": 30}

    def __init__(self, loop, render_mode=None, buffer=0):
        super().__init__()
        self.loop = loop
        self.buffer = buffer
        self.render_mode = render_mode
        self.default_action = np.zeros(self.loop.confWFC['numModes'], dtype=np.float32)
        self.flat2D = self.flat2D = np.zeros((self.loop.wfc2D_width, self.loop.wfc2D_height))
        self.active_modes = self.loop.numActiveModes
    

        self.timestep_limit = 100
        self.current_step = 0
        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        
        # Example for using image as input (channel-first; channel-last also works):
        self.observation_space = spaces.Dict(
            {
                "slopes": spaces.Box(low=-1, high=1,
                                            shape=(self.loop.slopes_width, self.loop.slopes_height), dtype=np.float32),
                "command": spaces.Box(low=-1, high=1,
                                            shape=(self.loop.wfc2D_width, self.loop.wfc2D_height), dtype=np.float32)
            }
        )
        
        self.action_space = spaces.Box(low=-1, high=1, shape=(self.loop.confWFC["numModes"], ) , dtype=np.float32)


    def _get_obs(self):
        """A private method to build an observation from a state"""

        self._normalized_slopes = self._slopes_obs
        self._normalized_cmd2D = self._cmd2D_obs

        if not np.all(self._cmd2D_obs == 0):
            self._normalized_cmd2D = self._cmd2D_obs / np.linalg.norm(self._cmd2D_obs)

        if not np.all(self._slopes_obs == 0):
            self._normalized_slopes = self._slopes_obs / np.linalg.norm(self._slopes_obs)

        return {"slopes": self._normalized_slopes, "command": self._normalized_cmd2D}
    
    
    def _get_info(self):
        return {
            'TimeLimit.truncated': False
        }
    

    def _send_control(self, control):

        control[self.active_modes:] = 0

        self.control = control
        
        self.loop.wfcShm.write(control)


    def step(self, action):
        self.current_step += 1

        #send the command to the mirror
        self._send_control(action)

        # make a blocking read of the wfc to make sure the action has been set
        self._cmd2D_obs = self.loop.wfc2DShm.read()

        # read the slopes
        self._slopes_obs = self.loop.slopesShm.read()

        observation = self._get_obs()

        # Compute reward
        reward = np.exp(-np.var(self._slopes_obs), dtype=np.float32).item()

        #Terminated condition -> divergent slopes
        terminated = bool(reward < np.exp(-9))

        # Info + Truncation
        if self.current_step >= self.timestep_limit:
            truncated = True
            info = {'TimeLimit.truncated': True}
        else:
            truncated = False
            info = {'TimeLimit.truncated': False}

        
        return observation, reward, terminated, truncated, info


    def reset(self, seed=None, options=None):
        self.current_step = 0

        #Flatten the mirror
        self._send_control(self.default_action)

        # Sleep to make sure the blocking order is reset
        time.sleep(0.01)

        #Set the current command
        self._cmd2D_obs = self.loop.wfc2DShm.read()

        #Read the slopes
        self._slopes_obs = self.loop.slopesShm.read()

        #Build the observation
        observation = self._get_obs()

        # Get info dict
        info = self._get_info()

    
        return observation, info


    def render(self):
        if self.render_mode == None:
            return

    def close(self):
        return

In [None]:
# from pyRTC.hardware.GymEnv import CustomEnv

env = CustomEnv(loop=loop)

In [None]:
env.reset()

for i in range(10):
    obs, reward, *_ = env.step(env.action_space.sample())

env.reset()

In [None]:
env.reset()

obs, *_ = env.step(env.action_space.sample())

plt.imshow(obs['slopes'])
plt.show()

In [None]:
from stable_baselines3.common import env_checker

env_checker.check_env(env, warn=True, skip_render_check=True)

In [None]:
import gymnasium as gym
import torch as th
from torch import nn

from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


class CustomCombinedExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict):
        # We do not know features-dim here before going over all the items,
        # so put something dummy for now. PyTorch requires calling
        # nn.Module.__init__ before adding modules
        super().__init__(observation_space, features_dim=1)

        self._numFeatures = 16

        self.cnn = nn.Sequential(
                                nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding='same'),
                                nn.ReLU(),
                                nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding='same'),
                                nn.ReLU(),
                                nn.Flatten())
        

        # find the output size of the CNN for each obs
        n_flatten = {}

        
        with th.no_grad():
            for key, subspace in observation_space.spaces.items():

                n_flatten[key] = self.cnn(
                    th.as_tensor(observation_space.sample()[key][np.newaxis,:,:]).float()
                ).numel()

        print(n_flatten)        

    
        extractors = {}

        # We need to know size of the output of this extractor,
        # so go over all the spaces and compute output feature sizes
        for key, subspace in observation_space.spaces.items():
            if key == "slopes":
                # We will just downsample one channel of the image by 4x4 and flatten.
                # Assume the image is single-channel (subspace.shape[0] == 0)
                extractors[key] = nn.Sequential(self.cnn,
                                nn.Linear(n_flatten[key], self._numFeatures),
                                nn.ReLU(),
                                nn.Flatten())
                
            elif key == "command":
                # Run through a simple MLP
                extractors[key] = nn.Sequential(self.cnn,
                                nn.Linear(n_flatten[key], self._numFeatures),
                                nn.ReLU(),
                                nn.Flatten())

        self.extractors = nn.ModuleDict(extractors)

        # Update the features dim manually
        self._features_dim = 32

    def forward(self, observations) -> th.Tensor:
        encoded_tensor_list = []

        # self.extractors contain nn.Modules that do all the processing.
        for key, extractor in self.extractors.items():
            print(th.Tensor(observations[key][:, np.newaxis, :, :]).shape)
            print(extractor(th.Tensor(observations[key][:, np.newaxis,:, :])).shape)
            encoded_tensor_list.append(extractor(th.Tensor(observations[key][:, np.newaxis, :, :])).view(1,-1))
        # Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension.
        return th.cat(encoded_tensor_list, dim=1)

In [None]:
from stable_baselines3.common.vec_env import DummyVecEnv

env1 = DummyVecEnv([lambda: env])

In [None]:
x = np.empty((64,5,5))

print(x[:,np.newaxis,:,:].shape)


In [None]:
from collections import OrderedDict

env1.reset()

obs, *_ = env1.step([env.action_space.sample()])

print(obs)

def merge_ordered_dicts(d1, d2):
    merged_dict = OrderedDict()
    for key in d1:
        # Each key points to a list of values from both dictionaries
        merged_dict[key] = [d1[key], d2[key]]
    return merged_dict

# Merging the dictionaries
merged_dict = merge_ordered_dicts(obs, obs)

print(CustomCombinedExtractor(env.observation_space)(merged_dict).shape)

In [None]:
from stable_baselines3 import PPO

policy_kwargs = dict(
    features_extractor_class=CustomCombinedExtractor,
    net_arch=[32, 32],
)

model = PPO("MultiInputPolicy", env1, policy_kwargs=policy_kwargs, verbose=1)

In [None]:
model.learn(total_timesteps=10)

In [None]:
sim.addAtmosphere()

In [None]:
#Remove the atmosphere from the simulation
sim.removeAtmosphere()

loop.pokeAmp = 1e-7

#Compute the IM, blocking
loop.computeIM()

#Add the atmosphere back to the simulation
sim.addAtmosphere()

In [None]:
loop.plotIM()

In [None]:
dm.flatten()
time.sleep(1e-2)
loop.setGain(0.3)
loop.start()
time.sleep(10)
loop.stop()
dm.flatten()

In [None]:
dm.push(10, 1e-6)

In [None]:
# loop.saveIM("simpleIM.npy")