In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import gudhi as gd
from scipy.sparse.csgraph import minimum_spanning_tree

In [2]:
def PH0(weight, mel=1000):
    if len(weight) == 2:
        weight = weight[0]
    m = weight.shape[0]
    W = weight.view(m, -1)
    rips = gd.RipsComplex(W, max_edge_length=mel)
    st = rips.create_simplex_tree(max_dimension=0)
    st.compute_persistence()
    idx = st.flag_persistence_generators()
    if len(idx[0]) == 0:
        verts = torch.empty((0, 2), dtype=int)
    else:
        verts = torch.tensor(idx[0][:, 1:])
    dgm = torch.norm(W[verts[:, 0], :] - W[verts[:, 1], :], dim=-1)
    tloss = torch.sum(dgm)

    norm = torch.norm(W, dim=1)
    nloss = torch.sum((1 - norm**2)**2)
    return nloss, tloss


def MST(weight):
    if len(weight) == 2:
        weight = weight[0]
    m = weight.shape[0]
    W = weight.view(m, -1)
    dist = torch.sqrt(torch.sum(torch.pow(W[:, None, :] - W[None, :, :], 2), dim=2))
    Tscr = minimum_spanning_tree(dist.detach().cpu().numpy())
    result = Tscr.toarray()
    mst = np.where(result > 0)
    tloss = torch.sqrt(((W[mst[0]] - W[mst[1]])**2).sum(-1)).sum()

    norm = torch.norm(W, dim=1)
    nloss = torch.sum((1 - norm**2)**2)
    return nloss, tloss

In [51]:
weight = torch.randn(128, 2048, 3, 3)

In [52]:
%%time
MST(weight)

CPU times: user 1.34 s, sys: 5.09 s, total: 6.44 s
Wall time: 321 ms


(tensor(4.3503e+10), tensor(24108.1094))

In [53]:
%%time
PH0(weight)

CPU times: user 3.88 s, sys: 211 ms, total: 4.09 s
Wall time: 3.77 s


(tensor(4.3503e+10), tensor(24108.0957))

In [54]:
weight = torch.randn(256, 1024, 3, 3)

In [55]:
%%time
MST(weight)

CPU times: user 2.42 s, sys: 9.11 s, total: 11.5 s
Wall time: 592 ms


(tensor(2.1760e+10), tensor(33985.5508))

In [56]:
%%time
PH0(weight)

CPU times: user 3.79 s, sys: 46.4 ms, total: 3.83 s
Wall time: 3.74 s


(tensor(2.1760e+10), tensor(33985.5391))

In [57]:
weight = torch.randn(512, 512, 3, 3)

In [58]:
%%time
MST(weight)

CPU times: user 3.75 s, sys: 17.2 s, total: 20.9 s
Wall time: 1.12 s


(tensor(1.0831e+10), tensor(47651.5469))

In [59]:
%%time
PH0(weight)

CPU times: user 3.66 s, sys: 57.3 ms, total: 3.72 s
Wall time: 3.44 s


(tensor(1.0831e+10), tensor(47651.5430))

In [60]:
weight = torch.randn(1024, 256, 3, 3)

In [61]:
%%time
MST(weight)

CPU times: user 7.07 s, sys: 34.6 s, total: 41.7 s
Wall time: 2.33 s


(tensor(5.4262e+09), tensor(66562.6562))

In [62]:
%%time
PH0(weight)

CPU times: user 4.63 s, sys: 200 ms, total: 4.83 s
Wall time: 4.13 s


(tensor(5.4262e+09), tensor(66562.6562))

In [63]:
weight = torch.randn(2048, 128, 3, 3)

In [64]:
%%time
MST(weight)

CPU times: user 15.1 s, sys: 1min 13s, total: 1min 28s
Wall time: 5.26 s


(tensor(2.7069e+09), tensor(92119.3359))

In [65]:
%%time
PH0(weight)

CPU times: user 7.19 s, sys: 168 ms, total: 7.35 s
Wall time: 6.51 s


(tensor(2.7069e+09), tensor(92119.3438))

In [14]:
check_size = (
    (64, 64, 3, 3), 
    (1024, 512, 3, 3),
    (1000, 10), 
    (10, 1000)
)

In [None]:
bsize = 512
num_worker = 4
epochs = 5

In [None]:
train_dataset = torchvision.datasets.CIFAR10(
                root='./DATA/', 
                transform=transforms.Compose(
                    [
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomRotation(15),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                    ]),
                train=True)

val_dataset = torchvision.datasets.CIFAR10(
                root='./DATA/', 
                transform=transforms.Compose(
                    [
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                    ]),
                train=False) 

In [None]:
train_loader = DataLoader(
                    train_dataset, 
                    batch_size=bsize, 
                    shuffle=True, 
                    num_workers=num_worker, 
                    pin_memory=True, 
                    )

val_loader = DataLoader(
                    val_dataset, 
                    batch_size=bsize, 
                    shuffle=False, 
                    num_workers=num_worker, 
                    pin_memory=True, 
                    )

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])


In [None]:
epochs = 50
lamd = 0.1
lr = 0.01
weight_decay = 0.0001

In [None]:
model = ResNet18().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
for e in range(epochs):
    for x, y in train_loader:
        x = x.cuda()
        y = y.cuda()
        out = model(x)
        loss = criterion(out, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    test_num = 0
    hit_num = 0
    for x, y in val_loader:
        x = x.cuda()
        y = y.cuda()
        test_num += len(y)
        out = model(x)
        pred = out.argmax(dim=1, keepdim=True)
        hit_num += pred.eq(y.view_as(pred)).sum().item()
    print(f'epoch: {e}, loss: {loss}, accruacy: {hit_num/test_num}')