In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as T
import torchvision.datasets as TD

import torch_geometric.nn as gnn
import torch_geometric.transforms as TgT
from torch_geometric.utils.convert import to_networkx
from torch_geometric.data import Data, HeteroData, Batch
from torch_geometric.loader import DataLoader, LinkNeighborLoader


import networkx as nx
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm
import gc


from extentions.utils import log
from extentions.utils.pathes import DATASET_PATH
from extentions.string_ex import Japanese


__device__ = torch.device("cuda" if torch.cuda.is_available() else "cpu")
__file__ = str(Path(locals().get("__file__", "main.py")).resolve())
logger = log.initialize_logger(__file__)
logger.info(f"{__file__=}")
__root__ = Path(__file__).parent
print(f"{__device__=}\n{__root__=}\n{__file__=}")
VALID_CHARS = Japanese.ENGLISH
print(f"{VALID_CHARS=}")

[37m[INFO] - "__file__='/home/n4okins/repositories/research_2/try/0711/main.py'" -  3346724717 : <module> : L31 (3346724717.py | /tmp/ipykernel_791/3346724717.py) [2023-07-11 07:50:11 AM][0m
__device__=device(type='cuda')
__root__=PosixPath('/home/n4okins/repositories/research_2/try/0711')
__file__='/home/n4okins/repositories/research_2/try/0711/main.py'
VALID_CHARS='0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c'


In [2]:
def plot(graph: Data, image: torch.Tensor, labels: list[str]):
    pos = {
        i: (graph.pos[i, 0], graph.pos[:, 1].max() - graph.pos[i, 1])
        for i in range(graph.num_nodes)  # type: ignore
    }
    node_color = graph.x
    fig, ax = plt.subplots(1, 2, figsize=(16, 7))

    ax[0].imshow(image)
    ax[0].axis("off")

    logger.info(f"Draw graph start ... | {graph=}")

    nx.draw_networkx(
        to_networkx(graph),
        pos=pos,
        node_color=node_color,
        node_size=36,
        with_labels=False,
        arrowsize=3,
        ax=ax[1],
    )
    ax[1].axis("off")
    fig.set_tight_layout(True)
    fig.suptitle(labels[0])
    plt.show()

In [8]:
def make_dataloader(path: Path):
    yield from (torch.load(f) for f in path.iterdir())


def make_hetero_data(train_data: HeteroData) -> HeteroData:
    data = HeteroData()
    data["char"].node_id = torch.arange(len(VALID_CHARS))
    data["image"].node_id = torch.arange(len(train_data["images"].id))
    data["image"].x = train_data["images"].image
    data["char", "caption", "image"].edge_index = torch.stack(
        [
            train_data["images", "captions", "chars"].edge_index[1],
            train_data["images", "captions", "chars"].edge_index[0],
        ],
        dim=0,
    )
    data = TgT.ToUndirected()(data)
    return data  # type: ignore


def make_train_val_test(data: HeteroData):
    # 学習・評価用のデータ分割
    transform = TgT.RandomLinkSplit(
        num_val=0.1,
        num_test=0.1,
        disjoint_train_ratio=0.3,
        neg_sampling_ratio=0,
        add_negative_train_samples=False,
        edge_types=("char", "caption", "image"),
        rev_edge_types=("image", "rev_caption", "char"),
    )
    train_data, val_data, test_data = transform(data)  # type: ignore
    return train_data, val_data, test_data


def make_link_neighbor_loader(
    data: HeteroData,
    batch_size: int = 16,
    shuffle: bool = True,
    neg_sampling_ratio: int = 2,
):
    edge_label_index = data["char", "caption", "image"].edge_label_index
    edge_label = data["char", "caption", "image"].edge_label
    loader = LinkNeighborLoader(
        data=data,
        num_neighbors=[20, 10],
        neg_sampling_ratio=neg_sampling_ratio,
        edge_label_index=(("char", "caption", "image"), edge_label_index),
        edge_label=edge_label,
        batch_size=batch_size,
        shuffle=shuffle,
    )
    return loader


for train_data in make_dataloader(DATASET_PATH.COCO_GRAPH_DATA / "train"):
    data = make_hetero_data(train_data)
    batch_train, batch_val, batch_test = make_train_val_test(data)
    batch_train_neighbor_loader = make_link_neighbor_loader(batch_train)
    batch_val_neighbor_loader = make_link_neighbor_loader(
        batch_val, shuffle=False, neg_sampling_ratio=1, batch_size=32
    )
    print(len(batch_train_neighbor_loader))
    break
# plot(data) # about 1min

1014
