In [None]:
import jax
import jax.numpy as jnp
import jraph

from torchvision import datasets, models
import torch.nn as nn

import numpy as np
import pandas as pd
import cv2
import networkx as nx

import sys
import matplotlib.pyplot as plt
from IPython.display import display

from pathlib import Path
from typing import *

from logging import getLogger, basicConfig, INFO, DEBUG

basicConfig(level=INFO, stream=sys.stdout)
logger = getLogger("Manebu")
logger.setLevel(INFO)

In [None]:
list(models.resnet34().children())

In [None]:
# # utils
# def generate_int32():
#     return np.random.randint(- 2 ** 31, 2 ** 31)

# def generate_random_key(seed=None):
#     if seed is None:
#         seed = generate_int32()
#         logger.info(f"SEED:{seed}")
#     return jax.random.PRNGKey(seed)

# def convert_jraph_to_networkx(jraph_graph: jraph.GraphsTuple) -> nx.Graph:
#     nodes, edges, receivers, senders, _, _, _ = jraph_graph
#     nx_graph = nx.DiGraph()
#     if nodes is None:
#         for n in range(jraph_graph.n_node[0]):
#             nx_graph.add_node(n)
#     else:
#         for n in range(jraph_graph.n_node[0]):
#             nx_graph.add_node(n, node_feature=nodes[n])
#     if edges is None:
#         for e in range(jraph_graph.n_edge[0]):
#             nx_graph.add_edge(int(senders[e]), int(receivers[e]))
#     else:
#         for e in range(jraph_graph.n_edge[0]):
#             nx_graph.add_edge(int(senders[e]), int(receivers[e]), edge_feature=edges[e])

#     return nx_graph

# def draw_jraph_by_networkx(jraph_graph: jraph.GraphsTuple, **kwargs) -> None:
#     nx_graph = convert_jraph_to_networkx(jraph_graph)
#     node_feature = {node:node_attr['node_feature'] for node, node_attr in nx_graph.nodes.items()}
#     pos = kwargs.get("pos", nx.spring_layout(nx_graph))
#     return nx.draw(nx_graph, pos=pos, with_labels=kwargs.get("with_labels", True),
#             node_size=kwargs.get("node_size", 400), font_color=kwargs.get("font_color", "black"))

# generate_random_key()

In [None]:
DATASET_DIR = Path("E:\\Dataset")
# print([d.name for d in DATASET_DIR.iterdir()])
DATASET_DIR /= "open-images-v6"

train_0_dir = tuple((DATASET_DIR / "train" / "data" / "train_0").iterdir())
len(train_0_dir)

In [None]:
sample_image = cv2.imread(str(train_0_dir[3000]))
sample_image = cv2.cvtColor(sample_image, cv2.COLOR_BGR2RGB)

In [None]:
class SimpleImageAttention(nn.Module):
    def __init__(self, num):
        super().__init__()
        def _conv2d_block():
            return [
                nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
                nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
            ]
        
        self.features = nn.Sequential(
            *_conv2d_block()
        )
        
        self.attention_conv = nn.Sequential(
            nn.Conv2d(512, 1, 1), nn.Sigmoid()
        )
        
        self.fc = nn.Dropout(0.5)
        self._mask = None
    
    def forward(self, x):
        x = self.features(x)

        attn = self.attn_conv(x)  # [B, 1, H, W]
        B, _, H, W = attn.shape
        self.mask_ = attn.detach().cpu()

        x = x * attn
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.reshape(B, -1)
        
        return self.fc(x)
    
    def save_attention_mask(self, x, path):
        B = x.shape[0]
        self.forward(x)
        x = x.cpu() * torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
        x = x + torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
        fig, axs = plt.subplots(4, 2, figsize=(6, 8))
        plt.axis('off')
        for i in range(4):
            axs[i, 0].imshow(x[i].permute(1, 2, 0))
            axs[i, 1].imshow(self.mask_[i][0])
        plt.show()
        plt.close()

In [None]:
model = SimpleImageAttention(10)

In [None]:
class NodeBase:
    def __init__(*args, **kwargs):
        pass
    
    
class PatchNode(NodeBase):
    def __init__(*args, **kwargs):
        super().__init__(*args, **kwargs)

        
class ImagePatchNode(PatchNode):
    def __init__(lab_color: tuple[int, int, int], *args, **kwargs):
        pass


class GraphBase:
    def __init__(*args, **kwargs):
        pass
    
class ImageGraph(GraphBase):
    VALID_COLOR_TYPE = Literal["BGR", "RGB", "LAB"]
    
    COLOR_CONVERT_CONSTANT = {
        "BGR": {
            "RGB": cv2.COLOR_BGR2RGB,
            "LAB": cv2.COLOR_BGR2LAB
        },
        "RGB": {
            "BGR": cv2.COLOR_RGB2BGR,
            "LAB": cv2.COLOR_RGB2LAB
        },
        "LAB": {
            "RGB": cv2.COLOR_LAB2RGB,
            "BGR": cv2.COLOR_LAB2BGR
        },
    }
    
    
    def __init__(self, data: np.ndarray, color_mode: VALID_COLOR_TYPE = "BGR", *args, **kwargs):
        self.origin_data: np.ndarray = data
        self.color_mode: ImageGraph.VALID_COLOR_TYPE  = "BGR"
        self.patches: list[np.ndarray] = []
        if color_mode != "BGR":
            data = cv2.cvtColor(data, ImageGraph.COLOR_CONVERT_CONSTANT["BGR"][color_mode])
        super().__init__(*args, **kwargs)
    
    def __repr__(self):
        return f"<ImageGraph Shape:{self.origin_data.shape} CMode={self.color_mode}>"
    

    @staticmethod
    def open(file_path: Path, color_mode: VALID_COLOR_TYPE = "RGB", *args, **kwargs):
        data = cv2.imread(str(file_path))
        return ImageGraph(data, color_mode=color_mode)
    
    
    @property
    def h(self):
        return self.origin_data.shape[0]

    @property
    def w(self):
        return self.origin_data.shape[1]
    
    @property
    def RGB(self):
        data = self.origin_data.copy()
        if self.color_mode != "RGB":
            data = cv2.cvtColor(data, ImageGraph.COLOR_CONVERT_CONSTANT[self.color_mode]["RGB"])
        return data
    
    @property
    def LAB(self):
        data = self.origin_data.copy()
        if self.color_mode != "LAB":
            data = cv2.cvtColor(data, ImageGraph.COLOR_CONVERT_CONSTANT[self.color_mode]["LAB"])
        return data
    
    def generate_patches(self):
        alpha = 0.5
        h, w, c = self.origin_data.shape
        
    
    
    def imshow(self):
        plt.imshow(self.RGB)
        plt.show()

In [None]:
img = ImageGraph.open(train_0_dir[3000])
img.imshow()
img.generate_patches()
plt.imshow(img.RGB)

In [None]:
# @dataclass
# class GraphParams:
#     nodes: jnp.ndarray
#     edges: jnp.ndarray

#     senders: jnp.ndarray
#     receivers: jnp.ndarray

#     n_node: int
#     n_edge: int


# class ManebuCore(object):
#     def __init__(self, init_params: GraphParams):
#         self.graph = jraph.GraphsTuple(
#             nodes=init_params.nodes, edges=init_params.edges,
#             n_node=init_params.n_node, n_edge=init_params.n_edge,
#             senders=init_params.senders, receivers=init_params.receivers,
#             globals=dict(init_params=init_params)
#         )
    
#     def convert_to_networkx(self) -> nx.Graph:
#         nodes, edges, receivers, senders, _, _, _ = self.graph
 
#         nx_graph = nx.DiGraph()
#         if nodes is None:
#             for n in range(self.graph.n_node[0]):
#                 nx_graph.add_node(n)
#         else:
#             for n in range(self.graph.n_node[0]):
#                 nx_graph.add_node(n, node_feature=nodes[n]) # create node_feature
 
#         if edges is None:
#             for e in range(self.graph.n_edge[0]):
#                 nx_graph.add_edge(int(senders[e]), int(receivers[e]))
#         else:
#             for e in range(self.graph.n_edge[0]):
#                 nx_graph.add_edge(int(senders[e]), int(receivers[e]), edge_feature=edges[e])

#         return nx_graph

    
#     def draw(self, **kwargs) -> None:
#         nx_graph = self.convert_to_networkx()
        
#         node_feature = {
#            node:node_attr['node_feature'] for node, node_attr in nx_graph.nodes.items()
#         }
#         nfdf = pd.DataFrame.from_dict(node_feature)
#         print(nfdf)
        
        
# #         return nx.draw(nx_graph, pos=pos, with_labels=kwargs.get("with_labels", True),
# #                 node_size=kwargs.get("node_size", 400), font_color=kwargs.get("font_color", "black"))


In [None]:
# nodes = [[0, 1], 2, 4]
# n_node = len(nodes)

# p = GraphParams(
#     nodes=nodes, edges=[3],
#     n_node=[n_node], n_edge=[1],
#     senders=[0], receivers=[2],
# )
# core = ManebuCore(p)
# core.draw()

In [8]:
# @dataclass
# class GraphParams:
#     nodes: jnp.ndarray
#     edges: jnp.ndarray

#     senders: jnp.ndarray
#     receivers: jnp.ndarray

#     n_node: int
#     n_edge: int


# class ManebuCore(object):
#     def __init__(self, init_params: GraphParams):
#         self.graph = jraph.GraphsTuple(
#             nodes=init_params.nodes, edges=init_params.edges,
#             n_node=init_params.n_node, n_edge=init_params.n_edge,
#             senders=init_params.senders, receivers=init_params.receivers,
#             globals=dict(init_params=init_params)
#         )
    
#     def convert_to_networkx(self) -> nx.Graph:
#         nodes, edges, receivers, senders, _, _, _ = self.graph
 
#         nx_graph = nx.DiGraph()
#         if nodes is None:
#             for n in range(self.graph.n_node[0]):
#                 nx_graph.add_node(n)
#         else:
#             for n in range(self.graph.n_node[0]):
#                 nx_graph.add_node(n, node_feature=nodes[n]) # create node_feature
 
#         if edges is None:
#             for e in range(self.graph.n_edge[0]):
#                 nx_graph.add_edge(int(senders[e]), int(receivers[e]))
#         else:
#             for e in range(self.graph.n_edge[0]):
#                 nx_graph.add_edge(int(senders[e]), int(receivers[e]), edge_feature=edges[e])

#         return nx_graph

    
#     def draw(self, **kwargs) -> None:
#         nx_graph = self.convert_to_networkx()
        
#         node_feature = {
#            node:node_attr['node_feature'] for node, node_attr in nx_graph.nodes.items()
#         }
#         nfdf = pd.DataFrame.from_dict(node_feature)
#         print(nfdf)
        
        
# #         return nx.draw(nx_graph, pos=pos, with_labels=kwargs.get("with_labels", True),
# #                 node_size=kwargs.get("node_size", 400), font_color=kwargs.get("font_color", "black"))


In [9]:
# nodes = [[0, 1], 2, 4]
# n_node = len(nodes)

# p = GraphParams(
#     nodes=nodes, edges=[3],
#     n_node=[n_node], n_edge=[1],
#     senders=[0], receivers=[2],
# )
# core = ManebuCore(p)
# core.draw()