# <img src="https://github.com/JuliaLang/julia-logo-graphics/raw/master/images/julia-logo-color.png" height="100" /> _Colab Notebook Template_

## Instructions
1. Work on a copy of this notebook: _File_ > _Save a copy in Drive_ (you will need a Google account). Alternatively, you can download the notebook using _File_ > _Download .ipynb_, then upload it to [Colab](https://colab.research.google.com/).
2. If you need a GPU: _Runtime_ > _Change runtime type_ > _Harware accelerator_ = _GPU_.
3. Execute the following cell (click on it and press Ctrl+Enter) to install Julia, IJulia and other packages (if needed, update `JULIA_VERSION` and the other parameters). This takes a couple of minutes.
4. Reload this page (press Ctrl+R, or ⌘+R, or the F5 key) and continue to the next section.

_Notes_:
* If your Colab Runtime gets reset (e.g., due to inactivity), repeat steps 2, 3 and 4.
* After installation, if you want to change the Julia version or activate/deactivate the GPU, you will need to reset the Runtime: _Runtime_ > _Factory reset runtime_ and repeat steps 3 and 4.

In [None]:
!git clone https://github.com/john-graves-les-mills/LAIF_Part_II.git

In [None]:
%%shell
set -e

#---------------------------------------------------#
JULIA_VERSION="1.8.2" # any version ≥ 0.7.0
JULIA_PACKAGES="IJulia BenchmarkTools"
JULIA_PACKAGES_IF_GPU="CUDA" # or CuArrays for older Julia versions
JULIA_NUM_THREADS=2
#---------------------------------------------------#

if [ -z `which julia` ]; then
  # Install Julia
  JULIA_VER=`cut -d '.' -f -2 <<< "$JULIA_VERSION"`
  echo "Installing Julia $JULIA_VERSION on the current Colab Runtime..."
  BASE_URL="https://julialang-s3.julialang.org/bin/linux/x64"
  URL="$BASE_URL/$JULIA_VER/julia-$JULIA_VERSION-linux-x86_64.tar.gz"
  wget -nv $URL -O /tmp/julia.tar.gz # -nv means "not verbose"
  tar -x -f /tmp/julia.tar.gz -C /usr/local --strip-components 1
  rm /tmp/julia.tar.gz

  # Install Packages
  nvidia-smi -L &> /dev/null && export GPU=1 || export GPU=0
  if [ $GPU -eq 1 ]; then
    JULIA_PACKAGES="$JULIA_PACKAGES $JULIA_PACKAGES_IF_GPU"
  fi
  for PKG in `echo $JULIA_PACKAGES`; do
    echo "Installing Julia package $PKG..."
    julia -e 'using Pkg; pkg"add '$PKG'; precompile;"' &> /dev/null
  done

  # Install kernel and rename it to "julia"
  echo "Installing IJulia kernel..."
  julia -e 'using IJulia; IJulia.installkernel("julia", env=Dict(
      "JULIA_NUM_THREADS"=>"'"$JULIA_NUM_THREADS"'"))'
  KERNEL_DIR=`julia -e "using IJulia; print(IJulia.kerneldir())"`
  KERNEL_NAME=`ls -d "$KERNEL_DIR"/julia*`
  mv -f $KERNEL_NAME "$KERNEL_DIR"/julia

  echo ''
  echo "Successfully installed `julia -v`!"
  echo "Please reload this page (press Ctrl+R, ⌘+R, or the F5 key) then"
  echo "jump to the 'Checking the Installation' section."
fi

Installing Julia 1.8.2 on the current Colab Runtime...
2024-03-05 20:51:19 URL:https://julialang-s3.julialang.org/bin/linux/x64/1.8/julia-1.8.2-linux-x86_64.tar.gz [135859273/135859273] -> "/tmp/julia.tar.gz" [1]
Installing Julia package IJulia...
Installing Julia package BenchmarkTools...
Installing IJulia kernel...
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mInstalling julia kernelspec in /root/.local/share/jupyter/kernels/julia-1.8

Successfully installed julia version 1.8.2!
Please reload this page (press Ctrl+R, ⌘+R, or the F5 key) then
jump to the 'Checking the Installation' section.




# Checking the Installation
The `versioninfo()` function should print your Julia version and some other info about the system:

In [1]:
versioninfo()

Julia Version 1.8.2
Commit 36034abf260 (2022-09-29 15:21 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 2 × Intel(R) Xeon(R) CPU @ 2.20GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, broadwell)
  Threads: 2 on 2 virtual cores
Environment:
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_NUM_THREADS = 2


Colab implementation of [T-maze_Generalized.ipynb](https://github.com/biaslab/LAIF/blob/main/Part2/T-maze_Generalized.ipynb) discussed in [Realising Synthetic Active Inference Agents, Part II: Variational Message Updates](https://arxiv.org/abs/2306.02733)

In [2]:
cd("LAIF_Part_II")

In [3]:
using Pkg
Pkg.activate(".")
Pkg.instantiate()

[32m[1m  Activating[22m[39m project at `/content/LAIF_Part_II`
[32m[1m   Installed[22m[39m Calculus ───────────────────────── v0.5.1
[32m[1m   Installed[22m[39m x265_jll ───────────────────────── v3.5.0+0
[32m[1m   Installed[22m[39m JpegTurbo_jll ──────────────────── v2.1.91+0
[32m[1m   Installed[22m[39m libfdk_aac_jll ─────────────────── v2.0.2+0
[32m[1m   Installed[22m[39m HypergeometricFunctions ────────── v0.3.16
[32m[1m   Installed[22m[39m OffsetArrays ───────────────────── v1.12.9
[32m[1m   Installed[22m[39m Libmount_jll ───────────────────── v2.35.0+0
[32m[1m   Installed[22m[39m GR_jll ─────────────────────────── v0.72.5+0
[32m[1m   Installed[22m[39m LERC_jll ───────────────────────── v3.0.0+1
[32m[1m   Installed[22m[39m StatsFuns ──────────────────────── v1.3.0
[32m[1m   Installed[22m[39m GraphPPL ───────────────────────── v3.1.0
[32m[1m   Installed[22m[39m Opus_jll ───────────────────────── v1.3.2+0
[32m[1m   Installed[22m

In [4]:
using RxInfer, LinearAlgebra, Plots

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling RxInfer [86711068-29c9-4ff7-b620-ae75d7495b3d]
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]


## Agent

In [5]:
include("./goal_observation.jl")

# Define the generative model
@model function t_maze(A_s, D_s, x)
    u = datavar(Matrix{Int64}, 2) # Policy for evaluations
    z = randomvar(2) # Latent states
    c = datavar(Vector{Float64}, 2) # Goal prior statistics

    z_0 ~ Categorical(D_s) # State prior
    A ~ MatrixDirichlet(A_s) # Observation matrix prior

    z_k_min = z_0
    for k=1:2
        z[k] ~ Transition(z_k_min, u[k])
        c[k] ~ GoalObservation(z[k], A) where {
            meta=GeneralizedMeta(x[k]),
            pipeline=GeneralizedPipeline(vague(Categorical,8))} # With breaker message

        z_k_min = z[k] # Reset for next slice
    end
end

# Define constraints on the variational density
@constraints function structured(approximate::Bool)
    q(z_0, z, A) = q(z_0, z)q(A)
    if approximate # Sampling approximation on A required for t<3
        q(A) :: SampleList(20)
    end
end

structured (generic function with 1 method)

## Simulation

In [6]:
# Define experimental setting
α = 0.9; c = 2.0 # Reward probability and utility
S = 100 # Number of trials
seed = 666 # Randomizer seed

include("helpers.jl")
include("environment.jl")
include("agent.jl")

(A, B, C, D) = constructABCD(α, c)
(A_0, D_0) = constructPriors() # Construct prior statistics for A and D

rs = generateGoalSequence(seed, S) # Sets random seed and returns reproducible goal sequence
(reset, execute, observe) = initializeWorld(A, B, C, D, rs) # Let there be a world
(infer, act) = initializeAgent(A_0, B, C, D_0) # Let there be a constrained agent

# Step through the experimental protocol
As = Vector{Matrix}(undef, S) # Posterior statistics for A
Gs = [Vector{Matrix}(undef, 3) for s=1:S] # Free energy values per time
as = [Vector{Int64}(undef, 2) for s=1:S] # Actions per time
os = [Vector{Vector}(undef, 2) for s=1:S] # Observations (one-hot) per time
for s = 1:S
    reset(s) # Reset world
    for t=1:2
        (Gs[s][t], _) = infer(t, as[s], os[s])
             as[s][t] = act(t, Gs[s][t])
                        execute(as[s][t])
             os[s][t] = observe()
    end
    (Gs[s][3], As[s]) = infer(3, as[s], os[s]) # Learn at t=3
end
;

## Results

In [7]:
include("visualizations.jl")
plotObservationStatistics(As[S], A_0)
savefig("figures/GFE_A")

"/content/LAIF_Part_II/figures/GFE_A.png"

In [8]:
include("visualizations.jl")
plotFreeEnergyMinimum(Gs, os, legend=100, ylim=(5,25))
savefig("figures/GFE_FE.png")

"/content/LAIF_Part_II/figures/GFE_FE.png"

## Log

In [9]:
using Dates;now()

2024-03-05T22:28:17.555