In [7]:
from itertools import product
import torch

In [8]:
w, h = 3, 3

ls = [[0, 1] for _ in range(w*h)]

In [9]:
indicator_dataset = None

for idx, i in enumerate(product(*ls)):
    arr = torch.tensor(list(i)).reshape(1, 3, 3)
    if indicator_dataset is None:
        indicator_dataset = arr
    else:
        indicator_dataset = torch.concatenate([indicator_dataset, arr])

print(indicator_dataset.shape)

torch.Size([512, 3, 3])


In [33]:
stack_arr = torch.stack([indicator_dataset, indicator_dataset], dim=1)

In [32]:
# filter out cells which is confident [0, epsilon] OR [1-epsilon, 1]

epsilon = 0.1

stack_arr[(stack_arr[:, 1, 2, 2] < epsilon) + (stack_arr[:, 1, 2, 2] > 1-epsilon)].shape

torch.Size([512, 2, 3, 3])

In [31]:
net: torch.nn.Module

# out_ = net(indicator_dataset) # out shape: [512, 2, 3, 3]
out_ = indicator_dataset.clone().float()[:, None, ...] # out shape: [512, 2, 3, 3]
# calculate probability of 1
out_ = torch.softmax(out_, dim=1)[:, 1, ...] # out shape: [512, 1, 3, 3]

out_.shape

IndexError: index 1 is out of bounds for dimension 1 with size 1

In [34]:
stack_arr = torch.stack([indicator_dataset, out_], dim=1)
stack_arr.shape

RuntimeError: stack expects each tensor to be equal size, but got [512, 3, 3] at entry 0 and [512, 1, 3, 3] at entry 1

In [36]:

# 过滤输出 confidence 足够高的，过滤出高置信度对
filtered_arr = stack_arr[(stack_arr[:, 1, 1, 1] < epsilon) + (stack_arr[:, 1, 1, 1] > 1-epsilon)].transpose(1, 0)

# 做用旋转不变性做商集
filtered_arr.shape

torch.Size([2, 512, 3, 3])

In [37]:
# 权宜之计：对确定的高置信度规则直接 copy，对不确定的规则随机指派结果 （0 或者 1）
# 造一个字典映射，对于确定性的规则，对应直接的结果，对于不确定性的规则，随机指派结果。

map_dict = {
   i: j[..., 2, 2] for (i,j) in zip(*filtered_arr)
}

In [40]:
for idx, (i,j) in enumerate(zip(*filtered_arr)):
    
    print(f"Index: {idx}")
    print(f"Original Pair: {filtered_arr[:, idx, ...]}")
    print(f"Input:\n{i}")
    print(f"Output Probability:\n{j[..., 2, 2]}\n")

Index: 0
Original Pair: tensor([[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]])
Input:
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])
Output Probability:
0

Index: 1
Original Pair: tensor([[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 1]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 1]]])
Input:
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 1]])
Output Probability:
1

Index: 2
Original Pair: tensor([[[0, 0, 0],
         [0, 0, 0],
         [0, 1, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 1, 0]]])
Input:
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 1, 0]])
Output Probability:
0

Index: 3
Original Pair: tensor([[[0, 0, 0],
         [0, 0, 0],
         [0, 1, 1]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 1, 1]]])
Input:
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 1, 1]])
Output Probability:
1

Index: 4
Original Pair: tensor([[[0, 0, 0],
         [0,

In [120]:
import torch.nn as nn
from einops import rearrange, reduce

# get all 3x3 windows of a 2D array.

count_kernel = nn.Conv2d(in_channels=1, out_channels=9, kernel_size=3, stride=1, padding=1, padding_mode="circular", bias=False)
count_kernel.weight.data = rearrange(torch.eye(9, 9), "c (w h) -> c 1 w h", w=3, h=3)

In [101]:
# test input array

arr = torch.randint(0, 2, (1, 100, 100)).float()
patches_arr = rearrange(count_kernel(arr), "n w h -> w h n")

In [105]:
f = lambda x: torch.randint(0, 2, (1,)).item() if x not in map_dict else map_dict[x]

In [118]:
from functools import reduce
from operator import add

torch.tensor(list(map(f, reduce(add, [[j.long() for j in i] for i in patches_arr])))).reshape(*arr.shape)

tensor([[[1, 1, 0,  ..., 0, 1, 0],
         [0, 0, 0,  ..., 1, 0, 0],
         [0, 0, 0,  ..., 0, 1, 0],
         ...,
         [1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 1, 0, 0],
         [0, 1, 0,  ..., 0, 1, 0]]])