In [None]:
import sys
import os
sys.path.append(os.getcwd())
print(sys.path)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
from aggregator.attention import AttentionLoop
from torch.utils.data import Dataset, DataLoader,ConcatDataset



def getTensorData(path_to_folder,idx):

    data = torch.load(f'{path_to_folder}/pca_{idx}.pt')
    label = torch.load(f'{path_to_folder}/label.pt') 

    perm = torch.randperm(len(label))
    data = data[:,perm]
    label = label[perm]

    label_norm = F.normalize(label,p=1,dim=-1)
    label_norm.unsqueeze(0)

    center = data.matmul(label_norm.unsqueeze(-1))

    x = data
    y = label.expand_as(x)
    c = center.expand_as(x)
    y_ = label.unsqueeze(-1)
    return x,y,c

# getting a hard prediction by binarizing the affinity matrix
def getBinaryPred(model,x,beta):
    weight = model.getWeight(beta,x)
    weight = torch.nn.Threshold(0.8 * 1.0 / weight.shape[-1],0)(weight)
    weight = F.normalize(weight,p=1,dim=-1)
    predB = torch.einsum('bqi,bji -> bjq', weight, x)
    return predB
# helper class to record the accumulating loss
# +=: add loss
# print(**): print the average loss
class loss_acc():
    def __init__(self):
        self.sum = 0.0
        self.n = 0
    def __iadd__(self,x):
        self.sum+=x
        self.n+=1
        return self
    def value(self):
        return self.sum / self.n
    def __str__(self):
        return f'{self.sum/self.n:.6f}'
    
def test(model,testloader):
    lossCounter = loss_acc()
    lossCounter2 = loss_acc()
    
    model.eval()
    
    with torch.no_grad():
        for x,y,c in testloader:

            x = x.cuda()
            y = y.cuda()
            c = c.cuda()
            beta = x.median(dim=-1,keepdim=True)[0]

            pred = model.cuda()(beta,x)
            loss = loss_fn(pred,c[:,:,[0]])

            pred_b = getBinaryPred(model,x,beta)
            loss_b = loss_fn(pred_b,c[:,:,[0]])

            lossCounter+=loss.cpu().detach().numpy()
            lossCounter2+=loss_b.cpu().detach().numpy()
    # print(f'{loss:.4f},{loss_b:.6f}')
    return lossCounter, lossCounter2

def test_classes(model,testloader):
    def loss_fn(pred,gt):
        correct = (pred == gt).all(dim=-1).sum()
        n = pred.shape[0]
        return correct,n

    correct = 0
    n = 0
    
    model.eval()
    
    with torch.no_grad():
        for x,y,c in testloader:

            beta = x.median(dim=-1,keepdim=True)[0]


            weight = model.getWeight(beta,x)
            weight = torch.nn.Threshold(0.8 * 1.0 / weight.shape[-1],0)(weight)
            pred =  (weight != 0) * 1.0
            accuracy = loss_fn(pred,y[:,[0],:])
            correct+=accuracy[0]
            n+=accuracy[1]
    # print(f'{loss:.4f},{loss_b:.6f}')
    return correct * 1.0 / n


def test_classes_hamming(model,testloader):
    def loss_fn(pred,gt):
#         print(pred.shape,gt.shape)
#         print("\n\n\n",pred,"\n\n\n",gt)
        n = pred.shape[0]
        correct = (pred - gt/gt.sum(2,keepdim=True)).abs().sum()

        return correct,n

    correct = 0
    n = 0
    
    model.eval()
    
    with torch.no_grad():
        for x,y,c in testloader:

            beta = x.median(dim=-1,keepdim=True)[0]


            weight = model.getWeight(beta,x)
#             weight = torch.nn.Threshold(0.8 * 1.0 / weight.shape[-1],0)(weight)
#             pred =  (weight != 0) * 1.0
            pred = weight
           
    
            accuracy = loss_fn(pred,y[:,[0],:])
            correct+=accuracy[0]
            n+=accuracy[1]
    # print(f'{loss:.4f},{loss_b:.6f}')
    return correct * 1.0 / n

class FLdata(Dataset):

    def __init__(self, path_to_folder, indexes):
        self.path_to_folder = path_to_folder
        self.indexes = indexes
        self.size = len(indexes)

    def __len__(self):
        return self.size

    def __getitem__(self, idx):

        sample = getTensorData(self.path_to_folder,self.indexes[idx])
        return sample


In [179]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

import math

class MLP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MLP, self).__init__()

        self.main = torch.nn.Sequential(nn.Linear(in_channels, 2 * in_channels),
                torch.nn.LeakyReLU(),
                nn.Linear(2 * in_channels, out_channels))

    def forward(self, x):
        vout=x
        print(x.shape)
        x=x.view(x.shape[0],-1)
        print(x.shape)
        attention_scores = self.main(x)
        attention_weights = F.softmax(attention_scores, dim=-1)
        print(vout.shape)
        print(attention_weights.shape)
        out = torch.einsum('bk,bjk -> bj' ,attention_weights,vout).unsqueeze(-1)

        return out
    
    def getWeight(self, x):
        print(x.shape)
        x=x.view(x.shape[0],-1)
        print(x.shape)
        attention_scores = self.main(x).unsqueeze(1)
        attention_weights = F.softmax(attention_scores, dim=-1)
        return attention_weights


In [195]:
(torch.rand(1)<0.5).item()

True

In [57]:
train_path=["backdoor_1(0)/gm"]
path_prefix="./AggData/dirichlet/cifar/"

In [29]:
trainDataset = ConcatDataset([FLdata(path_prefix + path_to_folder,list(range(0,30))) for path_to_folder in train_path])

In [30]:
k=trainDataset[0][0].shape[0]

In [105]:
x=torch.stack([trainDataset[0][0],trainDataset[1][0]])

In [188]:
net=MLP(k*10, 10)

In [189]:
net.getWeight(x).shape

torch.Size([2, 372, 10])
torch.Size([2, 3720])


torch.Size([2, 1, 10])

In [182]:
net(x).shape

torch.Size([2, 372, 10])
torch.Size([2, 3720])
torch.Size([2, 372, 10])
torch.Size([2, 10])


torch.Size([2, 372, 1])

In [184]:
net=AttentionLoop(k,21)

In [185]:
net(x.median(dim=-1,keepdims=True)[0],x).shape

torch.Size([2, 372, 1])

In [186]:
net.getWeight(x.median(dim=-1,keepdims=True)[0],x).shape

torch.Size([2, 1, 10])

In [87]:
x=x.view(-1)

In [93]:
x.shape

torch.Size([2, 372, 10])

In [94]:
x.view(x.shape[0],-1)

tensor([[ 0.0409,  0.1130,  0.0131,  ..., -0.0011,  0.0012,  0.0045],
        [-0.0198, -0.0205, -0.1619,  ...,  0.0019, -0.0015,  0.0057]])