In [1]:
from typing import cast

import networkx as nx
import numpy as np
import psycopg2
import torch
from dotenv import dotenv_values
from locus.data.QuadTree import CellState
from locus.models.dataloader import LDoGIDataLoader
from locus.models.dataset import LDoGIDataset
from locus.utils.paths import PROCESSED_DATA_DIR, PROJECT_ROOT, SQL_DIR
from networkx import DiGraph
from torch.utils.data import BatchSampler, DataLoader, RandomSampler


In [2]:
config = dotenv_values(PROJECT_ROOT / ".env")

conn = psycopg2.connect(
    host=config["DB_HOST"],
    port=config["DB_PORT"],
    dbname=config["DB_NAME"],
    user=config["DB_USER"],
    password=config["DB_PASSWORD"],
)
cur = conn.cursor()

In [3]:
with open(SQL_DIR / "select_max_id.sql") as f:
    cur.execute(f.read())

# Retrieve query results
max_id = cur.fetchall()[0][0]
max_id

4233900

In [4]:
QUADTREE = "qt_min50_max2000_df1.gml"
BATCH_SIZE = 32

In [5]:
G = cast(DiGraph, nx.read_gml(PROCESSED_DATA_DIR / f"LDoGI/quadtrees/{QUADTREE}"))
active_cells = [node for node in list(G.nodes) if G.nodes[node]["state"] == CellState.ACTIVE.value]
num_classes = len(active_cells)
num_classes

6398

In [6]:
def cf(*mini_batches):
    ids_out = np.concatenate([i[0] for i in mini_batches[0]])
    images_out = torch.cat([i[1] for i in mini_batches[0]])
    labels_out = torch.cat([i[2] for i in mini_batches[0]])
    label_names_out = np.concatenate([i[3] for i in mini_batches[0]])

    return ids_out, images_out, labels_out, label_names_out

In [7]:
train_data = LDoGIDataset(quadtree=QUADTREE, from_id=1, to_id=10000, env=PROJECT_ROOT / ".env")

In [8]:
def print_batch_stats(batch):
    print(f"ids: {batch[0]}")
    print(f"ids shape: {batch[0].shape}")
    print(f"ids dtype: {batch[0].dtype}")
    print()

    print(f"ims shape: {batch[1].shape}")
    print(f"ims dtype: {batch[1].dtype}")
    print()

    print(f"labels shape: {batch[2].shape}")
    print(f"labels sum: {batch[2].sum()}")
    print(f"labels dtype: {batch[2].dtype}")
    print()

    print(f"labels name: {batch[3]}")
    print(f"labels name shape: {batch[3].shape}")
    print(f"labels name dtype: {batch[3].dtype}")

In [9]:
# individual sampling
ind_loader = DataLoader(
    train_data,
    collate_fn=cf,
    sampler=RandomSampler(train_data),
    batch_size=BATCH_SIZE,
    # num_workers=1,
    # prefetch_factor=5,
)

In [10]:
ind_loader = LDoGIDataLoader(train_data, batch_size=BATCH_SIZE, fetch_mode="individual", shuffle=False)

In [11]:
A_iter = iter(ind_loader)

In [12]:
A = next(A_iter)

In [13]:
print_batch_stats(A)

ids: [   1    2    3    4    5    6    7    8    9   10   11   12   13   14
 9954   16   17   18   19   20   21   22   23   24   25   26   27   28
   29   30   31   32]
ids shape: (32,)
ids dtype: int64

ims shape: torch.Size([32, 3, 224, 224])
ims dtype: torch.float32

labels shape: torch.Size([32, 6398])
labels sum: 32.0
labels dtype: torch.float32

labels name: ['1200120013211231' '10223033123022' '12001003' '120012120' '21300002'
 '1310221303012' '1200020310201210' '12102' '313010' '31321011102202'
 '0123220' '1022202130101023' '1200102000' '030012310110212' '13203201333'
 '021032330' '021032012' '03002311332' '120012003' '021310130213'
 '0300010' '1003' '0302213211' '03113130' '0302232301123' '31332'
 '02112200213' '0211032303230' '31120' '021032333111' '03020332021'
 '10203122310321']
labels name shape: (32,)
labels name dtype: <U16


In [14]:
# batched sampling
batched_loader = DataLoader(
    train_data,
    sampler=BatchSampler(RandomSampler(train_data), batch_size=32, drop_last=False),
    collate_fn=lambda *mini_batches: mini_batches[0][0],
    # num_workers=1,
    # prefetch_factor=5,
)

In [15]:
batched_loader = LDoGIDataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False)

In [16]:
B_iter = iter(batched_loader)

In [17]:
B = next(B_iter)

In [18]:
print_batch_stats(B)

ids: [   1    2    3    4    5    6    7    8    9   10   11   12   13   14
 9954   16   17   18   19   20   21   22   23   24   25   26   27   28
   29   30   31   32]
ids shape: (32,)
ids dtype: int64

ims shape: torch.Size([32, 3, 224, 224])
ims dtype: torch.float32

labels shape: torch.Size([32, 6398])
labels sum: 32.0
labels dtype: torch.float32

labels name: ['1200120013211231' '10223033123022' '12001003' '120012120' '21300002'
 '1310221303012' '1200020310201210' '12102' '313010' '31321011102202'
 '0123220' '1022202130101023' '1200102000' '030012310110212' '13203201333'
 '021032330' '021032012' '03002311332' '120012003' '021310130213'
 '0300010' '1003' '0302213211' '03113130' '0302232301123' '31332'
 '02112200213' '0211032303230' '31120' '021032333111' '03020332021'
 '10203122310321']
labels name shape: (32,)
labels name dtype: <U16
