In [None]:
from dataclasses import dataclass


@dataclass
class ModelDefinition:
    dataset: Dataset
    dataset_args: dict
    split_dataset: bool
    iterator: GameBatchIterator
    image_loader: ImageLoader
    bounding_box_loader: ImageLoader
    sender: Module
    sender_args: dict
    receiver: Module
    receiver_args: dict
    loss_function: Callable


class CoordinatePredictorGameDataset(CoordinatePredictorDataset):
    def __getitem__(self, index):
        with h5py.File(self.file, "r") as f:
            return CoordinatePredictorSample(
                image_id=str(f["image_id"][index], "utf-8"),
                image=load_tensor(f["image"][index]),
                target_pixels=load_tensor(f["target_pixels"][index]),
                target_region=load_tensor(f["target_region"][index]),
                attribute_tensor=load_tensor(f["attribute_tensor"][index]),
                locations=load_tensor(f["locations"][index]),
                masked_image=load_tensor(f["masked_image"][index]),
                bounding_boxes=load_tensor(f["bounding_boxes"][index]),
            )

@dataclass
class CoordinatePredictorSample:
    image_id: str
    image: torch.Tensor

    # target
    target_pixels: torch.Tensor
    target_region: torch.Tensor

    # addtional (optional) information
    attribute_tensor: torch.Tensor = torch.tensor(0)
    locations: torch.Tensor = torch.tensor(0)
    masked_image: torch.Tensor = torch.tensor(0)
    bounding_boxes: torch.Tensor = torch.tensor(0)


def load_tensor(data):
    match type(data):
        case torch.Tensor:
            return data
        case numpy.ndarray:
            return torch.from_numpy(data)
        case builtins.list:
            match type(data[0]):
                case builtins.str:
                    return data
                case _:
                    return torch.stack([load_tensor(d) for d in data])
        case _:
            return torch.tensor(data)

class SingleObjectImageMasker(ImageMasker):
    def get_masked_image(self, image, scene, target_object):
        masked_image = image.copy()
        MASK_SIZE = masked_image.size[0] / 10
        x_center, y_center, _ = scene["objects"][target_object]["pixel_coords"]
        pixels = masked_image.load()

        for i, j in itertools.product(
            range(masked_image.size[0]), range(masked_image.size[1])
        ):
            if (
                i < x_center - MASK_SIZE
                or i > x_center + MASK_SIZE
                or j < y_center - MASK_SIZE
                or j > y_center + MASK_SIZE
            ):
                pixels[i, j] = (0, 0, 0)
            else:
                pixels[i, j] = (255, 255, 255)

        return masked_image

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"

class FeatureImageLoader(ImageLoader):
    def __init__(self, feature_file, image_dir) -> None:
        self.image_dir = image_dir
        self.feature_file = feature_file

        with h5py.File(feature_file, "r") as f:
            feature_data_set = f["features"]
            self.image_size = feature_data_set.attrs["image_size"]

    def get_image(self, image_id):
        image_index = int(image_id[-6:])

        with h5py.File(self.feature_file, "r") as f:
            feature_data_set = f["features"]
            features = feature_data_set[image_index]

        image = Image.open(os.path.join(self.image_dir, image_id + ".png")).convert(
            "RGB"
        )

        return image, torch.from_numpy(features), self.image_size

class AttentionPredictorGameBatchIterator(GameBatchIterator):
    def __init__(self, loader, batch_size, n_batches, train_mode, seed) -> None:
        self.loader = loader
        self.batch_size = batch_size
        self.n_batches = n_batches
        self.batches_generated = 0
        self.train_mode = train_mode
        self.random_seed = random.Random(seed)

    def __next__(self):
        if self.batches_generated > self.n_batches:
            raise StopIteration()

        batch_data = self.get_batch()
        self.batches_generated += 1
        return batch_data

    def get_batch(self):
        sampled_indices = self.random_seed.sample(
            range(len(self.loader.dataset)), self.batch_size
        )
        samples: list[CoordinatePredictorSample] = [
            self.loader.dataset[i] for i in sampled_indices
        ]

        sender_inputs = []
        target_regions = []
        receiver_inputs = []
        masked_images = []
        attibute_tensors = []
        image_ids = []

        for sample in samples:
            sender_inputs.append(sample.image)
            target_regions.append(sample.target_region)

            receiver_inputs.append(sample.image)
            masked_images.append(sample.masked_image)
            attibute_tensors.append(sample.attribute_tensor)
            image_ids.append(int(sample.image_id[-6:]))

        return (
            torch.stack(sender_inputs),
            torch.stack(target_regions),
            torch.stack(receiver_inputs),
            {
                "masked_image": torch.stack(masked_images),
                "attribute_tensor": torch.stack(attibute_tensors),
                "image_id": torch.tensor(image_ids),
            },
        )

class MaskedCoordinatePredictorSender(nn.Module):
    """
    Output:
     - x and y coordinates of target object

    Input:
     - image
     - attributes (shape, size, color)
     - center coordinates of all objects
    """

    def __init__(
        self,
        image_encoder: ImageEncoder,
        masked_image_encoder: ImageEncoder,
        sender_image_embedding: int,
        sender_hidden,
        *_args,
        **_kwargs,
    ) -> None:
        super().__init__()

        self.image_encoder = image_encoder
        self.masked_image_encoder = masked_image_encoder
        self.reduction = nn.Sequential(
            nn.Flatten(), nn.LazyLinear(sender_image_embedding)
        )

        self.linear = nn.LazyLinear(sender_hidden)

    def forward(self, x, aux_input):
        image = x
        masked_image = aux_input["masked_image"]

        reduced = self.image_encoder(image)
        reduced_masked_image = self.masked_image_encoder(masked_image)

        concatenated = torch.cat(
            (reduced, reduced_masked_image),
            dim=1,
        )
        reduced = self.reduction(concatenated)

        hidden = self.linear(reduced)

        return hidden


"masked_attention_predictor": ModelDefinition(
        dataset=CoordinatePredictorGameDataset,
        dataset_args={
            "image_masker": SingleObjectImageMasker(),
            "number_regions": 14,
        },
        split_dataset=False,
        image_loader=FeatureImageLoader,
        bounding_box_loader=None,
        iterator=AttentionPredictorGameBatchIterator,
        sender=MaskedCoordinatePredictorSender,
        sender_args={
            "image_encoder": ClevrImageEncoder(
                feature_extractor=DummyFeatureExtractor(), max_pool=True
            ),
            "masked_image_encoder": ClevrImageEncoder(
                feature_extractor=ResnetFeatureExtractor(
                    pretrained=True,
                    avgpool=False,
                    fc=False,
                    fine_tune=False,
                    number_blocks=3,
                ),
                max_pool=True,
            ),
        },
        receiver=AttentionPredictorReceiver,
        receiver_args={
            "image_encoder": ClevrImageEncoder(
                feature_extractor=DummyFeatureExtractor(), max_pool=False
            ),
        },
        loss_function=attention_loss,
    ),