In [1]:
import warnings
warnings.filterwarnings("ignore")

import torch
import os
import numpy as np

* Load training hyper-parameters

In [2]:
import json
from utils import DictToClass

model_dir = 'runs/resnet56_cifar10_sgd/cosine_lr=5.00e-02_bs=128_wd=5.00e-04_corr-1.0_1500_cat[]_seed=1'
# model_dir ='runs/resnet56_cifar10_sgd/cosine_lr=4.00e-02_bs=64_wd=5.00e-04_corr-1.0_-1_cat[]_seed=1'

with open(os.path.join(model_dir, 'config.json'), 'r') as f:
    args = f.read()
args = json.loads(args)
args = DictToClass(args)

# device = torch.device(f'cuda:{args.gpu_id}' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0')

* Prepare training set (shffuled v.s. unshuffled)

In [3]:
from data import cifar_dataset
from utils import cycle_loader

train_set, test_set = cifar_dataset(data_name=args.data_name, root=args.data_dir, label_corruption=args.label_corruption, example_per_class=args.example_per_class, categories=args.categories)

train_loader_no_shuffle = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size_eval, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_memory, drop_last=False)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers,
    pin_memory=args.pin_memory, drop_last=True)
train_loader_cycle = cycle_loader(train_loader)

Files already downloaded and verified
Files already downloaded and verified


* Load pretrained neural network

In [4]:
from pytorchcv.model_provider import get_model as ptcv_get_model

net =  ptcv_get_model(args.arch, pretrained=False).to(device)
net.load_state_dict(torch.load(os.path.join(model_dir, 'state_dict.pt'), map_location=device))

loss_func = torch.nn.CrossEntropyLoss(reduction='none')

# note: using momentum will cause ph_dim always approximately equals 1
# note: the authors use vaniall SGD with constant learning rate
optimizer = torch.optim.SGD(net.parameters(), lr=0.005, momentum=0.0, weight_decay=args.weight_decay)

* Continue training with more steps and collect the weights and training losses

In [5]:
from utils import get_params
from utils import validate

max_points = 3000 # too large value sometimes will cause training failure
train_loss_hist = []
weights_hist = []
for j, (x, y) in enumerate(train_loader_cycle):
    if j == max_points:
        break
    net.train()
    x, y = x.to(device), y.to(device)
    
    optimizer.zero_grad()
    yhat = net(x)
    loss = loss_func(yhat, y)
    loss.mean().backward()
    optimizer.step()
    
    curr_params = get_params(net)
    weights_hist.append(curr_params)
    
    # train_loss, train_acc, train_loss_vec = validate(net, train_loader_no_shuffle, loss_func, device, train=False)
    # train_loss_hist.append(train_loss_vec)
    if j % 500 == 0:
        train_loss, train_acc, train_loss_vec = validate(net, train_loader_no_shuffle, loss_func, device, train=False)
        print(f'iteration={j}, train_loss={train_loss:.4f}, train_acc={train_acc*100:.4f}%')
weights_hist = torch.stack(weights_hist, dim=0).cpu().numpy()

iteration=0, train_loss=0.0060, train_acc=99.9133%
iteration=500, train_loss=0.0013, train_acc=100.0000%
iteration=1000, train_loss=0.0009, train_acc=100.0000%
iteration=1500, train_loss=0.0008, train_acc=100.0000%
iteration=2000, train_loss=0.0007, train_acc=100.0000%
iteration=2500, train_loss=0.0006, train_acc=100.0000%


* Compute persistent homology dimension

In [6]:
from indicator import fast_ripser
ph_dim_euclidean = fast_ripser(weights_hist, max_points=max_points, min_points=200, point_jump=20)

print('PH dimension', ph_dim_euclidean)

[32m2024-07-06 18:02:25.725[0m | [34m[1mDEBUG   [0m | [36mindicator.topology[0m:[36mfast_ripser[0m:[36m84[0m - [34m[1mDistance matrix computation time: 131.51s[0m
[32m2024-07-06 18:02:45.876[0m | [34m[1mDEBUG   [0m | [36mindicator.topology[0m:[36mph_dim_from_distance_matrix[0m:[36m64[0m - [34m[1mPh Dimension Calculation has an approximate error of: 0.0004687931896550983.[0m


PH dimension 1.6849122710899136
