In [25]:
from typing import List

from config import Config, EncoderConfig


def get_worker_assignments(
    encoders: List[EncoderConfig],
    worker_id: int,
    worker_total: int,
    dry_run: bool = False,
) -> List[dict]:
    """Get all encoders that would be processed by this worker.

    Returns:
        List of dicts containing assignment details
    """
    # number of workers per encoder
    subworker_total = worker_total // len(encoders)
    print(f"Subworker total: {subworker_total}")
    assignments = []
    for i, encoder_cfg in enumerate(encoders):
        print(f"Looking at encoder {i} of {len(encoders)}")
        if i % worker_total == worker_id % len(
            encoders
        ):  # should we work on this encoder?
            print(f"We should work on {i}")
            subworker_id = worker_id // len(
                encoders
            )  # what is our subworker id for this encoder?
            if subworker_id < subworker_total:
                assignments.append(
                    {
                        "encoder_id": i,
                        "encoder_name": encoder_cfg.name,
                        "subworker_id": subworker_id,
                        "subworker_total": subworker_total,
                    }
                )
            # elif dry_run:
            #     assignments.append(
            #         {
            #             "encoder_name": encoder_name,
            #             "dataset_name": dataset_name,
            #             "dropped": True,
            #             "reason": f"Incomplete set: subworker {subworker_id} of {subworker_total}",
            #             "num_combinations": len(combinations),
            #         }
            #     )

    print(f"Subworker got {len(assignments)} assignments")
    return assignments

In [26]:
config = Config()

In [34]:
get_worker_assignments(config.encoders, 5, 6)

Subworker total: 2
Looking at encoder 0 of 3
Looking at encoder 1 of 3
Looking at encoder 2 of 3
We should work on 2
Subworker got 1 assignments


[{'encoder_id': 2,
  'encoder_name': 'local_commonpool_s_s13m_b4k_2',
  'subworker_id': 1,
  'subworker_total': 2}]