#### Parsing Keepaway weights files

In [1]:
from GAME.utils.config import config
import numpy as np
import sys
import os
import math
import pandas as pd
import random
import copy

In [2]:
def parse_weights(weights_file):
    with open(weights_file, 'rb') as f:
        rl_mem_size = 1048576
        weights = np.fromfile(f, dtype=np.double, count=rl_mem_size)
        
        offset = rl_mem_size * 8
        f.seek(offset, os.SEEK_SET)
        m = np.fromfile(f, dtype=np.int_, count=1)

        offset += 4
        f.seek(offset, os.SEEK_SET)
        safe = np.fromfile(f, dtype=np.intc, count=1)

        offset += 4
        f.seek(offset, os.SEEK_SET)
        calls = np.fromfile(f, dtype=np.int_, count=1)

        offset += 4
        f.seek(offset, os.SEEK_SET)
        clearhits = np.fromfile(f, dtype=np.int_, count=1)

        offset += 4
        f.seek(offset, os.SEEK_SET)
        collisions = np.fromfile(f, dtype=np.int_, count=1)

        offset += 4
        f.seek(offset, os.SEEK_SET)
        data = np.fromfile(f, dtype=np.int_, count=m[0])

    return weights, m, safe, calls, clearhits, collisions, data

class collision_table:
    def __init__(self, m, safe, calls, clearhits, collisions, data):
        self.m = m
        self.safe = safe
        self.calls = calls
        self.clearhits = clearhits
        self.collisions = collisions
        self.data = data

def hash_UNH(ints:list, num_ints:int, m:int, increment:int, rndseq:list):
    i = 0
    index = 0
    sum = 0

    for i in range(num_ints):
        index = ints[i]
        index = index + (increment * i)
        index = index % 2048
        while index < 0:
            index = index + 2048
        sum = sum + int(rndseq[int(index)])
    
    index = int(sum % m)
    while index < 0:
        index = index + m
    
    return int(index)
    
def hash(ints:list, num_ints:list, ct:collision_table, rndseq:list):
    j = 0
    ccheck = 0

    ct.calls = ct.calls + 1
    j = hash_UNH(ints, num_ints, ct.m, 449, rndseq)
    ccheck = hash_UNH(ints, num_ints, sys.maxsize, 457, rndseq)
    if ccheck == ct.data[j]:
        ct.clearhits = ct.clearhits + 1
    elif ct.data[j] == -1:
        ct.clearhits = ct.clearhits + 1
        ct.data[j] = ccheck
    elif ct.safe == 0:
        ct.collisions = ct.collisions + 1
    else:
        h2 = 1 + 2 * hash_UNH(ints, num_ints, int(sys.maxsize / 4), 449, rndseq)
        i = 0
        i += 1
        while i:
            ct.collisions = ct.collisions + 1
            j = int((j + h2) % (ct.m))
            if i > ct.m:
                raise ValueError("Out of memory")
            if ccheck == ct.data[j]:
                break
            if ct.data[j] == -1:
                ct.data[j] = ccheck
                break
            i += 1
    return int(j)

def GetTiles(num_tilings:int, ctable:collision_table, floats:list, num_floats:int, ints:list, num_ints:int, rndseq:list):
    tiles = np.zeros(shape = (1, num_tilings), dtype = int)[0]
    i = 0
    j = 0
    qstate = np.zeros(shape = (1, 20))[0]
    base = np.zeros(shape = (1, 20))[0]
    coordinates = np.zeros(shape = (1, 20 * 2 + 1))[0]
    num_coordinates = num_floats + num_ints + 1

    for i in range(num_ints):
        coordinates[num_floats + 1 + i] = ints[i]

    for i in range(num_floats):
        qstate[i] = int(math.floor(floats[i] * num_tilings))
        base[i] = 0

    for j in range(num_tilings):
        for i in range(num_floats):
            coordinates[i] = qstate[i] - ((qstate[i] - base[i]) % num_tilings)
            base[i] = base[i] + 1 + (2*i)
        
        coordinates[num_floats] = j
        coordinates = coordinates[:4]
        # print(coordinates)
        tiles[j] = hash(coordinates, num_coordinates, ctable, rndseq)

    return tiles

def GetTiles1(nt:int, ct:collision_table, f1:float, h1:int, h2:int, rndseq:list):
    f_tmp_arr = [f1]
    i_tmp_arr = [h1, h2]
    return GetTiles(nt, ct, f_tmp_arr, 1, i_tmp_arr, 2, rndseq)

In [3]:
config_data = config()
# variables to identify the task
target_task_name = '4v3'
src_state_var_names = config_data['3v2_state_names']
src_action_names = config_data['3v2_action_names']
src_action_values = config_data['3v2_action_values']
target_state_var_names = config_data['4v3_state_names']
target_action_names = config_data['4v3_action_names']
target_action_values = config_data['4v3_action_values']
current_state_3v2_col_names = config_data['3v2_current_state_transition_df_col_names'.format(target_task_name)]
current_state_4v3_col_names = config_data['{}_current_state_transition_df_col_names'.format(target_task_name)]

In [4]:
src_task_data_folder_and_filename = os.path.join(config_data['data_path'], 'keepaway', "keepaway_3v2_transitions.csv")
trans_3v2_df = pd.read_csv(src_task_data_folder_and_filename, index_col = False)

In [4]:
k3v2_weights_folder = os.path.join(config_data['logs_path'], '202211031646-UbuntuXenial-3v2-weights')
k3v2_weights_file = os.path.join(k3v2_weights_folder, 'k1-weights.dat')

k4v3_weights_folder = os.path.join(config_data['logs_path'], '202211031855-UbuntuXenial-4v3-weights')
k4v3_weights_file = os.path.join(k4v3_weights_folder, 'k1-weights.dat')

In [26]:
# rndseq = np.zeros((1, 2048), dtype=int)[0]
# for k in range(len(rndseq)):
#     for i in range(4):
#         rndseq[k] = (rndseq[k] << 8) | (int(random.random() * 1000) & 0xff)

In [5]:
weights3v2, m3v2, safe3v2, calls3v2, clearhits3v2, collisions3v2, data3v2 = parse_weights(k3v2_weights_file)

In [6]:
data3v2

array([-730175833,          0,   15611312, ...,         -1,         -1,
               -1])

In [14]:
# 3v2
weights3v2, m3v2, safe3v2, calls3v2, clearhits3v2, collisions3v2, data3v2 = parse_weights(k3v2_weights_file)
ct3v2 = collision_table(m3v2, safe3v2, calls3v2, clearhits3v2, collisions3v2, data3v2)

# 4v3
weights4v3, m4v3, safe4v3, calls4v3, clearhits4v3, collisions4v3, data4v3 = parse_weights(k4v3_weights_file)
for weight_idx in range(len(weights4v3)):
    weights4v3[weight_idx] = 0
ct4v3 = collision_table(m4v3, safe4v3, calls4v3, clearhits4v3, collisions4v3, data4v3)

rndseq = np.random.rand(2048) * m3v2

state_mapping = [0, 12, 9, 1, 4, 4, 4, 1, 11, 9, 9, 9, 9, 4, 9, 9, 9, 11, 9]
action_mapping = [0, 0, 0, 0]

for _, row in trans_3v2_df.iterrows():
    for col_idx, col in enumerate(src_state_var_names):
        current_feature_val = float(row['Current-{}'.format(col)])
        current_action = int(row['Current-action'])
        tiles = GetTiles1(32, ct3v2, current_feature_val, current_action, col_idx, rndseq)
        activated_weights = weights3v2[tiles]

        # transfer
        target_state_idx = state_mapping[col_idx]
        target_state_val = current_feature_val
        target_action = action_mapping[current_action]
        tiles = GetTiles1(32, ct4v3, target_state_val, target_action, target_state_idx, rndseq)
        weights4v3[tiles] = copy.deepcopy(activated_weights)

avg_weights = np.mean([w for w in weights3v2 if w != 0])
for w_idx in range(len(weights4v3)):
    if weights4v3[w_idx] == 0.0:
        # print(weights4v3[w_idx])
        weights4v3[w_idx] = avg_weights

with open('k1-weights.dat', 'wb') as f:
    for weight in weights4v3:
        f.write(weight)
    for item in m4v3:
        f.write(item)
    for item in safe4v3:
        f.write(item)
    for item in calls4v3:
        f.write(item)
    for item in clearhits4v3:
        f.write(item)
    for item in collisions4v3:
        f.write(item)
    for item in data4v3:
        f.write(item)

In [58]:
avg_weights = np.mean([w for w in weights3v2 if w != 0])

In [59]:
avg_weights

0.14120184468217797

In [62]:
for w_idx in range(len(weights4v3)):
    if weights4v3[w_idx] == 0.0:
        # print(weights4v3[w_idx])
        weights4v3[w_idx] = avg_weights

In [None]:
for w_idx in range(len(weights4v3)):
    if weights4v3[w_idx] != avg_weights:
        print(weights4v3[w_idx])

In [17]:
with open('k4-weights.dat', 'wb') as f:
    for weight in weights4v3:
        f.write(weight)
    for item in m4v3:
        f.write(item)
    for item in safe4v3:
        f.write(item)
    for item in calls4v3:
        f.write(item)
    for item in clearhits4v3:
        f.write(item)
    for item in collisions4v3:
        f.write(item)
    for item in data4v3:
        f.write(item)

In [None]:
weights, m, safe, calls, clearhits, collisions, data = parse_weights('k2-weights.dat')