In [1]:
import argparse
import os
import random
import shutil
import time
import warnings
import json
import re

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, transforms

from model_analysis_nets import *
from CKA import linear_CKA, kernel_CKA

import numpy as np
import sympy
import pandas as pd
import seaborn

import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster, fclusterdata
import scipy
import networkx as nx
import gif


In [2]:
def merge_sublists_with_shared_items(data):
    merged = True
    while merged:
        merged = False
        for i in range(len(data)):
            for j in range(i + 1, len(data)):
                if set(data[i]).intersection(data[j]):
                    data[i] = list(set(data[i]).union(data[j]))
                    del data[j]
                    merged = True
                    break
            if merged:
                break
    return data


In [3]:

import gudhi as gd
from ripser import Rips
def getGradients(modelpth : str = ""):
    if not modelpth:
        return {}
    model1 = LeNet().to(device)
    model1.load_state_dict(torch.load(modelpth))
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model1.parameters(), lr=0.01, momentum=0.5)

    for X, Y in test_loader:
        X_test = X.to(device)
        Y_test = Y.to(device)
        # print(Y)
        break

    # Forward pass
    outputs = model1(X_test)

    # Compute the loss
    loss = criterion(outputs, Y_test)

    # Zero the gradients
    optimizer.zero_grad()

    # Backward pass
    loss.backward()

    # Access the gradients
    gradients1 = {}
    for name, param in model1.named_parameters():
        gradients1[name] = param.grad
    return gradients1
trans_mnist = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,))])
dataset_test = datasets.MNIST('data/mnist/', train=False, download=True, transform=trans_mnist)
test_loader = torch.utils.data.DataLoader(
                dataset_test, batch_size=32,
                num_workers=2, pin_memory=True, shuffle=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
attack_pth = "../TDA-NN/3090/LG-FedAvg/save_attack_ub/mnist/lenet_iidTrue_num100_C0.8_le2/shard2/pattern11-30--17-49-20/local_attack_save/"
normal_pth = "../TDA-NN/3090/LG-FedAvg/save_attack_ub/mnist/lenet_iidTrue_num100_C0.8_le2/shard2/pattern11-30--17-49-20/local_normal_save/"
round_set = []
client_set_attack = []
client_set_normal = []
modelpth_set = []
for dirpath, dirnames, filenames in os.walk(attack_pth):
    for filename in filenames:
        # if int(re.findall(r'-?\d+', filename)[0]) < round_max and int(re.findall(r'-?\d+', filename)[1]) < client_max:
        modelpth_set.append(os.path.join(dirpath, filename).replace("\\","/"))
        round_set.append(int(re.findall(r'-?\d+', filename)[0]))
        client_set_attack.append(int(re.findall(r'-?\d+', filename)[1]))

round_set = list(set(round_set))
client_set_attack = list(set(client_set_attack))

for dirpath, dirnames, filenames in os.walk(normal_pth):
    for filename in filenames:
        # if int(re.findall(r'-?\d+', filename)[0]) < round_max and int(re.findall(r'-?\d+', filename)[1]) < client_max:
        modelpth_set.append(os.path.join(dirpath, filename).replace("\\","/"))
        client_set_normal.append(int(re.findall(r'-?\d+', filename)[1]))

client_set_normal = list(set(client_set_normal))

print("rounds", round_set, len(round_set))
print("attack", client_set_attack, len(client_set_attack))
print("normal", client_set_normal, len(client_set_normal))
print(len(modelpth_set))

modelpth_dict = {}
for idx, r in enumerate(round_set):
    current_set_normal = []
    current_set_attack = []
    for mt in modelpth_set:
        # print(mt)
        if int(re.findall(r'-?\d+', mt)[-2]) == r and int(re.findall(r'-?\d+', mt)[-1]) in client_set_normal:
            # print(int(re.findall(r'-?\d+', mt)[-1]))
            current_set_normal.append(mt)
        if int(re.findall(r'-?\d+', mt)[-2]) == r and int(re.findall(r'-?\d+', mt)[-1]) in client_set_attack:
            # print(int(re.findall(r'-?\d+', mt)[-1]))
            current_set_attack.append(mt)
        
    modelpth_dict[f'{r}_normal'] = current_set_normal
    modelpth_dict[f'{r}_attack'] = current_set_attack
# print(modelpth_dict)
client_max = 14
corr_set = []
for idx, r in enumerate(round_set):
    if idx > 15:
        break
    r_modelpth_set = modelpth_dict[f'{r}_normal'][0:15] + modelpth_dict[f'{r}_attack'][0:15]
    # print(len(r_modelpth_set))
    r_corr = np.zeros((client_max * 2, client_max * 2, 1))

    gradientlist = []
    for idx, modelp1 in enumerate(r_modelpth_set):
        gradientlist.append(getGradients(modelp1))
    
    for id1, modelp1 in enumerate(gradientlist):
        if id1 >= client_max * 2:
            break
        for id2, modelp2 in enumerate(gradientlist):
            if id2 >= client_max * 2:
                break
            if id2 <= id1:
                continue
            gradients1 = gradientlist[id1]
            gradients2 = gradientlist[id2]
            
            corrlist = []

            acts1 = gradients1["layer3.0.weight"]
            acts2 = gradients2["layer3.0.weight"]
            acts1 = acts1.cpu().reshape((50,-1))
            acts2 = acts2.cpu().reshape((50,-1))
            l = linear_CKA(acts1.T, acts2.T)
            k = kernel_CKA(acts1.T, acts2.T)
            corrlist.append((l+k)/2)

            r_corr[id1][id2] = np.array(corrlist)
    corr_set.append(r_corr)

print(len(corr_set))

rounds [8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50] 22
attack [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 20
normal [21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99] 27
833
16


In [26]:
for corr in corr_set:
    data = 1 - corr.reshape(-1, 2 * client_max) - corr.reshape(-1, 2 * client_max).T
    print(data.shape)
    # df=pd.DataFrame(data)
    # plot=seaborn.heatmap(df)
    # plt.show()

    rips_complex = gd.RipsComplex(distance_matrix=data, max_edge_length=100)
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=2)
    persistence = simplex_tree.persistence()

    # filtration = simplex_tree.get_filtration()
    # gd.plot_persistence_barcode(persistence=persistence)
    # plt.show()
    
    count = 0
    results = []
    for (birth_value, death_value), d in [(i[1], i[0]) for i in persistence if i[0] == 0]:
        tolerance = 1e-6
        birth_simplices = []
        death_simplices = []
        filtration = simplex_tree.get_filtration()

        for simplex, filtration_value in filtration:
            # print(filtration_value, birth_value)
            if abs(filtration_value - birth_value) < tolerance:
                # print(simplex)
                birth_simplices.append(simplex)
            elif abs(filtration_value - death_value) < tolerance or (death_value == np.Inf and d != 0):
                death_simplices.append(simplex)
        # print(birth_simplices)

        # involved_points = set()

        # for birth_simplex in birth_simplices:
        #     for death_simplex in death_simplices:
        #         if set(birth_simplex).issubset(death_simplex):
        #             involved_points.update(birth_simplex)
        #             involved_points.update(death_simplex)
        
        # involved_points_b = (merge_sublists_with_shared_items(birth_simplices))
        involved_points_d = (merge_sublists_with_shared_items(death_simplices))

        # if len(involved_points_b) >= 0 : 
        #     print('b',count, birth_value, death_value, involved_points_b) 
        if len(involved_points_d) >= 0 : 
            print('d', d, count, birth_value, death_value, involved_points_d) 
            results.append(involved_points_d)
        count += 1
    
    a = []
    for concern_points in results:
        for c in concern_points:
            # print(c)
            a.append(c)
            
    for i in range(0, len(a)):
        if (np.unique(a[i:len(a)]).shape[0] == 28):
            continue
        print(merge_sublists_with_shared_items(a[i:len(a)]))
        for j in merge_sublists_with_shared_items(a[i:len(a)]):
            print(np.mean(np.array(j) >= 14), end=' ')
        print()


    group = np.zeros((28,28))
    for i in merge_sublists_with_shared_items(a)[0]:
        group[i][i] = 1
    # df=pd.DataFrame(group)
    # plot=seaborn.heatmap(df)
    # plt.show()


    

(28, 28)
d 0 0 0.0 inf []
d 0 1 0.0 0.18628089909290302 [[18, 20]]
d 0 2 0.0 0.15767802743773407 [[17, 25]]
d 0 3 0.0 0.1464803589213347 [[16, 24]]
d 0 4 0.0 0.1357047638713328 [[21, 23]]
d 0 5 0.0 0.13110019680628682 [[17, 26]]
d 0 6 0.0 0.12137554366074566 [[17, 27]]
d 0 7 0.0 0.12122116428865182 [[1, 19]]
d 0 8 0.0 0.09722612051139601 [[23, 27]]
d 0 9 0.0 0.09145980529430686 [[1, 2]]
d 0 10 0.0 0.07929121286219432 [[8, 0, 11, 9], [20, 26]]
d 0 11 0.0 0.07769103933473143 [[4, 12]]
d 0 12 0.0 0.0706996564895932 [[0, 1]]
d 0 13 0.0 0.07047794164252541 [[8, 10]]
d 0 14 0.0 0.0703839008710212 [[11, 15]]
d 0 15 0.0 0.06883810725922057 [[8, 9]]
d 0 16 0.0 0.06736501387020255 [[19, 26]]
d 0 17 0.0 0.06667039427475785 [[16, 26]]
d 0 18 0.0 0.06305502401773744 [[11, 12]]
d 0 19 0.0 0.05494566360974773 [[0, 8]]
d 0 20 0.0 0.052635023100906286 [[22, 27]]
d 0 21 0.0 0.04914569180478634 [[3, 4]]
d 0 22 0.0 0.04460676524839835 [[2, 5]]
d 0 23 0.0 0.042304983114682226 [[0, 14]]
d 0 24 0.0 0.0421111

# how to evaluate the cluster accuracy
a > 14 ? 0 : 1
mean(a[i] > 14)