In [1]:
import numba as nb
from numba import prange
import numpy as np
from numba.typed import List

In [2]:
# считываение из файла
def read_input(file_str: str):
    with open(file_str) as f:

        lines = f.readlines()
        lines.reverse()

        n = int(lines.pop().split()[0])
        parents = [int(s) for s in lines.pop().split(' ')]
        values = [int(s) for s in lines.pop().split(' ')]

        diseases = []
        num_of_diseases = int(lines.pop().split()[0])
        for i in range(num_of_diseases):
            dis = lines.pop()
            diseases.append([int(s) for s in dis.split(' ')])

        patients = []
        num_of_patients = int(lines.pop().split()[0])
        for i in range(num_of_patients):
            dis = lines.pop()
            patients.append([int(s) for s in dis.split(' ')])

        # return [n, relationships, values, diseases, patients]
        return {
            "n": n,
            "parents": parents,
            "values": values,
            "diseases": diseases,
            "patients": patients,
        }


data = read_input('./tests/test1')

n = data["n"]
parents = data["parents"]
values = data["values"]

diseases = data["diseases"]
diseases_numba = List(List(x) for x in diseases)

patients = data["patients"]
patients_numba = List(List(x) for x in patients)

print(n)

19830


In [3]:
from collections import defaultdict

parents_hmap = np.empty(n + 1, dtype=int)  # child: parent
children_hmap = [[] for i in range(n + 1)]  # parent: [children]
values_hmap = np.empty(n + 1, dtype=int)  # node: value
is_checked_hmap = np.empty(n + 1)  # node: value

values_hmap[1] = values[0]
is_checked_hmap[1] = False
for i in range(len(parents)):
    child_index = i + 2
    parent_index = parents[i]

    parents_hmap[child_index] = parent_index
    children_hmap[parent_index].append(child_index)
    values_hmap[child_index] = values[i + 1]

for i in range(len(children_hmap)):
    children_hmap[i] = np.asarray(children_hmap[i])

In [4]:
depth_hmap = np.empty(n + 1, dtype=int)

def fill_depth_hmap(cur, index):
    depth_hmap[cur] = index
    for child in children_hmap[cur]:
        fill_depth_hmap(child, index + 1)

fill_depth_hmap(1, 0)
print(depth_hmap)

[0 0 1 ... 2 3 4]


In [5]:
@nb.njit()
def find_common_parent(i1, i2):
    i1 = int(i1)
    i2 = int(i2)
    depth1 = depth_hmap[i1]
    depth2 = depth_hmap[i2]
    if depth1 < depth2:
        depth1, depth2 = depth2, depth1
        i1, i2 = i2, i1

    while depth1 > depth2:
        depth1 -= 1
        i1 = parents_hmap[i1]

    while i1 != i2:
        depth1 -= 1
        i1 = parents_hmap[i1]
        depth2 -= 1
        i2 = parents_hmap[i2]
    return values_hmap[i1]

In [14]:
@nb.njit(parallel=True)
def foo(patients_numba, diseases_numba):
    # result = []
    result = np.zeros(len(patients_numba))
    for i in prange(len(patients_numba)):
        patient = patients_numba[i]
        IC_LCA_sum_max = 0
        best_k = -1
        for k in range(len(diseases_numba)):
            disease = diseases_numba[k]
            IC_LCA_sum = 0
            for p in patient:
                IC_LCA_max = 0
                for d in disease:
                    IC_LCA = find_common_parent(p, d)
                    IC_LCA_max = max(IC_LCA_max, IC_LCA)
                IC_LCA_sum += IC_LCA_max
            if IC_LCA_sum > IC_LCA_sum_max:
                IC_LCA_sum_max = IC_LCA_sum
                best_k = k
        # result[i] = best_k + 1
        # print(i, best_k + 1)
        result[i] = best_k + 1
        # result.append(best_k + 1)
    return result

res = foo(patients_numba, diseases_numba)
print(res)

[3342. 3546. 2219. ... 3254. 1793. 4758.]


In [18]:
with open('output/test1.txt', 'w') as f:
    for n in res:
        f.write(f"{int(n)}\n")

In [227]:
subtree_hmap = defaultdict(list)  # parent: [all children in subtree]


def fill_subtrees(cur):
    children = children_hmap[cur]
    subtree_hmap[cur].append(cur)
    if len(children) == 0:
        return [cur]

    for child in children:
        for el in fill_subtrees(child):
            subtree_hmap[cur].append(el)
    return subtree_hmap[cur]


fill_subtrees(1)
# print(subtree_hmap)

[1,
 2,
 3,
 10087,
 12930,
 14784,
 4,
 7970,
 8600,
 5,
 4200,
 5434,
 6,
 2931,
 8471,
 7,
 3843,
 5274,
 8,
 3923,
 5888,
 9,
 615,
 10,
 9643,
 11,
 609,
 16540,
 7752,
 10636,
 17164,
 19410,
 19699,
 15435,
 12,
 1181,
 5300,
 13,
 1911,
 5306,
 14,
 4198,
 17939,
 15,
 16,
 3709,
 12307,
 15009,
 19033,
 4057,
 17123,
 17,
 4588,
 5774,
 18,
 4803,
 4804,
 10169,
 16801,
 19300,
 19,
 1172,
 10614,
 13721,
 17337,
 20,
 6862,
 7432,
 10341,
 13008,
 21,
 22,
 4345,
 14329,
 15130,
 16264,
 19044,
 4627,
 10022,
 12571,
 16838,
 23,
 1913,
 6226,
 12766,
 17045,
 24,
 484,
 5135,
 25,
 85,
 26,
 7564,
 27,
 28,
 2157,
 2837,
 29,
 5667,
 30,
 31,
 4620,
 9133,
 32,
 770,
 6867,
 33,
 1617,
 17841,
 34,
 1450,
 35,
 36,
 37,
 5965,
 9612,
 15411,
 38,
 7860,
 11592,
 17022,
 13093,
 14927,
 17218,
 14373,
 39,
 6497,
 9489,
 40,
 41,
 4185,
 17907,
 42,
 8573,
 43,
 856,
 44,
 45,
 5474,
 6699,
 46,
 3518,
 13153,
 15696,
 15889,
 47,
 1262,
 2442,
 14695,
 48,
 2731,
 10697,
 49

In [None]:
common_parent_hmap = defaultdict()


def fill_common_parent_hmap(cur):
    # subtree = subtree_hmap[cur]
    subtrees = []
    for child in children_hmap[cur]:
        subtrees.append(subtree_hmap[child])
    for subtree1 in subtrees:
        for subtree2 in subtrees:
            if subtree1 != subtree2:
                for el1 in subtree1:
                    for el2 in subtree2:
                        common_parent_hmap[el1, el2] = values_hmap[cur]
                        common_parent_hmap[el2, el1] = values_hmap[cur]
    for el in subtree_hmap[cur]:
        common_parent_hmap[cur, el] = values_hmap[cur]
        common_parent_hmap[el, cur] = values_hmap[cur]
    for child in children_hmap[cur]:
        fill_common_parent_hmap(child)


fill_common_parent_hmap(1)
# print(common_parent_hmap)

In [220]:
import numpy as np

for patient in patients:
    IC_LCA_sum_max = 0
    best_k = -1
    for k in range(len(diseases)):
        disease = diseases[k]
        IC_LCA_sum = 0
        for p in patient:
            IC_LCA_max = 0
            for d in disease:
                IC_LCA = common_parent_hmap[p, d]
                IC_LCA_max = max(IC_LCA_max, IC_LCA)
            IC_LCA_sum += IC_LCA_max
        if IC_LCA_sum > IC_LCA_sum_max:
            IC_LCA_sum_max = IC_LCA_sum
            best_k = k
        # IC_LCA_sum_max = max(IC_LCA_sum, IC_LCA_sum_max)
    print(best_k + 1)

2
1
2
2
