In [1]:
import os
import pandas as pd
import numpy as np

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Load datasets

## Main Survey Items (Trend and New)

In [2]:
main_items = pd.read_excel('input_data/PISA2015_TechRep_Final-AnnexA.xlsx', sheet_name=None, skiprows=1)


### Get item names from trend survey (3 sheets)

In [3]:
items = ['CNT', 'CNTSTUID']
sheets = list(main_items.keys())
#print(main_items[sheets[0]].columns)

for s in sheets[:2]:
    items.extend(main_items[s]['CBA item ID in main survey analysis output'].values.tolist())
    items.extend(main_items[s]['PBA item ID in main survey analysis output'].values.tolist())

for s in sheets[10:12]:
    items.extend(main_items[s]['Item ID\nin analysis output'].values.tolist())
    
print(items)

['CNT', 'CNTSTUID', 'DS131Q02C', 'DS131Q04C', 'CS252Q01S', 'CS252Q02S', 'CS252Q03S', 'CS256Q01S', 'CS268Q01S', 'DS268Q02C', 'CS268Q06S', 'DS269Q01C', 'DS269Q03C', 'CS269Q04S', 'DS304Q01C', 'CS304Q02S', 'DS304Q03aC', 'DS304Q03bC', 'DS326Q01C', 'DS326Q02C', 'CS326Q03S', 'CS326Q04S', 'CS327Q01S', 'CS408Q01S', 'DS408Q03C', 'CS408Q04S', 'CS408Q05S', 'CS413Q04S', 'CS413Q05S', 'CS413Q06S', 'CS415Q02S', 'CS415Q07S', 'CS415Q08S', 'DS416Q01C', 'CS421Q01S', 'CS421Q02S', 'CS421Q03S', 'CS425Q02S', 'DS425Q03C', 'DS425Q04C', 'CS425Q05S', 'CS428Q01S', 'CS428Q03S', 'DS428Q05C', 'PS131Q02', 'PS131Q04', 'PS252Q01S', 'PS252Q02S', 'PS252Q03S', 'PS256Q01S', 'PS268Q01S', 'S268Q02', 'PS268Q06S', 'PS269Q01', 'PS269Q03', 'PS269Q04S', 'S304Q01', 'PS304Q02S', 'S304Q03a', 'S304Q03b', 'PS326Q01', 'PS326Q02', 'PS326Q03S', 'PS326Q04S', 'PS327Q01S', 'PS408Q01S', 'PS408Q03', 'PS408Q04S', 'PS408Q05S', 'PS413Q04S', 'PS413Q05S', 'PS413Q06', 'PS415Q02S', 'PS415Q07S', 'PS415Q08S', 'S416Q01', 'S421Q01', 'S421Q02', 'S421Q03',

## Cognitve dataset

In [None]:
cols = ['CNT', 'CNTSTUID', 'CS601Q01S', 'CS601Q02S', 'CS601Q04S', 'CS602Q01S',
    'CS602Q02S', 'DS602Q03C', 'CS602Q04S', 'CS603Q01S', 'DS603Q02C', 
    'CS603Q03S', 'CS603Q04S', 'CS603Q05S', 'CS604Q02S', 'DS604Q04C', 
    'CS605Q01S', 'CS605Q02S', 'CS605Q03S', 'DS605Q04C', 'CS607Q01S', 
    'CS607Q02S', 'DS607Q03C', 'CS608Q01S', 'CS608Q02S', 'CS608Q03S', 
    'DS608Q04C', 'DS610Q01C', 'CS610Q02S', 'CS610Q04S', 'CS615Q01S', 
    'CS615Q02S', 'CS615Q05S', 'CS615Q07S', 'CS620Q01S', 'CS620Q02S', 
    'DS620Q04C', 'DS625Q01C', 'CS625Q02S', 'CS625Q03S', 'CS626Q01S', 
    'CS626Q02S', 'CS626Q03S', 'DS626Q04C', 'CS627Q01S', 'CS627Q03S', 
    'CS627Q04S', 'DS629Q01C', 'CS629Q02S', 'DS629Q03C', 'CS629Q04S', 
    'CS634Q01S', 'CS634Q02S', 'DS634Q03C', 'CS634Q04S', 'DS634Q05C', 
    'CS635Q01S', 'CS635Q02S', 'DS635Q03C', 'CS635Q04S', 'DS635Q05C', 
    'DS637Q01C', 'CS637Q02S', 'DS637Q05C', 'CS638Q01S', 'CS638Q02S', 
    'CS638Q04S', 'DS638Q05C', 'CS641Q01S', 'CS641Q02S', 'CS641Q03S', 
    'CS641Q04S', 'CS643Q01S', 'CS643Q02S', 'DS643Q03C', 'CS643Q04S', 
    'DS643Q05C', 'CS645Q01S', 'CS645Q03S', 'DS645Q04C', 'DS645Q05C', 
    'CS646Q01S', 'CS646Q02S', 'CS646Q03S', 'DS646Q04C', 'DS646Q05C', 
    'DS648Q01C', 'CS648Q02S', 'CS648Q03S', 'DS648Q05C', 'CS649Q01S', 
    'DS649Q02C', 'CS649Q03S', 'CS649Q04S', 'CS656Q01S', 'DS656Q02C', 
    'CS656Q04S', 'CS657Q01S', 'CS657Q02S', 'CS657Q03S', 'DS657Q04C']

#cog = pd.read_sas('PUF_SAS_COMBINED_CMB_STU_COG/cy6_ms_cmb_stu_cog.sas7bdat')
#data_df = cog[cols]
#data_df['CNT'] = data_df['CNT'].str.decode('utf-8')


In [8]:
cog = pd.read_spss(
    'PUF_SPSS_COMBINED_CMB_STU_COG/CY6_MS_CMB_STU_COG.sav', usecols=items)

In [9]:
data_df = cog.copy().astype(str)
response = np.ones(data_df.shape) * -1

response[data_df.values == '0 - No credit'] = 0
response[data_df.values == '0 - No credit'] = 0
response[data_df.values == '01 - No credit'] = 0
response[data_df.values == '02 - No credit'] = 0
response[data_df.values == '03 - No credit'] = 0
response[data_df.values == '04 - No credit'] = 0
response[data_df.values == 'No credit'] = 0
response[data_df.values == '1 - Full credit'] = 1
response[data_df.values == '1 - Partial credit'] = 1
response[data_df.values == '11 - Full credit'] = 1
response[data_df.values == '11 - Partial credit'] = 1
response[data_df.values == '12 - Full credit'] = 1
response[data_df.values == '12 - Partial credit'] = 1
response[data_df.values == '2 - Full credit'] = 1
response[data_df.values == '21 - Full credit'] = 1
response[data_df.values == 'Full credit'] = 1


In [10]:
data_df.iloc[:, 2:] = response[:, 2:]
data_df['CNTSTUID'] = data_df['CNTSTUID'].astype(float)
data_df['CNT'] = data_df['CNT'].convert_dtypes(str)

In [11]:
data_df.to_csv('cog_science.csv', sep=';', header=True, index=False)

## Context dataset

In [12]:
ctx_cols = ['ST004D01T', 'ST005Q01TA', 'ST006Q01TA', 'ST006Q02TA', 'ST006Q03TA',
            'ST006Q04TA', 'ST007Q01TA', 'ST008Q01TA', 'ST008Q02TA', 'ST008Q03TA',
            'ST008Q04TA', 'ST011Q01TA', 'ST011Q02TA', 'ST011Q03TA', 'ST011Q04TA',
            'ST011Q05TA', 'ST011Q06TA', 'ST011Q07TA', 'ST011Q08TA', 'ST011Q09TA',
            'ST011Q10TA', 'ST011Q11TA', 'ST011Q12TA', 'ST011Q16NA', 'ST011D17TA',
            'ST011D18TA', 'ST011D19TA', 'ST012Q01TA', 'ST012Q02TA', 'ST012Q03TA',
            'ST012Q05NA', 'ST012Q06NA', 'ST012Q07NA', 'ST012Q08NA', 'ST012Q09NA',
            'ST013Q01TA', 'ST123Q01NA', 'ST123Q02NA', 'ST123Q03NA', 'ST123Q04NA',
            'ST019AQ01T', 'ST019BQ01T', 'ST019CQ01T', 'ST021Q01TA', 'ST022Q01TA',
            'ST124Q01TA', 'ST125Q01NA', 'ST126Q01TA', 'ST127Q01TA', 'ST127Q02TA',
            'ST127Q03TA', 'ST111Q01TA', 'ST118Q01NA', 'ST118Q02NA', 'ST118Q03NA',
            'ST118Q04NA', 'ST118Q05NA', 'ST119Q01NA', 'ST119Q02NA', 'ST119Q03NA',
            'ST119Q04NA', 'ST119Q05NA', 'ST121Q01NA', 'ST121Q02NA', 'ST121Q03NA',
            'ST082Q01NA', 'ST082Q02NA', 'ST082Q03NA', 'ST082Q08NA', 'ST082Q09NA',
            'ST082Q12NA', 'ST082Q13NA', 'ST082Q14NA', 'ST034Q01TA', 'ST034Q02TA',
            'ST034Q03TA', 'ST034Q04TA', 'ST034Q05TA', 'ST034Q06TA', 'ST039Q01NA',
            'ST039Q02NA', 'ST039Q03NA', 'ST039Q04NA', 'ST039Q05NA', 'ST039Q06NA',
            'ST059Q01TA', 'ST059Q02TA', 'ST059Q03TA', 'ST060Q01NA', 'ST061Q01NA',
            'ST062Q01TA', 'ST062Q02TA', 'ST062Q03TA', 'ST071Q01NA', 'ST071Q02NA',
            'ST071Q03NA', 'ST071Q04NA', 'ST071Q05NA', 'ST031Q01NA', 'ST032Q01NA',
            'ST032Q02NA', 'ST063Q01NA', 'ST063Q01NB', 'ST063Q02NA', 'ST063Q02NB',
            'ST063Q03NA', 'ST063Q03NB', 'ST063Q04NA', 'ST063Q04NB', 'ST063Q05NA',
            'ST063Q05NB', 'ST063Q06NA', 'ST063Q06NB', 'ST064Q01NA', 'ST064Q02NA',
            'ST064Q03NA', 'ST097Q01TA', 'ST097Q02TA', 'ST097Q03TA', 'ST097Q04TA',
            'ST097Q05TA', 'ST098Q01TA', 'ST098Q02TA', 'ST098Q03NA', 'ST098Q05TA',
            'ST098Q06TA', 'ST098Q07TA', 'ST098Q08NA', 'ST098Q09TA', 'ST098Q10NA',
            'ST100Q01TA', 'ST100Q02TA', 'ST100Q03TA', 'ST100Q04TA', 'ST100Q05TA',
            'ST103Q01NA', 'ST103Q03NA', 'ST103Q08NA', 'ST103Q11NA', 'ST104Q01NA',
            'ST104Q02NA', 'ST104Q03NA', 'ST104Q04NA', 'ST104Q05NA', 'ST107Q01NA',
            'ST107Q02NA', 'ST107Q03NA', 'ST092Q01TA', 'ST092Q02TA', 'ST092Q04TA',
            'ST092Q05TA', 'ST092Q06NA', 'ST092Q08NA', 'ST092Q09NA', 'ST093Q01TA',
            'ST093Q03TA', 'ST093Q04TA', 'ST093Q05TA', 'ST093Q06TA', 'ST093Q07NA',
            'ST093Q08NA', 'ST094Q01NA', 'ST094Q02NA', 'ST094Q03NA', 'ST094Q04NA',
            'ST094Q05NA', 'ST095Q04NA', 'ST095Q07NA', 'ST095Q08NA', 'ST095Q13NA',
            'ST095Q15NA', 'ST113Q01TA', 'ST113Q02TA', 'ST113Q03TA', 'ST113Q04TA',
            'ST129Q01TA', 'ST129Q02TA', 'ST129Q03TA', 'ST129Q04TA', 'ST129Q05TA',
            'ST129Q06TA', 'ST129Q07TA', 'ST129Q08TA', 'ST131Q01NA', 'ST131Q03NA',
            'ST131Q04NA', 'ST131Q06NA', 'ST131Q08NA', 'ST131Q11NA', 'ST146Q01TA',
            'ST146Q02TA', 'ST146Q03TA', 'ST146Q04TA', 'ST146Q05TA', 'ST146Q06NA',
            'ST146Q07NA', 'ST146Q08NA', 'ST146Q09NA', 'ST076Q01NA', 'ST076Q02NA',
            'ST076Q03NA', 'ST076Q04NA', 'ST076Q05NA', 'ST076Q06NA', 'ST076Q07NA',
            'ST076Q08NA', 'ST076Q09NA', 'ST076Q10NA', 'ST076Q11NA', 'ST078Q01NA',
            'ST078Q02NA', 'ST078Q03NA', 'ST078Q04NA', 'ST078Q05NA', 'ST078Q06NA',
            'ST078Q07NA', 'ST078Q08NA', 'ST078Q09NA', 'ST078Q10NA', 'ST078Q11NA',
            'ST065Class', 'CNTSTUID']

ctx2_cols = ['ST016Q01NA', 'ST038Q01NA', 'ST038Q02NA', 'ST038Q03NA', 'ST038Q04NA', 
             'ST038Q05NA', 'ST038Q06NA', 'ST038Q07NA', 'ST038Q08NA', 'CNTSTUID']


In [13]:
ctx = pd.read_spss('PUF_SPSS_COMBINED_CMB_STU_QQQ/CY6_MS_CMB_STU_QQQ.sav', usecols=ctx_cols)
ctx2 = pd.read_spss(
    'PUF_SPSS_COMBINED_CMB_STU_QQQ/CY6_MS_CMB_STU_QQ2.sav', usecols=ctx2_cols)


## Concateate all datasets

In [14]:
ctx = pd.concat([ctx, ctx2])

## Save backup data

In [15]:
ctx.to_csv('ctx_science.csv', sep=';', header=True, index=False)

# Context data coding

In [None]:
df_ctx_code = pd.read_excel(
    'Technical Report 2015 - Annex B Contrast Coding_FINAL .xlsx')

ctx_code_cols = ['ST004D01T', 'ST005Q01TA', 'ST006Q01TA', 'ST006Q02TA', 'ST006Q03TA',
            'ST006Q04TA', 'ST007Q01TA', 'ST008Q01TA', 'ST008Q02TA', 'ST008Q03TA',
            'ST008Q04TA', 'ST011Q01TA', 'ST011Q02TA', 'ST011Q03TA', 'ST011Q04TA',
            'ST011Q05TA', 'ST011Q06TA', 'ST011Q07TA', 'ST011Q08TA', 'ST011Q09TA',
            'ST011Q10TA', 'ST011Q11TA', 'ST011Q12TA', 'ST011Q16NA', 'ST011D17TA',
            'ST011D18TA', 'ST011D19TA', 'ST012Q01TA', 'ST012Q02TA', 'ST012Q03TA',
            'ST012Q05NA', 'ST012Q06NA', 'ST012Q07NA', 'ST012Q08NA', 'ST012Q09NA',
            'ST013Q01TA', 'ST123Q01NA', 'ST123Q02NA', 'ST123Q03NA', 'ST123Q04NA',
            'ST019AQ01T', 'ST019BQ01T', 'ST019CQ01T', 'ST021Q01TA', 'ST022Q01TA',
            'ST124Q01TA', 'ST125Q01NA', 'ST126Q01TA', 'ST127Q01TA', 'ST127Q02TA',
            'ST127Q03TA', 'ST111Q01TA', 'ST118Q01NA', 'ST118Q02NA', 'ST118Q03NA',
            'ST118Q04NA', 'ST118Q05NA', 'ST119Q01NA', 'ST119Q02NA', 'ST119Q03NA',
            'ST119Q04NA', 'ST119Q05NA', 'ST121Q01NA', 'ST121Q02NA', 'ST121Q03NA',
            'ST082Q01NA', 'ST082Q02NA', 'ST082Q03NA', 'ST082Q08NA', 'ST082Q09NA',
            'ST082Q12NA', 'ST082Q13NA', 'ST082Q14NA', 'ST034Q01TA', 'ST034Q02TA',
            'ST034Q03TA', 'ST034Q04TA', 'ST034Q05TA', 'ST034Q06TA', 'ST039Q01NA',
            'ST039Q02NA', 'ST039Q03NA', 'ST039Q04NA', 'ST039Q05NA', 'ST039Q06NA',
            'ST059Q01TA', 'ST059Q02TA', 'ST059Q03TA', 'ST060Q01NA', 'ST061Q01NA',
            'ST062Q01TA', 'ST062Q02TA', 'ST062Q03TA', 'ST071Q01NA', 'ST071Q02NA',
            'ST071Q03NA', 'ST071Q04NA', 'ST071Q05NA', 'ST031Q01NA', 'ST032Q01NA',
            'ST032Q02NA', 'ST063Q01NA', 'ST063Q01NB', 'ST063Q02NA', 'ST063Q02NB',
            'ST063Q03NA', 'ST063Q03NB', 'ST063Q04NA', 'ST063Q04NB', 'ST063Q05NA',
            'ST063Q05NB', 'ST063Q06NA', 'ST063Q06NB', 'ST064Q01NA', 'ST064Q02NA',
            'ST064Q03NA', 'ST097Q01TA', 'ST097Q02TA', 'ST097Q03TA', 'ST097Q04TA',
            'ST097Q05TA', 'ST098Q01TA', 'ST098Q02TA', 'ST098Q03NA', 'ST098Q05TA',
            'ST098Q06TA', 'ST098Q07TA', 'ST098Q08NA', 'ST098Q09TA', 'ST098Q10NA',
            'ST100Q01TA', 'ST100Q02TA', 'ST100Q03TA', 'ST100Q04TA', 'ST100Q05TA',
            'ST103Q01NA', 'ST103Q03NA', 'ST103Q08NA', 'ST103Q11NA', 'ST104Q01NA',
            'ST104Q02NA', 'ST104Q03NA', 'ST104Q04NA', 'ST104Q05NA', 'ST107Q01NA',
            'ST107Q02NA', 'ST107Q03NA', 'ST092Q01TA', 'ST092Q02TA', 'ST092Q04TA',
            'ST092Q05TA', 'ST092Q06NA', 'ST092Q08NA', 'ST092Q09NA', 'ST093Q01TA',
            'ST093Q03TA', 'ST093Q04TA', 'ST093Q05TA', 'ST093Q06TA', 'ST093Q07NA',
            'ST093Q08NA', 'ST094Q01NA', 'ST094Q02NA', 'ST094Q03NA', 'ST094Q04NA',
            'ST094Q05NA', 'ST095Q04NA', 'ST095Q07NA', 'ST095Q08NA', 'ST095Q13NA',
            'ST095Q15NA', 'ST113Q01TA', 'ST113Q02TA', 'ST113Q03TA', 'ST113Q04TA',
            'ST129Q01TA', 'ST129Q02TA', 'ST129Q03TA', 'ST129Q04TA', 'ST129Q05TA',
            'ST129Q06TA', 'ST129Q07TA', 'ST129Q08TA', 'ST131Q01NA', 'ST131Q03NA',
            'ST131Q04NA', 'ST131Q06NA', 'ST131Q08NA', 'ST131Q11NA', 'ST146Q01TA',
            'ST146Q02TA', 'ST146Q03TA', 'ST146Q04TA', 'ST146Q05TA', 'ST146Q06NA',
            'ST146Q07NA', 'ST146Q08NA', 'ST146Q09NA', 'ST076Q01NA', 'ST076Q02NA',
            'ST076Q03NA', 'ST076Q04NA', 'ST076Q05NA', 'ST076Q06NA', 'ST076Q07NA',
            'ST076Q08NA', 'ST076Q09NA', 'ST076Q10NA', 'ST076Q11NA', 'ST078Q01NA',
            'ST078Q02NA', 'ST078Q03NA', 'ST078Q04NA', 'ST078Q05NA', 'ST078Q06NA',
            'ST078Q07NA', 'ST078Q08NA', 'ST078Q09NA', 'ST078Q10NA', 'ST078Q11NA',
            'ST065Class', 'ST016Q01NA', 'ST038Q01NA', 'ST038Q02NA', 'ST038Q03NA',
            'ST038Q04NA', 'ST038Q05NA', 'ST038Q06NA', 'ST038Q07NA', 'ST038Q08NA']

ctx_code = df_ctx_code[df_ctx_code['ITEM_ID'].isin(ctx_code_cols)]

## Save contrast code for students

In [None]:
ctx_code.to_csv('ctx_code.csv', sep=';', header=True, index=False)

In [None]:
cnt = data_df['CNT'].isin(['BRA', 'NOR'])
cnt_df = data_df.loc[cnt]
cnt_df = cnt_df.reset_index(drop=True)

print(cnt_df.shape)

In [None]:
print(cnt_df)

In [None]:
imp = np.c_[imp, data_df['CNT']]

In [None]:
print(imp.shape)

In [None]:
response = imp[cnt]
#response = response.reset_index(drop=True)
print(response[:, 99])

In [None]:
print(cnt_df[cnt_df['CNT'] == 'BRA'].shape)
print(cnt_df[cnt_df['CNT'] == 'NOR'].shape)
print(cnt_df['CNT'])

In [None]:
response[response == np.nan]


In [None]:
#data_df = pd.read_csv('cog_science.csv')
#response = np.zeros(cnt_df.shape)

response[response[:, 99] == 'NOR'] = 1
response[response[:, 99] == 'BRA'] = 0


In [None]:
print(response)


In [None]:
print(response.shape)
print(response[:, -1])
#print(data_df[data_df == '0.0'].shape)

In [None]:
cache_file = os.path.join('score_matrix.npy')
np.save(cache_file, response)

In [None]:
rs = np.random.RandomState(42)
swapper = np.arange(response.shape[0])
rs.shuffle(swapper)
response = response[swapper]

#rows_to_remove = np.sum(response, 1) == (-1 * response.shape[1])
#rows_to_remove = np.sum(response, 1) <= -1
#response = response[~rows_to_remove]
#response = response[np.all(response != -1, axis=1)]

In [None]:
ref = pd.read_csv('Reference Matrix.csv', delimiter=';')

In [None]:
c = ['Items', '111', '112', '113', '121', '122', '123', '131', '132', '133',
              '211', '212', '213', '221', '222', '223', '231', '232', '233',
              '311', '312', '313', '321', '322', '323', '331', '332', '333']

q_matrix = pd.DataFrame(columns=c)
q_matrix['Items'] = ref['Item ID in analysis output'].to_numpy()
q_matrix.iloc[:, 1:] = 0
#print(q_matrix)

In [None]:
for i, r in ref.iterrows():
    seq = None
    if 'Explain phenomena scientifically' in r['Competency (2015)']:
        seq = '1'
    elif 'Interpret data and evidence scientifically' in r['Competency (2015)']:
        seq = '2'
    elif 'Evaluate and design scientific enquiry' in r['Competency (2015)']:
        seq = '3'

    if 'Content' in r['Knowledge (2015)']:
        seq += '1'
    elif 'Procedural' in r['Knowledge (2015)']:
        seq += '2'
    elif 'Epistemic' in r['Knowledge (2015)']:
        seq += '3'

    if 'Physical' in r['System (2015)']:
        seq += '1'
    elif 'Living' in r['System (2015)']:
        seq += '2'
    elif 'Earth and Space' in r['System (2015)']:
        seq += '3'
    
    q_matrix.loc[q_matrix.index[i], seq] = 1

In [None]:
q_matrix.to_csv('q_matrix_PISA15.csv', sep=';', header=True, index=False)

In [None]:
Q = q_matrix.iloc[:,1:].T
print(Q.shape)

In [None]:
test_response = os.path.join('test_response.npy')

if os.path.isfile(test_response):
    response = np.load(test_response)

# Helper Functions


In [None]:
# Restrict connection in decoder
def q_constraint(w):
    target = w * Q
    diff = w - target
    w = w * tf.cast(tf.math.equal(diff, 0), keras.backend.floatx()) 
    return w * tf.cast(tf.math.greater_equal(w, 0), keras.backend.floatx())

# Remove zeros function
def remove_zeros(arr):
  n_arr = []
  
  for j in range(num_skills): 
    for i in range(num_stats):
      if Q.iloc[j, i] != 0:
        n_arr.append(arr[j][i])
  
  return n_arr


# Variables Initialization

In [None]:
# Set stats and skills
num_stats = 99 #180
num_skills = 27 #21

intermediate_dim=40

# Number of subjects
N = response.shape[0]
# Training number
tr = cnt_df[cnt_df['CNT'] == 'BRA'].shape[0]
batch_size = response.shape[0]#50
epochs = 200

In [None]:
print(tr)
print(batch_size)

# Preparing Input Data

## Generating Abilities

In [None]:
#with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    

In [None]:
class Sampling(layers.Layer):
  """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
  
  def call(self, inputs):
    z_mean, z_log_var = inputs
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
    #epsilon = tfp.distributions.normal(0, 1)
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

class Encoder(keras.Model):
    """Maps items respone to a triplet (z_mean, z_log_var, z)."""

    def __init__(self, latent_dim=num_skills, intermediate_dim=intermediate_dim, name="encoder", **kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        
        self.dense_proj = layers.Dense(intermediate_dim, activation="tanh")
        self.dense_mean = layers.Dense(latent_dim)
        self.dense_log_var = layers.Dense(latent_dim)
        
        self.dense_mean_POP = layers.Dense(latent_dim, use_bias=False)
        self.dense_log_var_POP = layers.Dense(latent_dim, use_bias=False)
        
        self.dense_mean_added = layers.Add()
        self.dense_log_var_added = layers.Add()
        
        self.sampling = Sampling()

    def call(self, inputs):
        print(inputs.shape)

        # Items response
        x1 = self.dense_proj(inputs[:,:-1])
        z_mean = self.dense_mean(x1)
        z_log_var = self.dense_log_var(x1)

        # Population
        z_mean_pop = self.dense_mean_POP(inputs[:,-1:])
        z_log_var_pop = self.dense_log_var_POP(inputs[:,-1:])
    
        z_mean_added = self.dense_mean_added([z_mean, z_mean_pop])
        z_log_var_added = self.dense_log_var_added([z_log_var, z_log_var_pop])

        z = self.sampling((z_mean_added, z_log_var_added))
        return z_mean, z_log_var, z, z_mean_added, z_log_var_added


class Decoder(keras.Model):
    """Converts z, the encoded digit vector, back into a readable digit."""

    def __init__(self, original_dim, latent_dim=num_skills, name="decoder", **kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        #self.dense_proj = layers.Dense(latent_dim, activation="relu")
        self.dense_output = layers.Dense(original_dim
                                         , activation="sigmoid"
                                         , kernel_constraint=q_constraint
                                         #, kernel_initializer=initializers.Ones()
                                         #, bias_initializer=initializers.Zeros()
                                         )

    def call(self, inputs):
        return self.dense_output(inputs)



class VariationalAutoEncoder(keras.Model):
    """Combines the encoder and decoder into an end-to-end model for training."""

    def __init__(
        self,
        original_dim,
        intermediate_dim=intermediate_dim,
        latent_dim=num_skills,
        name="autoencoder"
    ):
        super(VariationalAutoEncoder, self).__init__(name=name)
        self.original_dim = original_dim
        self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)
        self.decoder = Decoder(original_dim, latent_dim=latent_dim)

    def call(self, inputs):
        self.z_mean, self.z_log_var, self.z, _, _ = self.encoder(inputs)
        reconstructed = self.decoder(self.z)
    
        #print(inputs.shape)
        #print(reconstructed.shape)
        return reconstructed

    # Loss function
    def vae_loss(self, input, output):
        
        #TODO: New loss function
        #func_a = self.decoder.trainable_weights
        
        cross_entropy_loss = (num_stats / 1.0) * keras.losses.binary_crossentropy(input[:,:-1], output)
        
        kl_loss = -0.5 * tf.reduce_mean(self.z_log_var - tf.square(self.z_mean) - tf.exp(self.z_log_var) + 1, axis=-1)
        
        return cross_entropy_loss + kl_loss * (1 - input[:, -1:])

    # Get weights
    def _get_weights(self):
        return self.decoder.trainable_weights

    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

# Custom Metric Function

In [None]:
# Custom Binary Accuracy Metric
# https://github.com/keras-team/keras/blob/1d12a9287cbe3f2d4cdde84918f0ef2086da9bee/keras/metrics.py#L3373

def binary_accuracy(y_true, y_pred, threshold=0.5):
    y_pred = tf.convert_to_tensor(y_pred)
    threshold = tf.cast(threshold, y_pred.dtype)
    y_pred = tf.cast(y_pred > threshold, y_pred.dtype)
    return keras.backend.mean(tf.equal(y_true[:,:-1], y_pred), axis=-1)


# Training Function

In [None]:
def train(dtrain):
    vae_q = VariationalAutoEncoder(num_stats, intermediate_dim, num_skills)

    # Optimizer
    opt = tf.keras.optimizers.Adam(learning_rate=0.005, amsgrad=True)
    #opt = tf.keras.optimizers.SGD(learning_rate=0.005)
    
    vae_q.compile(optimizer=opt, loss=vae_q.vae_loss, metrics=[binary_accuracy])

    history = vae_q.fit(dtrain,
                        dtrain,
                        epochs=epochs,
                        batch_size=batch_size,
                        shuffle=True)
    # validation_split=0.2
    
    ########################## Binary Accuracy Mean ############################
    #ba = 0
    #for value in history.history['binary_accuracy']:
    #    ba += value
    #print("Binary Accuracy: %.4f" % (ba / 25))
    ############################################################################

    encoder = vae_q.get_encoder()
    decoder = vae_q.get_decoder()
    #weights = vae_q.get_decoder().trainable_weights

    weights = vae_q._get_weights()

    discr = weights[0].numpy()
    #print(discr.shape)
    #diff = pd.DataFrame(weights[3].numpy())
    negative_diff = pd.DataFrame(np.negative(weights[1].numpy()))

    # Get latent trait predictions
    z_mean_pred, z_logvar_pred, z_pred, z_mean_added_pred, z_logvar_added_pred = encoder.predict(dtrain)
    output_pred = decoder.predict(z_pred)

    s = Sampling()
    pv = [s([z_mean_added_pred, z_logvar_added_pred]) for i in range(5)]

    output_array = [decoder.predict(i) for i in pv]

    #bern_array = sc.stats.bernoulli.rvs(output_array)


    #print(pred)
    #print(thetas_hat)

    # Total score on the test -------
    score = np.apply_over_axes(np.sum, dtrain, 1)

    #### Vectoring the matrices Thetas_hat ans discr ####
    theta_hat = np.transpose(z_mean_pred).flatten()
    #step_theta_hat = np.transpose(step_thetas_hat.numpy()).flatten()
  
    log_var_theta_hat = np.transpose(z_logvar_pred).flatten()
    #step_log_var_theta_hat = np.transpose(step_log_var_thetas_hat.numpy()).flatten()

    discr_hat = remove_zeros(discr)

    return {
        'z_mean_pred': z_mean_pred,
        'z_logvar_pred': z_logvar_pred, 
        'z': z_pred,
        'z_mean_added_pred': z_mean_added_pred,
        'z_logvar_added_pred': z_logvar_added_pred,
        'theta_hat': theta_hat,
        'logvar_theta_hat': log_var_theta_hat,
        'discr_hat': discr_hat,
        'diff': negative_diff,
        'output': output_pred,
        'weights': encoder.trainable_weights,
        'output_array': output_array
        #'bern_array': bern_array
    }
    


# Training All

In [None]:
print(response[:, -1:].shape)

In [None]:
df_train = tf.cast(response, tf.float32)
results = train(df_train)