In [9]:
import networkx as nx
from tqdm import tqdm
import numpy as np
import random

from karateclub import DeepWalk
from karateclub.dataset import GraphSetReader

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.data import DataLoader
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

In [10]:
def getembedding(graph_list):
    scaler = StandardScaler()
    dim = 16
    embedlist = []
    newlistt = []

    for x in graph_list:
        model = DeepWalk(dimensions=dim)
        model.fit(x)
        x_embed = model.get_embedding()
        x_embed = scaler.fit_transform(x_embed)
        x_embed = torch.from_numpy(x_embed)
        embedlist.append(x_embed)
    seq_len = [len(x) for x in graph_list]
    max_len = max(seq_len)
    for x in embedlist:
        x = F.pad(x, (0, 0, 0 ,max_len - x.shape[0]), "constant", 0)
        temp = x.reshape([1, dim, max_len])
        newlistt.append(temp)
    return newlistt

def chunkwise(t, size=2):
    it = iter(t)
    newzipped = zip(*[it]*size)
    newlist = [list(x) for x in newzipped]
    return newlist

In [11]:
reader = GraphSetReader("reddit10k") #  two types - discussion and non-discussion based ones.
graphlist = reader.get_graphs()[:100]
targetlist = reader.get_target()[:100]

In [12]:
bag = [(x, y) for x, y in zip(graphlist, targetlist) 
       #if len(x.edges()) > 50 and len(x.edges()) < 80
      ]

In [13]:
embed_x = getembedding([x for x, y in bag])
embed_y = [y for x,y in bag]

In [14]:
X_train, X_test, y_train, y_test = train_test_split(
    embed_x, embed_y, train_size=0.5, test_size=0.5, random_state=42)
print(
"x_train length:", len(X_train),"\n"
"x_test length:", len(X_test)
)

x_train length: 50 
x_test length: 50


In [15]:
def pair_sampler(sample_split, classpick):
    newbag = []
    sample = []
    loader = []

    for c in classpick:
        for x,y in sample_split:
            if y==c:
                newbag.append((x, y))
                break

    for img1, y1 in newbag:
        for img2, y2 in newbag:
            if y1 == y2:
                sample.append((img1, img2, 1)) # label 0 if pairs are same 
            if y1 != y2:
                sample.append((img1, img2, 0)) # label 1 if pairs are not same
    
    for img1, img2, y in sample:
        dataload = Data(
            img1=img1,
            img2=img2,
            y=y
        )
        loader.append(dataload)

    return newbag, loader

def validation_sampler(sample_split, classpick=None):
    newbag = []
    loader = []

    if classpick:
    
        for c in classpick:
            for x,y in sample_split:
                if y==c:
                    newbag.append((x, y))
                    break
                    
    else:
        newbag = sample_split

                
    for img, y in newbag:
        dataload = Data(
            img=img,
            y=y
        )
        loader.append(dataload)
    
    return loader

In [16]:
validationload, trainload = pair_sampler(
    zip(X_train, y_train), 
    [0,0,0,1,1,1,1],
    #np.random.binomial(1, 0.4, 5)
)
train_loader = DataLoader(trainload, batch_size=1, shuffle=False)
print(len(trainload))

49


In [17]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
    def forward(self, output1, output2, label):
        #euclidean_distance = torch.cdist(output1, output2)
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

In [18]:
# from https://github.com/delijati/pytorch-siamese/blob/master/net.py
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork,self).__init__()
        self.cnn1 = nn.Sequential(
            nn.Conv1d(16, 32, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(32),
            
            nn.Conv1d(32, 32, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(32),
            
            nn.Conv1d(32, 32, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(32),
        )

        self.fc1 = nn.Sequential(
            nn.Linear(3008, 500),
            nn.ReLU(inplace=True),
            nn.Linear(500,500),
            nn.ReLU(inplace=True),
            nn.Linear(500, 2)
            
        )

    def forward_once(self, x):        
        x = self.cnn1(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
    
        return x

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2

In [None]:
net = SiameseNetwork()
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(), lr=0.01)

counter = []
loss_history = [] 
iteration_number= 0

for epoch in range(0,100):
    net.train()
    
    for i, data in enumerate(train_loader,0):
        optimizer.zero_grad()
        output1, output2 = net(data.img1, data.img2)
        loss_contrastive = criterion(output1, output2, data.y)
        loss_contrastive.backward()
        optimizer.step()
        
        if i %10 == 0 :
            print("Epoch number {}\n Current loss {}\n".format(epoch,loss_contrastive.item()))
            iteration_number +=10
            counter.append(iteration_number)
            loss_history.append(loss_contrastive.item())

Epoch number 0
 Current loss 3.9999942779541016

Epoch number 0
 Current loss 197.32015991210938

Epoch number 0
 Current loss 76.34922790527344

Epoch number 0
 Current loss 3892.815185546875

Epoch number 0
 Current loss 3.9999942779541016

Epoch number 1
 Current loss 3.9999942779541016

Epoch number 1
 Current loss 6.991643905639648

Epoch number 1
 Current loss 8.079567909240723

Epoch number 1
 Current loss 2.884491443634033

Epoch number 1
 Current loss 3.9999942779541016

Epoch number 2
 Current loss 3.9999942779541016

Epoch number 2
 Current loss 1.0152783393859863

Epoch number 2
 Current loss 1.9365854263305664

Epoch number 2
 Current loss 0.43034735321998596

Epoch number 2
 Current loss 3.9999942779541016

Epoch number 3
 Current loss 3.9999942779541016

Epoch number 3
 Current loss 2.426431894302368

Epoch number 3
 Current loss 0.9521488547325134

Epoch number 3
 Current loss 0.12152929604053497

Epoch number 3
 Current loss 3.9999942779541016

Epoch number 4
 Current 

Epoch number 33
 Current loss 3.9999942779541016

Epoch number 33
 Current loss 0.14430008828639984

Epoch number 33
 Current loss 0.19563239812850952

Epoch number 33
 Current loss 0.7199106812477112

Epoch number 33
 Current loss 3.9999942779541016

Epoch number 34
 Current loss 3.9999942779541016

Epoch number 34
 Current loss 0.13839678466320038

Epoch number 34
 Current loss 0.21049699187278748

Epoch number 34
 Current loss 0.7284717559814453

Epoch number 34
 Current loss 3.9999942779541016

Epoch number 35
 Current loss 3.9999942779541016

Epoch number 35
 Current loss 0.13142810761928558

Epoch number 35
 Current loss 0.22156953811645508

Epoch number 35
 Current loss 0.7192065715789795

Epoch number 35
 Current loss 3.9999942779541016

Epoch number 36
 Current loss 3.9999942779541016

Epoch number 36
 Current loss 0.12351879477500916

Epoch number 36
 Current loss 0.23825062811374664

Epoch number 36
 Current loss 0.6432142853736877

Epoch number 36
 Current loss 3.9999942779

Epoch number 65
 Current loss 3.9999942779541016

Epoch number 66
 Current loss 3.9999942779541016

Epoch number 66
 Current loss 2.7926628589630127

Epoch number 66
 Current loss 1.5497777462005615

Epoch number 66
 Current loss 0.3728814125061035

Epoch number 66
 Current loss 3.9999942779541016

Epoch number 67
 Current loss 3.9999942779541016

Epoch number 67
 Current loss 2.265895128250122

Epoch number 67
 Current loss 1.3573726415634155

Epoch number 67
 Current loss 0.4236902892589569

Epoch number 67
 Current loss 3.9999942779541016

Epoch number 68
 Current loss 3.9999942779541016

Epoch number 68
 Current loss 2.7888736724853516

Epoch number 68
 Current loss 1.3603814840316772

Epoch number 68
 Current loss 0.41164782643318176

Epoch number 68
 Current loss 3.9999942779541016

Epoch number 69
 Current loss 3.9999942779541016

Epoch number 69
 Current loss 2.4079394340515137

Epoch number 69
 Current loss 1.2532507181167603

Epoch number 69
 Current loss 0.4265110492706299



In [None]:
gt = []
pred = []
for data in train_loader:
    output1,output2 = net(data.img1, data.img2)
    euclidean_distance = F.pairwise_distance(output1, output2).item()
    if euclidean_distance < 1:
        pred.append(0)
    else:
        pred.append(1)
    
    gt.append(data.y.item())
        
AC = accuracy_score(gt, pred)
f1_grid = precision_recall_fscore_support(gt, pred, average='macro')
prec = f1_grid[0]
rec = f1_grid[1]
f1 = f1_grid[2]

print('Accuracy: ', AC)
print('precision: ', prec)
print('recall: ', rec)
print('fscore: ', f1)

In [None]:
testload = validation_sampler(
    zip(X_test, y_test),
    [0,0,0,1,1,1,0,0,1,1,0,0,0,0,1],
    #[random.randint(0,1) for _ in range(20)]
    )
test_loader = DataLoader(testload, batch_size=1, shuffle=False)

In [None]:
class0 = [x for x, y in validationload if y == 0][0]
class1 = [x for x, y in validationload if y == 1][0]

gt = []
pred = []
with torch.no_grad():
    net.eval()
    for data in test_loader:
        output1, output2 = net(class0, data.img)
        class0_euc_score = F.pairwise_distance(output1, output2).item()
        output1, output2 = net(class1, data.img)
        class1_euc_score = F.pairwise_distance(output1, output2).item()

        print(class0_euc_score, class1_euc_score)
        if class0_euc_score < class1_euc_score:
            pred.append(0)
        elif class0_euc_score > class1_euc_score:
            pred.append(1)

        gt.append(data.y[0])
        
        
AC = accuracy_score(gt, pred)
auc = roc_auc_score(gt, pred)
f1_grid = precision_recall_fscore_support(gt, pred, average='macro')
prec = f1_grid[0]
rec = f1_grid[1]
f1 = f1_grid[2]

print('Accuracy: ', AC)
print('AUC', auc)
print('precision: ', prec)
print('recall: ', rec)
print('fscore: ', f1)
print("Pred: ", pred)
print("Pred: ", gt)