In [1]:
from functools import reduce
import time
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from matplotlib import pyplot as plt
from torchvision import datasets, transforms
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchsummary import summary as summary_

train_data = datasets.MNIST(root='./', train=True, download=False, transform=transforms.ToTensor())
test_data = datasets.MNIST(root='./ ', train=False, download=True, transform=transforms.ToTensor())

train_dataloader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(dataset=test_data, batch_size=64, shuffle=True)

is_draw_feature_map = False

def inference(
        input_img,
        input_target,
        weight_1,
        bias_1,
        weight_2,
        bias_2,
        weight_3,
        bias_3,
        in_linear_weight,
        in_linear_bias,
        out_linear_weight,
        out_linear_bias,
        iterations_
    ):

    tot_time = 0
    tot_acc = 0
    accuracy = Accuracy(num_classes=10)

    for _ in range(iterations_):
        torch.manual_seed(888)
        np.random.seed(888)
        random.seed(888)

        start_time = time.perf_counter_ns()
        res = F.conv2d(input=input_img,
                       weight=weight_1,
                       bias=bias_1,
                       stride=1,
                       padding='same')
        res = F.relu(res)
        res = F.max_pool2d(input=res, kernel_size=2)
        end_time = time.perf_counter_ns()

        tot_time += (end_time - start_time)


        if is_draw_feature_map:
            fig, axes = plt.subplots(1, 4)

            for features, ax in zip(res, axes.flat):
                ax.imshow(features, cmap='gray')
                ax.axis('off')
                ax.grid('off')

            plt.show()

        start_time = time.perf_counter_ns()
        res = F.conv2d(input=res,
                       weight=weight_2,
                       bias=bias_2,
                       stride=1,
                       padding='same')
        res = F.relu(res)
        res = F.max_pool2d(input=res, kernel_size=2)
        end_time = time.perf_counter_ns()
        tot_time += (end_time - start_time)


        if is_draw_feature_map:
            fig, axes = plt.subplots(2, 4, figsize=(6, 2.8))

            for features, ax in zip(res, axes.flat):
                ax.imshow(features, cmap='gray')
                ax.axis('off')
                ax.grid('off')

            plt.show()


        start_time = time.perf_counter_ns()
        res = F.conv2d(input=res,
                       weight=weight_3,
                       bias=bias_3,
                       stride=1,
                       padding='same')
        res = F.relu(res)
        res = F.max_pool2d(input=res, kernel_size=2)
        end_time = time.perf_counter_ns()
        tot_time += (end_time - start_time)


        if is_draw_feature_map:
            fig, axes = plt.subplots(4, 4, figsize=(6, 6))

            for features, ax in zip(res, axes.flat):
                ax.imshow(features, cmap='gray')
                ax.axis('off')
                ax.grid('off')

            plt.show()

        start_time = time.perf_counter_ns()
        res = torch.flatten(res, start_dim=1)
        res = F.linear(res,
                      weight=in_linear_weight,
                      bias=in_linear_bias)
        res = F.relu(res)
        res = F.linear(res,
                      weight=out_linear_weight,
                      bias=out_linear_bias)
        res = F.softmax(res, dim=0)
        probabilty, predicted = torch.max(res, 1)

        end_time = time.perf_counter_ns()
        tot_time += (end_time - start_time)

        tot_acc += accuracy(predicted, input_target)

    return tot_acc / iterations_, tot_time / iterations_ / 1000 / 1000


from anytree import Node, RenderTree
from functools import reduce

def find_weight_relations(weight):
    left_to_right = (np.pad(weight[:, 1:], ((0, 0), (0, 1))) - weight)[:, :-1] # left -> right
    top_to_bottom = (np.pad(weight[1:, :], ((0, 1), (0, 0))) - weight)[:-1] # top -> bottom

    return [[(i * (left_to_right.shape[1]+1)+1) + j,
               i * (left_to_right.shape[1]+1) + j + 2,
                    left_to_right[i, j]] for i in range(left_to_right.shape[0]) for j in range(left_to_right.shape[1])] + \
            [[(i*(top_to_bottom.shape[1])+1) + j, ((i+1)*(top_to_bottom.shape[1])+1) + j,
                    top_to_bottom[i, j]] for i in range(top_to_bottom.shape[0]) for j in range(top_to_bottom.shape[1])]


def find_inner(feed_weight, find_method):
    zero_counts = 0
    selected_weight = np.zeros(shape=(feed_weight.shape[0] * feed_weight.shape[1]), dtype=np.float32)

    w = find_weight_relations(feed_weight)
    connections = sorted(w, reverse=False, key=lambda x: x[-1])
    connections = list(map(lambda x: [x[2], str(x[0]), str(x[1])], connections))

    res = find_method(connections)

    nodes = []
    idx = 1
    for i in range(feed_weight.shape[0]):
        for j in range(feed_weight.shape[1]):
            nodes.append(Node(idx, weight=feed_weight[i, j]))
            idx += 1

    for r in res:
        start, end = map(int, r[1:])
        nodes[end-1].parent = nodes[start-1]

    # mst가 space에 몇개 있는지
    roots = {x.root.name for x in nodes}

    for root in roots:
        # for pre, fill, node in RenderTree(nodes[root-1]):
        #     print("%s%s" % (pre, node.name))

        # 가지에 연결되지 않은 노드는 0으로 설정한다.
        if len(nodes[root-1].descendants) == 0:
            zero_counts += 1
            continue

        selected_weight[root-1] = nodes[root-1].weight

        for des in nodes[root-1].descendants:
            selected_weight[des.name-1] = des.weight

    selected_weight = selected_weight.reshape(feed_weight.shape)

    return selected_weight, zero_counts


def find_best_feature(weights):
    valid_params = reduce(lambda res, x: res*x, weights.shape)

    if weights.ndim == 2:
        w, zero_counts = find_inner(weights)
        valid_params -= zero_counts
        print('valid_param: ', valid_params)
        return w

    else:
        tot_weight = np.zeros(shape=weights.shape, dtype=np.float32)

        for out_ in range(weights.shape[0]):
            for in_ in range(weights.shape[1]):
                tot_weight[out_, in_], zero_counts = find_inner(weights[out_, in_])
                valid_params -= zero_counts

        print('valid_param: ', valid_params)
        return tot_weight

model = torch.load('./compressed_base_model_weight.pt')['model']

In [115]:
from pprint import pprint

def find_floyd(weight):
    adj_matrix_weight = np.full(shape=(weight.shape[0] * weight.shape[1], weight.shape[0] * weight.shape[1]), fill_value=np.inf, dtype=np.float32)
    adj_matrix_vert  = np.full_like(adj_matrix_weight, fill_value=np.nan)

    connections = find_weight_relations(weight)

    for start, end, w in connections:
        adj_matrix_weight[start-1, end-1] = w
        adj_matrix_vert[start-1, end-1] = start-1

    for i in range(weight.shape[0] * weight.shape[1]):
        for j in range(weight.shape[0] * weight.shape[1]):
            for k in range(weight.shape[0] * weight.shape[1]):
                if adj_matrix_weight[j, k] > adj_matrix_weight[j, i] + adj_matrix_weight[i, k]:
                    adj_matrix_weight[j, k] = adj_matrix_weight[j, i] + adj_matrix_weight[i, k]
                    adj_matrix_vert[j, k] = adj_matrix_vert[i, k]

    for i in range(9):
        if i == 0:
            continue

        print('0 -> ', end='')
        if not np.isnan(adj_matrix_vert[0, i]):
            print(int(adj_matrix_vert[0, i]), end='')
        print(f' -> {i}')

for layer in model:
    for sep in layer:
        if isinstance(sep, torch.nn.Sequential):
            for l in sep:
                if isinstance(l, nn.Conv2d):
                    weight = l.weight.detach().cpu().numpy()

                    for in_ in range(weight.shape[1]):
                        for out_ in range(weight.shape[0]):
                            find_floyd(weight[out_, in_])
                    break
        else:
            print('nono')
            print(sep)
        break
    break

0 -> 0 -> 1
0 -> 1 -> 2
0 -> 0 -> 3
0 -> 1 -> 4
0 -> 4 -> 5
0 -> 3 -> 6
0 -> 4 -> 7
0 -> 5 -> 8
0 -> 0 -> 1
0 -> 1 -> 2
0 -> 0 -> 3
0 -> 1 -> 4
0 -> 2 -> 5
0 -> 3 -> 6
0 -> 4 -> 7
0 -> 5 -> 8
0 -> 0 -> 1
0 -> 1 -> 2
0 -> 0 -> 3
0 -> 1 -> 4
0 -> 2 -> 5
0 -> 3 -> 6
0 -> 6 -> 7
0 -> 5 -> 8
0 -> 0 -> 1
0 -> 1 -> 2
0 -> 0 -> 3
0 -> 1 -> 4
0 -> 2 -> 5
0 -> 3 -> 6
0 -> 6 -> 7
0 -> 5 -> 8


In [None]:
'''

0 -> 1
0 -> 1 -> 2
0 -> 3
0 -> 1 -> 4
0 -> 1 -> 4 -> 5
0 -> 3 -> 6
0 -> 1 -> 4 -> 7
0 -> 1 -> 4 -> 5 -> 8

0
| - 1
|   | - 2
|   | - 4
|       | - 5
|       |   | - 8
|       |
|       | - 7
|
| - 3
    | - 6

'''