### Core function for testing

In [73]:
from numpy.random import choice, rand, randn
import numpy as np
import lea  # probability calculations, see https://pypi.org/project/lea/
from collections import defaultdict
from sklearn import linear_model
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from dataclasses import dataclass, field
from typing import Optional
import matplotlib.pyplot as plt
import copy
import contourpy as cp
import pandas as pd
import random
import numba
from numba import jit, njit
from numba import types
from numba.typed import Dict

import math
def counts(A, V, time_steps, repeats, windowsize=3, pairs=2): # pairs = 2 implemented!
        A = A
        V = V
        #print(type(A))

        def calculate_state(draw_sequence):
            # Mapping for the states to digits
            state_to_digit = {-1: 0, 0: 1, 1: 2} # To change -1 to 1
            
            # Convert the draw sequence to a base-3 number
            base_3_number = 0
            for draw in draw_sequence:
                base_3_number = base_3_number * 3 + state_to_digit[draw]
            #print(draw_sequence, base_3_number)

            # The state is the base-3 number
            return base_3_number

        def apply_state(row):
            # Convert row to list and pass it to the calculate_state function
            #print(row.tolist())
            return calculate_state(row.tolist())
        
        if time_steps == 0:
            return np.zeros((repeats, 6 + 3 * pairs))
        
        if pairs == 0:
            CA = np.apply_along_axis(np.bincount, 1, A + 1, minlength=3)  # (repeats, 3)
            CV = np.apply_along_axis(np.bincount, 1, V + 1, minlength=3)  # (repeats, 3)
            C = np.concatenate((CA, CV), axis=1)

        elif pairs == 1:
            AV = (A + 1) + 3 * (V + 1)  # shape (repeats, time_steps)
            C = np.apply_along_axis(np.bincount, 1, AV, minlength=9)  # (repeats, 9)     

        elif pairs == 2: # consider windows with n number of consecutive AV-pairs
            max_state = 3**(2*windowsize) # 3**(2n) 
            C = np.zeros((repeats, max_state))
            for trialnum in range(repeats):
                #print(trialnum)
                _A = A[trialnum]
                _V = V[trialnum]
                #print(A.shape)
                df = pd.DataFrame()
                df['A'], df['V'] = _A, _V
                
                
                if windowsize == 2:
                    df['A-1'], df['V-1'] =  df['A'].shift(1), df['V'].shift(1) # Shifting column down one step
                if windowsize == 3:
                    df['A-1'], df['V-1'] =  df['A'].shift(1), df['V'].shift(1) # Shifting column down one step
                    df['A-2'], df['V-2'] =  df['A'].shift(2), df['V'].shift(2) # Shifting column down one step (window size is 3)
                df = df.dropna()
                #print(df)
                #return df
                
                # Apply the function to each row and store the result in a new column 'state'
                df['state'] = df.apply(apply_state, axis=1)
                # Calculate value counts
                state_counts = df['state'].value_counts()
                

                # Generate a range of numbers representing all possible states
                # Adjust the range based on your specific needs (max_state + 1)
                
                all_possible_states = range(0, max_state)  # Replace max_state with your actual maximum state value

                # Reindex the value counts to include all possible states
                # Fill missing values (states with 0 occurrences) with 0
                state_counts = state_counts.reindex(all_possible_states, fill_value=0)
                #return state_counts
                
                #state_counts = state_counts.values.reshape(1,-1)
                C[trialnum,:] = state_counts

        return C
"""
A = np.load('A.npy')
V = np.load('V.npy')
M = np.load('M.npy')
"""

test = np.load('test.npy', allow_pickle=1)
test = test.item()
train = np.load('train.npy', allow_pickle=1)
train = train.item()

In [166]:
u = np.unique(test.M)
if len(u) == 1:
    print('skip')

In [167]:
len(u)

3

### Observe outcomes of pairs = [0, 1, 2]

In [54]:
def getcounts(A, V, timesteps, repeats):
    C0 = counts(A,V, timesteps, repeats, windowsize=None, pairs=0)
    C1 = counts(A,V, timesteps, repeats, windowsize=None, pairs=1)
    C2 = counts(A,V, timesteps, repeats, windowsize=1, pairs=2)    
    return C0, C1, C2

In [66]:
# Train
Atrain = train.A
Vtrain = train.V
C0train, C1train, C2train = getcounts(Atrain, Vtrain, 100, 100)

# Test
Atest = test.A
Vtest = test.V
C0test, C1test, C2test = getcounts(Atest, Vtest, 100, 100)


In [67]:
print(np.bincount (Atrain[0]+1))
print(np.bincount (Vtrain[0]+1))

[38 31 31]
[38 36 26]


In [68]:
print(C0train[0] )


[38 31 31 38 36 26]


In [69]:
print(C1train[0])
print(C2train[0])

[19 12  7 14  9 13  5 10 11]
[19. 14.  5. 12.  9. 10.  7. 13. 11.]


In [70]:
print(np.bincount (Atest[0]+1))
print(np.bincount (Vtest[0]+1))

[34 36 30]
[33 35 32]


In [71]:
print(C0test[0] )

[34 36 30 33 35 32]


In [72]:
print(C1test[0])
print(C2test[0])

[ 6 15 12 15 11  9 13 10  9]
[ 6. 15. 13. 15. 11. 10. 12.  9.  9.]


In [76]:
dftrain = counts(Atrain,Vtrain, 100, 100, windowsize=1, pairs=2)
dftest = counts(Atest,Vtest, 100, 100, windowsize=1, pairs=2)

In [78]:
dftrain.head()

Unnamed: 0,A,V
0,0,-1
1,0,-1
2,-1,0
3,1,1
4,0,-1


In [79]:
dftest.head()

Unnamed: 0,A,V
0,0,0
1,0,-1
2,1,0
3,-1,1
4,1,-1


In [48]:
df1 = df
df1['A-1'], df1['V-1'] =  df1['A'].shift(1), df1['V'].shift(1)
df1.head()

Unnamed: 0,A,V,A-1,V-1
0,-1,1,,
1,1,-1,-1.0,1.0
2,0,-1,1.0,-1.0
3,-1,0,0.0,-1.0
4,0,-1,-1.0,0.0


In [58]:
df2 = df
df2['A-1'], df2['V-1'] =  df2['A'].shift(-1), df2['V'].shift(-1)
df2.head(10)

Unnamed: 0,A,V,A-1,V-1
0,-1,1,1.0,-1.0
1,1,-1,0.0,-1.0
2,0,-1,-1.0,0.0
3,-1,0,0.0,-1.0
4,0,-1,1.0,1.0
5,1,1,-1.0,0.0
6,-1,0,-1.0,0.0
7,-1,0,1.0,-1.0
8,1,-1,-1.0,1.0
9,-1,1,0.0,-1.0


In [53]:
print(A[-1][:10])
print(V[-1][:10])

[-1  1  0 -1  0  1 -1 -1  1 -1]
[ 1 -1 -1  0 -1  1  0  0 -1  1]


### Test nested functions within core fn

In [162]:
def calculate_state_(draw_sequence):
    # Mapping for the states to digits
    state_to_digit = {-1: 0, 0: 1, 1: 2} # To change -1 to 1
    
    # Convert the draw sequence to a base-3 number
    base_3_number = 0
    for draw in draw_sequence:
        base_3_number = base_3_number * 3 + state_to_digit[draw]
    #print(draw_sequence, base_3_number)

    # The state is the base-3 number
    return base_3_number

def apply_state_(row):
    # Convert row to list and pass it to the calculate_state function
    #print(row.tolist())
    return calculate_state_(row.tolist())

dftrain = counts(Atrain,Vtrain, 100, 100, windowsize=1, pairs=2)
dftest = counts(Atest,Vtest, 100, 100, windowsize=1, pairs=2)

# Line1:
dftrain['state'] = dftrain.apply(apply_state_, axis=1)
dftest['state'] = dftest.apply(apply_state_, axis=1)

windowsize = 1
max_state = 3**(2*windowsize) # 3**(2n)
C = np.zeros((100, max_state))

In [113]:
# Group by A and V, then list unique state values for each combination
summarytrain = dftrain.groupby(['A', 'V'])['state'].unique()
summarytest = dftest.groupby(['A', 'V'])['state'].unique()


# Print the summary
print(summarytrain)

print(summarytest)


A   V 
-1  -1    [0]
     0    [1]
     1    [2]
 0  -1    [3]
     0    [4]
     1    [5]
 1  -1    [6]
     0    [7]
     1    [8]
Name: state, dtype: object
A   V 
-1  -1    [0]
     0    [1]
     1    [2]
 0  -1    [3]
     0    [4]
     1    [5]
 1  -1    [6]
     0    [7]
     1    [8]
Name: state, dtype: object


In [160]:
def basenum(draw_sequence):
    base_number = 0
    for draw in draw_sequence:
        base_number = base_number * 3 + (draw + 1) #state_to_digit[draw]
    print(base_number)

basenum([0, 1, 1, 1])

basenum([1, 0, 1, 1])

basenum([1, 1, 0, 1])

basenum([1, 1, 1, 0])

basenum([1, 1, 1, 1, 0, 1])

basenum([1, 1, 1, 1, 1, 0])

basenum([1, 1, 1, 1, 1, 1])
basenum([1, -1, 1, 1, 1, 1])
basenum([-1, 1, 1, 1, 1, 1])

53
71
77
79
725
727
728
566
242


In [148]:
# Line2: 
# Calculate value counts
state_countstrain = dftrain['state'].value_counts()
state_countstest = dftest['state'].value_counts()

state_countstrain


state
0    19
1    14
7    13
3    12
8    11
5    10
4     9
6     7
2     5
Name: count, dtype: int64

In [149]:
# Line3:
# Generate a range of numbers representing all possible states
# Adjust the range based on your specific needs (max_state + 1)

all_possible_states = range(0, max_state)  # Replace max_state with your actual maximum state value
all_possible_states

range(0, 9)

In [163]:
# Line4:
# Reindex the value counts to include all possible states
# Fill missing values (states with 0 occurrences) with 0
state_countstrain = state_countstrain.reindex(all_possible_states, fill_value=0)

state_countstest = state_countstest.reindex(all_possible_states, fill_value=0)
print(state_countstrain)
print(C1train[0])
print(C2train[0])

state
0    19
1    14
2     5
3    12
4     9
5    10
6     7
7    13
8    11
Name: count, dtype: int64
[19 12  7 14  9 13  5 10 11]
[19. 14.  5. 12.  9. 10.  7. 13. 11.]


In [143]:
print(state_countstest)
print(C1test[0])

state
0     6
1    15
2    13
3    15
4    11
5    10
6    12
7     9
8     9
Name: count, dtype: int64
[ 6 15 12 15 11  9 13 10  9]


In [None]:


#state_counts = state_counts.values.reshape(1,-1)
C[trialnum,:] = state_counts

In [111]:
dftest

Unnamed: 0,A,V
0,0,0
1,0,-1
2,1,0
3,-1,1
4,1,-1
...,...,...
95,1,-1
96,1,-1
97,-1,0
98,-1,0
