In [None]:
!python -m pip install --upgrade pip
!python -m pip install pip==20.2.4
!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
!pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
!pip install torch-geometric

In [None]:
import numpy as np
import networkx as nx
import os
import pandas as pd
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Embedding
from torch.nn import Parameter
from torch_geometric.data import Data,DataLoader
from torch_geometric.nn import GCNConv
from torch_geometric.utils.convert import to_networkx
from torch_geometric.utils import to_undirected

In [None]:
os.environ['KAGGLE_USERNAME'] = "karthikapv" # username from the json file
os.environ['KAGGLE_KEY'] = "cc11b8fcbb2e177d31cd566bbabe382a" # key from the json file
!kaggle datasets download -d ellipticco/elliptic-data-set
!unzip elliptic-data-set.zip
!mkdir elliptic_bitcoin_dataset_cont

In [None]:
import time
import tracemalloc


def get_memory_and_execution_time_details(func, is_teacher):
    tracemalloc.start()
    start_time = time.time()
    func(teacher=is_teacher)
    exec_time = time.time() - start_time
    print("Model Evaluation Time: ")
    print(exec_time)
    current, peak = tracemalloc.get_traced_memory()
    print(f"Current memory usage is {current / 10 ** 3}KB; Peak was {peak / 10 ** 3}KB")
    tracemalloc.stop()

    return current, peak, exec_time

In [None]:
import torch
from torch.nn import Parameter
from torch_geometric.nn import GCNConv
import torch.nn.functional as F


class GCN(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels, use_skip=False):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, hidden_channels[0])
        self.conv2 = GCNConv(hidden_channels[0], 2)
        self.use_skip = use_skip
        if self.use_skip:
            self.weight = torch.nn.init.xavier_normal_(Parameter(torch.Tensor(num_node_features, 2)))

    def forward(self, data):
        x = self.conv1(data.x, data.edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, data.edge_index)
        if self.use_skip:
            x = F.softmax(x + torch.matmul(x, self.weight), dim=-1)
        else:
            x = F.softmax(x, dim=-1)
        return x

    def embed(self, data):
        x = self.conv1(data.x, data.edge_index)
        return x

In [None]:
import os
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import f1_score, precision_score, recall_score
from torch.utils.tensorboard import SummaryWriter


class BaseClass:
    """
    Basic implementation of a general Knowledge Distillation framework
    :param teacher_model (torch.nn.Module): Teacher model
    :param student_model (torch.nn.Module): Student model
    :param train_loader (torch.utils.data.DataLoader): Dataloader for training
    :param val_loader (torch.utils.data.DataLoader): Dataloader for validation/testing
    :param optimizer_teacher (torch.optim.*): Optimizer used for training teacher
    :param optimizer_student (torch.optim.*): Optimizer used for training student
    :param loss_fn (torch.nn.Module): Loss Function used for distillation
    :param temp (float): Temperature parameter for distillation
    :param distil_weight (float): Weight paramter for distillation loss
    :param device (str): Device used for training; 'cpu' for cpu and 'cuda' for gpu
    :param log (bool): True if logging required
    :param logdir (str): Directory for storing logs
    """

    def __init__(
            self,
            teacher_model,
            student_model,
            train_loader,
            val_loader,
            optimizer_teacher,
            optimizer_student,
            loss_fn=nn.KLDivLoss(),
            temp=20.0,
            distil_weight=0.5,
            device="cpu",
            log=False,
            logdir="./Experiments",
    ):

        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer_teacher = optimizer_teacher
        self.optimizer_student = optimizer_student
        self.temp = temp
        self.distil_weight = distil_weight
        self.log = log
        self.logdir = logdir

        if self.log:
            self.writer = SummaryWriter(logdir)

        try:
            torch.Tensor(0).to(device)
            self.device = device
        except:
            print(
                "Either an invalid device or CUDA is not available. Defaulting to CPU."
            )
            self.device = torch.device("cpu")

        try:
            self.teacher_model = teacher_model.to(self.device)
        except:
            print("Warning!!! Teacher is NONE.")
        self.student_model = student_model.to(self.device)
        try:
            self.loss_fn = loss_fn.to(self.device)
            self.ce_fn = nn.CrossEntropyLoss().to(self.device)
        except:
            self.loss_fn = loss_fn
            self.ce_fn = nn.CrossEntropyLoss()
            print("Warning: Loss Function can't be moved to device.")

    def train_teacher(
            self,
            epochs = 10,
            plot_losses=True,
            save_model=True,
            save_model_pth="./models/teacher.pt",
    ):
        """
        Function that will be training the teacher
        :param epochs (int): Number of epochs you want to train the teacher
        :param plot_losses (bool): True if you want to plot the losses
        :param save_model (bool): True if you want to save the teacher model
        :param save_model_pth (str): Path where you want to store the teacher model
        """
        self.teacher_model.train()
        loss_arr = []
        illicit_f1_arr = []
        micro_avg_f1_arr = []
        illicit_precision_arr = []
        micro_avg_precision_arr = []
        illicit_recall_arr = []
        micro_avg_recall_arr = []
        length_of_dataset = len(self.train_loader.dataset)
        best_acc = 0.0
        self.best_teacher_model_weights = deepcopy(self.teacher_model.state_dict())

        save_dir = os.path.dirname(save_model_pth)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        print("Training Teacher... ")

        for ep in range(epochs):
            epoch_loss = 0.0
            correct = 0
            torch.manual_seed(ep)
            np.random.seed(42)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            for data in self.train_loader:

                data.x = data.x.to(self.device)
                label = data.y.to(self.device)
                mask = data.mask

                out = self.teacher_model(data)

                if isinstance(out, tuple):
                    out = out[0]

                pred = out.argmax(dim=1, keepdim=True)
                correct += pred.eq(label.view_as(pred)).sum().item()
                illicit_f1_arr.append(f1_score(pred[mask], label[mask], pos_label=1))
                micro_avg_f1_arr.append(f1_score(pred[mask], label[mask], average='micro'))
                illicit_precision_arr.append(precision_score(pred[mask], label[mask], pos_label=1))
                micro_avg_precision_arr.append(precision_score(pred[mask], label[mask], average='micro'))
                illicit_recall_arr.append(recall_score(pred[mask], label[mask], pos_label=1))
                micro_avg_recall_arr.append(recall_score(pred[mask], label[mask], average='micro'))

                loss = self.ce_fn(out[mask], label[mask])

                self.optimizer_teacher.zero_grad()
                loss.backward()
                self.optimizer_teacher.step()

                epoch_loss += loss

            epoch_acc = correct / length_of_dataset
            if epoch_acc > best_acc:
                best_acc = epoch_acc
                self.best_teacher_model_weights = deepcopy(
                    self.teacher_model.state_dict()
                )

            if self.log:
                self.writer.add_scalar("Training loss/Teacher", epoch_loss, epochs)
                self.writer.add_scalar("Training accuracy/Teacher", epoch_acc, epochs)

            loss_arr.append(epoch_loss)
            print(
                'Epoch: {:1d}, Epoch Loss: {:.4f}, Illicit Precision: {:.4f}, Illicit Recall: '
                '{:.4f}, Illicit f1: {:.4f}, F1: {:.4f}, Precision: {:.4f}, Recall: {:.4f}' \
                    .format(ep + 1, epoch_loss, np.mean(illicit_precision_arr),
                            np.mean(illicit_recall_arr), np.mean(illicit_f1_arr), np.mean(micro_avg_f1_arr),
                            np.mean(micro_avg_precision_arr), np.mean(micro_avg_recall_arr)))

            self.post_epoch_call(ep)

        self.teacher_model.load_state_dict(self.best_teacher_model_weights)
        if save_model:
            torch.save(self.teacher_model.state_dict(), save_model_pth)
        if plot_losses:
            plt.plot(loss_arr)

    def _train_student(
            self,
            epochs = 10,
            plot_losses=True,
            save_model=True,
            save_model_pth="./models/student.pt",
    ):
        """
        Function to train student model - for internal use only.
        :param epochs (int): Number of epochs you want to train the teacher
        :param plot_losses (bool): True if you want to plot the losses
        :param save_model (bool): True if you want to save the student model
        :param save_model_pth (str): Path where you want to save the student model
        """
        self.teacher_model.eval()
        self.student_model.train()
        loss_arr = []
        illicit_f1_arr = []
        micro_avg_f1_arr = []
        illicit_precision_arr = []
        micro_avg_precision_arr = []
        illicit_recall_arr = []
        micro_avg_recall_arr = []
        length_of_dataset = len(self.train_loader.dataset)
        best_acc = 0.0
        self.best_student_model_weights = deepcopy(self.student_model.state_dict())

        save_dir = os.path.dirname(save_model_pth)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        print("Training Student...")

        for ep in range(epochs):
            epoch_loss = 0.0
            correct = 0
            torch.manual_seed(ep)
            np.random.seed(ep)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            for data in self.train_loader:

                data.x = data.x.to(self.device)
                label = data.y.to(self.device)
                mask = data.mask

                student_out = self.student_model(data)
                teacher_out = self.teacher_model(data)

                loss = self.calculate_kd_loss(student_out[mask], teacher_out[mask], label[mask])

                if isinstance(student_out, tuple):
                    student_out = student_out[0]

                pred = student_out.argmax(dim=1, keepdim=True)
                correct += pred.eq(label.view_as(pred)).sum().item()
                illicit_f1_arr.append(f1_score(pred[mask], label[mask], pos_label=1))
                micro_avg_f1_arr.append(f1_score(pred[mask], label[mask], average='micro'))
                illicit_precision_arr.append(precision_score(pred[mask], label[mask], pos_label=1))
                micro_avg_precision_arr.append(precision_score(pred[mask], label[mask], average='micro'))
                illicit_recall_arr.append(recall_score(pred[mask], label[mask], pos_label=1))
                micro_avg_recall_arr.append(recall_score(pred[mask], label[mask], average='micro'))

                self.optimizer_student.zero_grad()
                loss.backward()
                self.optimizer_student.step()

                epoch_loss += loss.item()

            epoch_acc = correct / length_of_dataset
            if epoch_acc > best_acc:
                best_acc = epoch_acc
                self.best_student_model_weights = deepcopy(
                    self.student_model.state_dict()
                )

            if self.log:
                self.writer.add_scalar("Training loss/Student", epoch_loss, epochs)
                self.writer.add_scalar("Training accuracy/Student", epoch_acc, epochs)

            loss_arr.append(epoch_loss)
            print(
                'Epoch: {:1d}, Epoch Loss: {:.4f}, Illicit Precision: {:.4f}, Illicit Recall: '
                '{:.4f}, Illicit f1: {:.4f}, F1: {:.4f}, Precision: {:.4f}, Recall: {:.4f}' \
                    .format(ep + 1, epoch_loss, np.mean(illicit_precision_arr),
                            np.mean(illicit_recall_arr), np.mean(illicit_f1_arr), np.mean(micro_avg_f1_arr),
                            np.mean(micro_avg_precision_arr), np.mean(micro_avg_recall_arr)))

        self.student_model.load_state_dict(self.best_student_model_weights)
        if save_model:
            torch.save(self.student_model.state_dict(), save_model_pth)
        if plot_losses:
            plt.plot(loss_arr)

    def train_student(
            self,
            epochs = 10,
            plot_losses=True,
            save_model=True,
            save_model_pth="./models/student.pt",
    ):
        """
        Function that will be training the student
        :param epochs (int): Number of epochs you want to train the teacher
        :param plot_losses (bool): True if you want to plot the losses
        :param save_model (bool): True if you want to save the student model
        :param save_model_pth (str): Path where you want to save the student model
        """
        self._train_student(epochs, plot_losses, save_model, save_model_pth)

    def calculate_kd_loss(self, y_pred_student, y_pred_teacher, y_true):
        """
        Custom loss function to calculate the KD loss for various implementations
        :param y_pred_student (Tensor): Predicted outputs from the student network
        :param y_pred_teacher (Tensor): Predicted outputs from the teacher network
        :param y_true (Tensor): True labels
        """

        raise NotImplementedError

    def _evaluate_model(self, model, verbose=False):
        """
        Evaluate the given model's accuaracy over val set.
        For internal use only.
        :param model (nn.Module): Model to be used for evaluation
        :param verbose (bool): Display Accuracy
        """
        model.eval()
        length_of_dataset = len(self.val_loader.dataset)
        correct = 0
        outputs = []
        illicit_f1_arr = []
        micro_avg_f1_arr = []
        illicit_precision_arr = []
        micro_avg_precision_arr = []
        illicit_recall_arr = []
        micro_avg_recall_arr = []

        seed_val = 35

        with torch.no_grad():
            for data in self.train_loader:

                torch.manual_seed(seed_val)
                np.random.seed(seed_val)
                torch.backends.cudnn.deterministic = True
                torch.backends.cudnn.benchmark = False

                data.x = data.x.to(self.device)
                target = data.y.to(self.device)
                mask = data.mask

                output = model(data)

                if isinstance(output, tuple):
                    output = output[0]
                outputs.append(output)

                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                accuracy = correct / length_of_dataset
                illicit_f1_arr.append(f1_score(pred[mask], target[mask], pos_label=1))
                micro_avg_f1_arr.append(f1_score(pred[mask], target[mask], average='micro'))
                illicit_precision_arr.append(precision_score(pred[mask], target[mask], pos_label=1))
                micro_avg_precision_arr.append(precision_score(pred[mask], target[mask], average='micro'))
                illicit_recall_arr.append(recall_score(pred[mask], target[mask], pos_label=1))
                micro_avg_recall_arr.append(recall_score(pred[mask], target[mask], average='micro'))

                if verbose:
                    print("-" * 80)
                    print(f"Iteration: {seed_val-34}")
                    print("-" * 80)
                    print("Illicit F1: {:.4f}".format(f1_score(pred[mask], target[mask], pos_label=1)))
                    print("Illicit Precision: {:.4f}".format(precision_score(pred[mask], target[mask], pos_label=1)))
                    print("Illicit Recall: {:.4f}".format(recall_score(pred[mask], target[mask], pos_label=1)))
                    print("Micro Avg F1: {:.4f}".format(f1_score(pred[mask], target[mask], average='micro')))
                    print("Micro Avg Precision: {:.4f}".format(precision_score(pred[mask], target[mask], average='micro')))
                    print("Micro Avg Recall: {:.4f}".format(recall_score(pred[mask], target[mask], average='micro')))

                seed_val += 1

        print("-" * 80)
        print("-" * 80)
        print("Final Result")
        print("-" * 80)
        print("-" * 80)
        
        print("Illicit F1: {:.4f}".format(np.mean(illicit_f1_arr)))
        print("Illicit Precision: {:.4f}".format(np.mean(illicit_precision_arr)))
        print("Illicit Recall: {:.4f}".format(np.mean(illicit_recall_arr)))
        print("Micro Avg F1: {:.4f}".format(np.mean(micro_avg_f1_arr)))
        print("Micro Avg Precision: {:.4f}".format(np.mean(micro_avg_precision_arr)))
        print("Micro Avg Recall: {:.4f}".format(np.mean(micro_avg_recall_arr)))
        return outputs, accuracy

    def evaluate(self, teacher=False):
        """
        Evaluate method for printing accuracies of the trained network
        :param teacher (bool): True if you want accuracy of the teacher network
        """
        if teacher:
            model = deepcopy(self.teacher_model).to(self.device)
        else:
            model = deepcopy(self.student_model).to(self.device)
        _, accuracy = self._evaluate_model(model=model, verbose=False)

        return accuracy

    def get_parameters(self):
        """
        Get the number of parameters for the teacher and the student network
        """
        teacher_params = sum(p.numel() for p in self.teacher_model.parameters())
        student_params = sum(p.numel() for p in self.student_model.parameters())

        print("-" * 80)
        print(f"Total parameters for the teacher network are: {teacher_params}")
        print(f"Total parameters for the student network are: {student_params}")

    def post_epoch_call(self, epoch):
        """
        Any changes to be made after an epoch is completed.
        :param epoch (int) : current epoch number
        :return            : nothing (void)
        """

        pass

In [None]:
import torch.nn as nn
import torch.nn.functional as F



class VanillaKD(BaseClass):
    """
    Original implementation of Knowledge distillation from the paper "Distilling the
    Knowledge in a Neural Network" https://arxiv.org/pdf/1503.02531.pdf
    :param teacher_model (torch.nn.Module): Teacher model
    :param student_model (torch.nn.Module): Student model
    :param train_loader (torch.utils.data.DataLoader): Dataloader for training
    :param val_loader (torch.utils.data.DataLoader): Dataloader for validation/testing
    :param optimizer_teacher (torch.optim.*): Optimizer used for training teacher
    :param optimizer_student (torch.optim.*): Optimizer used for training student
    :param loss_fn (torch.nn.Module):  Calculates loss during distillation
    :param temp (float): Temperature parameter for distillation
    :param distil_weight (float): Weight paramter for distillation loss
    :param device (str): Device used for training; 'cpu' for cpu and 'cuda' for gpu
    :param log (bool): True if logging required
    :param logdir (str): Directory for storing logs
    """

    def __init__(
        self,
        teacher_model,
        student_model,
        train_loader,
        val_loader,
        optimizer_teacher,
        optimizer_student,
        loss_fn=nn.MSELoss(),
        temp=20.0,
        distil_weight=0.5,
        device="cpu",
        log=False,
        logdir="./Experiments",
    ):
        super(VanillaKD, self).__init__(
            teacher_model,
            student_model,
            train_loader,
            val_loader,
            optimizer_teacher,
            optimizer_student,
            loss_fn,
            temp,
            distil_weight,
            device,
            log,
            logdir,
        )

    def calculate_kd_loss(self, y_pred_student, y_pred_teacher, y_true):
        """
        Function used for calculating the KD loss during distillation
        :param y_pred_student (torch.FloatTensor): Prediction made by the student model
        :param y_pred_teacher (torch.FloatTensor): Prediction made by the teacher model
        :param y_true (torch.FloatTensor): Original label
        """

        soft_teacher_out = F.softmax(y_pred_teacher / self.temp, dim=1)
        soft_student_out = F.softmax(y_pred_student / self.temp, dim=1)

        loss = (1 - self.distil_weight) * F.cross_entropy(y_pred_student, y_true)
        loss += (self.distil_weight * self.temp * self.temp) * self.loss_fn(
            soft_teacher_out, soft_student_out
        )

        return loss

In [None]:
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import to_undirected


def train_test_split():
    df_edge = pd.read_csv('elliptic_bitcoin_dataset/elliptic_txs_edgelist.csv')
    df_class = pd.read_csv('elliptic_bitcoin_dataset/elliptic_txs_classes.csv')
    df_features = pd.read_csv('elliptic_bitcoin_dataset/elliptic_txs_features.csv', header=None)

    # Setting Column name
    df_features.columns = ['id', 'time step'] + [f'trans_feat_{i}' for i in range(93)] + [f'agg_feat_{i}' for i in
                                                                                          range(72)]

    print('Number of edges: {}'.format(len(df_edge)))
    df_edge.head()

    # Get Node Index

    all_nodes = list(
        set(df_edge['txId1']).union(set(df_edge['txId2'])).union(set(df_class['txId'])).union(set(df_features['id'])))
    nodes_df = pd.DataFrame(all_nodes, columns=['id']).reset_index()

    print('Number of nodes: {}'.format(len(nodes_df)))
    nodes_df.head()

    # Fix id index

    df_edge = df_edge.join(nodes_df.rename(columns={'id': 'txId1'}).set_index('txId1'), on='txId1', how='inner') \
        .join(nodes_df.rename(columns={'id': 'txId2'}).set_index('txId2'), on='txId2', how='inner', rsuffix='2') \
        .drop(columns=['txId1', 'txId2']) \
        .rename(columns={'index': 'txId1', 'index2': 'txId2'})
    df_edge.head()

    df_class = df_class.join(nodes_df.rename(columns={'id': 'txId'}).set_index('txId'), on='txId', how='inner') \
        .drop(columns=['txId']).rename(columns={'index': 'txId'})[['txId', 'class']]
    df_class.head()

    df_features = df_features.join(nodes_df.set_index('id'), on='id', how='inner') \
        .drop(columns=['id']).rename(columns={'index': 'id'})
    df_features = df_features[['id'] + list(df_features.drop(columns=['id']).columns)]
    df_features.head()

    df_edge_time = df_edge.join(df_features[['id', 'time step']].rename(columns={'id': 'txId1'}).set_index('txId1'),
                                on='txId1', how='left', rsuffix='1') \
        .join(df_features[['id', 'time step']].rename(columns={'id': 'txId2'}).set_index('txId2'), on='txId2', how='left',
              rsuffix='2')
    df_edge_time['is_time_same'] = df_edge_time['time step'] == df_edge_time['time step2']
    df_edge_time_fin = df_edge_time[['txId1', 'txId2', 'time step']].rename(
        columns={'txId1': 'source', 'txId2': 'target', 'time step': 'time'})

    # Create csv from Dataframe

    df_features.drop(columns=['time step']).to_csv('elliptic_bitcoin_dataset_cont/elliptic_txs_features.csv', index=False, header=None)
    df_class.rename(columns={'txId': 'nid', 'class': 'label'})[['nid', 'label']].sort_values(by='nid').to_csv(
        'elliptic_bitcoin_dataset_cont/elliptic_txs_classes.csv', index=False, header=None)
    df_features[['id', 'time step']].rename(columns={'id': 'nid', 'time step': 'time'})[['nid', 'time']].sort_values(
        by='nid').to_csv('elliptic_bitcoin_dataset_cont/elliptic_txs_nodetime.csv', index=False, header=None)
    df_edge_time_fin[['source', 'target', 'time']].to_csv('elliptic_bitcoin_dataset_cont/elliptic_txs_edgelist_timed.csv', index=False,
                                                          header=None)

    # Graph Preprocessing

    node_label = df_class.rename(columns={'txId': 'nid', 'class': 'label'})[['nid', 'label']].sort_values(by='nid').merge(
        df_features[['id', 'time step']].rename(columns={'id': 'nid', 'time step': 'time'}), on='nid', how='left')
    node_label['label'] = node_label['label'].apply(lambda x: '3' if x == 'unknown' else x).astype(int) - 1
    node_label.head()

    merged_nodes_df = node_label.merge(
        df_features.rename(columns={'id': 'nid', 'time step': 'time'}).drop(columns=['time']), on='nid', how='left')
    merged_nodes_df.head()

    train_dataset = []
    test_dataset = []

    num_node_features = 0
    for i in range(49):
        nodes_df_tmp = merged_nodes_df[merged_nodes_df['time'] == i + 1].reset_index()
        nodes_df_tmp['index'] = nodes_df_tmp.index
        df_edge_tmp = df_edge_time_fin.join(
            nodes_df_tmp.rename(columns={'nid': 'source'})[['source', 'index']].set_index('source'), on='source',
            how='inner') \
            .join(nodes_df_tmp.rename(columns={'nid': 'target'})[['target', 'index']].set_index('target'), on='target',
                  how='inner', rsuffix='2') \
            .drop(columns=['source', 'target']) \
            .rename(columns={'index': 'source', 'index2': 'target'})
        x = torch.tensor(np.array(nodes_df_tmp.sort_values(by='index').drop(columns=['index', 'nid', 'label'])),
                         dtype=torch.float)
        edge_index = torch.tensor(np.array(df_edge_tmp[['source', 'target']]).T, dtype=torch.long)
        edge_index = to_undirected(edge_index)
        mask = nodes_df_tmp['label'] != 2
        y = torch.tensor(np.array(nodes_df_tmp['label']), dtype=torch.long)

        data = Data(x=x, edge_index=edge_index, mask=mask, y=y)
        num_node_features = data.num_node_features
        if i + 1 < 34:
            train_dataset.append(data)
        else:
            test_dataset.append(data)

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    return train_loader, test_loader, num_node_features

In [None]:
import torch.nn.functional as F

class Network(nn.Module):
    def __init__(self):
        super().__init__()

        self.hidden = nn.Linear(166,50 )

        self.output = nn.Linear(50,2)
        self.out = nn.Sigmoid()
        
    def forward(self, data):
        x=data.x
        x = F.relu(self.hidden(x))

        x = self.out(self.output(x))
        
        return x

In [None]:
import time
import torch.optim as optim
from torch_geometric.data import Data, DataLoader
train_loader, test_loader, num_node_features = train_test_split()



In [None]:
lr = 0.00001
weight_decay = 0.00005
teacher_model = GCN(num_node_features=num_node_features, hidden_channels=[100])
teacher_optimizer = optim.Adam(teacher_model.parameters(), lr=lr, weight_decay=weight_decay, amsgrad=True)
epochs = 10

student_model = Network()

student_optimizer = torch.optim.Adam(student_model.parameters(), lr=lr,weight_decay=weight_decay, amsgrad=True)


In [None]:
distiller = VanillaKD(teacher_model, student_model, train_loader, test_loader,
                                   teacher_optimizer, student_optimizer)
distiller.train_teacher(epochs=epochs, plot_losses=True, save_model=True,
                                  save_model_pth='./models/teacher.pt')  # Train the teacher network

distiller.train_student(epochs=epochs, plot_losses=True, save_model=True,
                                  save_model_pth='./models/student.pt')  # Train the student network

In [None]:
get_memory_and_execution_time_details(distiller.evaluate, True)  # Evaluate the teacher network

get_memory_and_execution_time_details(distiller.evaluate, False)  # Evaluate the student network

distiller.get_parameters()