In [32]:
import heapq

class Ball:
    def __init__(self, weight, pos):
        self.weight = weight
        self.pos = pos

    def can_eat(self, w_i):
        return 2 * self.weight > w_i

    def eat(self, w_i):
        if self.can_eat(w_i):
            self.weight += w_i
            return True 
        return False
    
class State:
    def __init__(self, pos, weight, eaten, time):
        self.pos = pos
        self.weight = weight
        self.eaten = frozenset(eaten)
        self.time = time
    
    def __hash__(self):
        return hash((self.pos, self.weight, self.eaten))
    
    def __eq__(self, other):
        return (self.pos == other.pos and 
                self.weight == other.weight and 
                self.eaten == other.eaten)

    def __lt__(self, other):
        return self.time < other.time
        
def manhattan_distance(pos_a, pos_b):
    return abs(pos_a[0] - pos_b[0]) + abs(pos_a[1] - pos_b[1])

class Puzzle:
    def __init__(self, objects_params, w_0):
        self.objects = [(x, y, w) for x, y, w in objects_params]
        self.initial_weight = w_0
        self.n_objects = len(self.objects)
    
    def get_eatable_objects(self, weight, eaten):
        eatable = []
        for i, (x, y, w) in enumerate(self.objects):
            if i not in eaten and  weight > 2 * w:
                eatable.append(i)
        return eatable
    
    def heuristic(self, state):
        eatable = self.get_eatable_objects(state.weight, state.eaten)
        if not eatable:
            return 0
        min_dist = float('inf')
        for obj_idx in eatable:
            x, y, w = self.objects[obj_idx]
            dist = manhattan_distance(state.pos, (x, y))
            min_dist = min(min_dist, dist + 1)
        return min_dist
    
    def get_max_possible_weight(self, current_weight, eaten):
        max_weight = current_weight
        remaining_objects = []
        for i, (x, y, w) in enumerate(self.objects):
            if i not in eaten:
                remaining_objects.append((w, i))
        remaining_objects.sort()
        for obj_weight, obj_idx in remaining_objects:
            if max_weight > 2 * obj_weight:
                max_weight += obj_weight
        return max_weight
    
    def solve(self):
        initial_state = State((0, 0), self.initial_weight, set(), 0)
        pq = [(0, initial_state, [])]
        visited = {}

        best_weight = self.initial_weight
        best_solution = []
        best_time = 0

        while pq:
            f_score, current_state, path = heapq.heappop(pq)
            state_key = (current_state.pos, current_state.weight, current_state.eaten)

            if state_key in visited and visited[state_key] <= current_state.time:
                continue
            visited[state_key] = current_state.time

            if current_state.weight > best_weight:
                best_weight = current_state.weight
                best_solution = path.copy()
                best_time = current_state.time
            elif current_state.weight == best_weight and current_state.time < best_time:
                best_solution = path.copy()
                best_time = current_state.time
            
            eatable = self.get_eatable_objects(current_state.weight, current_state.eaten)
            if not eatable:
                continue

            for obj_idx in eatable:
                x, y, w = self.objects[obj_idx]
                travel_time = manhattan_distance(current_state.pos, (x, y))
                eat_time = 1
                total_time = current_state.time + travel_time + eat_time

                new_eaten = current_state.eaten | {obj_idx}
                new_weight = current_state.weight + w
                new_state = State((x, y), new_weight, new_eaten, total_time)

                h_score = self.heuristic(new_state)
                f_score = total_time + h_score
                new_path = path + [obj_idx]
                heapq.heappush(pq, (f_score, new_state, new_path))
        
        return best_solution

In [33]:
def parse_input(path):

    with open(path) as f:
        it = iter(map(int, f.read().split()))
    T = next(it)
    cases = []
    for _ in range(T):
        n, W = next(it), next(it)
        items = [(next(it), next(it), next(it)) for _ in range(n)]
        cases.append((n, W, items))
    return cases

In [34]:
test = parse_input('Katamari\example.in')
for case in test:
    puzzle = Puzzle(case[2], case[1])
    solution = puzzle.solve()
    print(case)
    print(solution)


(3, 100, [(1, 1, 49), (-1, -1, 50), (2, 3, 95)])
[0, 1, 2]
(2, 100, [(1, 0, 100), (2, 0, 40)])
[1]


  test = parse_input('Katamari\example.in')
