#### Dataset Creation

In [1]:
import random
import torch
import numpy as np

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [2]:
import json
import os

datasets_dir = 'datasets'
ecore_json_path = os.path.join(datasets_dir, 'ecore_555/ecore_555.jsonl')
mar_json_path = os.path.join(datasets_dir, 'mar-ecore-github/ecore-github.jsonl')
modelsets_uml_json_path = os.path.join(datasets_dir, 'modelset/uml.jsonl')
modelsets_ecore_json_path = os.path.join(datasets_dir, 'modelset/ecore.jsonl')

In [3]:
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import AutoTokenizer
from sklearn.preprocessing import LabelEncoder
import fasttext
from scipy.sparse import csr_matrix

from re import finditer


SEP = ' '
def camel_case_split(identifier):
    matches = finditer('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', identifier)
    return [m.group(0) for m in matches]


def doc_tokenizer(doc):
    words = doc.split()
    # split _
    words = [w2 for w1 in words for w2 in w1.split('_') if w2 != '']
    # camelcase
    words = [w2.lower() for w1 in words for w2 in camel_case_split(w1) if w2 != '']
    return words


class TFIDFEncoder:
    def __init__(self, X=None):
        self.encoder = TfidfVectorizer(
            lowercase=False, tokenizer=doc_tokenizer, min_df=3
        )

        if X:
            self.encode(X)

    def encode(self, X):
        # print('Fitting TFIDF')
        X_t = self.encoder.fit_transform(X)
        X_sp = csr_matrix(np.vstack([x.toarray() for x in X_t]))
        # print('TFIDF Encoded')
        return X_sp


class BertTokenizerEncoder:
    def __init__(self, name, X=None):
        self.tokenizer = AutoTokenizer.from_pretrained(name)

        if X:
            self.encode(X)
        

    def encode(self, X, batch_encode=False, percentile=100):
        # print('Tokenizing Bert')
        tokens = self.tokenizer(X)

        if batch_encode:
            lengths = [len(i) for i in tokens['input_ids']]
            size = int(np.percentile(lengths, percentile)) if percentile < 100 else max(lengths)
            if size > 512:
                print(f'WARNING: Max size is {size}. Truncating to 512')
            size = max(size, 512)
            
            tokenized_data = self.tokenizer(
                X, 
                padding=True, 
                truncation=True, 
                max_length=size
            )
        else:
            tokenized_data = self.tokenizer(X)
        # print('Bert Tokenized')

        return tokenized_data


class BertTFIDF:
    def __init__(self, name, X=None):
        self.bert = BertTokenizerEncoder(name)
        self.tfidf = TFIDFEncoder()

        if X:
            self.encode(X)

    def encode(self, X):
        X_b = [f"{SEP}".join([str(j) for j in i]) for i in self.bert.encode(X)['input_ids']]
        X_t = self.tfidf.encode(X_b)
        return X_t


class FasttextEncoder:
    def __init__(self, model_name, X=None):
        self.model = fasttext.load_model(model_name)
        if X:
            self.encode(X)

    def encode(self, X):
        def get_sentence_embedding(sentence):
            return self.model.get_sentence_vector(sentence)
        
        # print('Encoding Fasttext')
        X_t = [" ".join(doc_tokenizer(i)) for i in X]
        X_t = np.array([get_sentence_embedding(i) for i in X_t])
        # print('Fasttext Encoded')
        return X_t


class ClassLabelEncoder(LabelEncoder):
    def __init__(self, y=None) -> None:
        super().__init__()
        if y:
            self.fit(y)
    
    def encode(self, y):
        return self.fit_transform(y)

In [4]:
from typing import List
import networkx as nx
from tqdm.auto import tqdm
import pickle
import numpy as np
from random import shuffle
from sklearn.model_selection import StratifiedKFold
import random
import torch

from abc import abstractmethod

SEP = ' '


seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

EGenericType = 'EGenericType'
EPackage = 'EPackage'
EClass = 'EClass'
EAttribute = 'EAttribute'
EReference = 'EReference'
EEnum = 'EEnum'
EEnumLiteral = 'EEnumLiteral'
EOperation = 'EOperation'
EParameter = 'EParameter'
EDataType = 'EDataType'
GenericNodes = [EGenericType, EPackage]




class LangGraph(nx.DiGraph):
    def __init__(self):
        super().__init__()

    @abstractmethod
    def create_graph(self):
        pass

    @abstractmethod
    def get_graph_node_text(self, node):
        pass

    @abstractmethod
    def find_node_str_upto_distance(self, node, distance=1):
        pass

    @abstractmethod
    def get_node_texts(self, distance=1):
        pass

    def find_nodes_within_distance(self, n, distance=1):
        visited = {n: 0}
        queue = [(n, 0)]
        
        while queue:
            node, d = queue.pop(0)
            if d == distance:
                continue
            for neighbor in self.neighbors(node):
                if neighbor not in visited:
                    visited[neighbor] = d+1
                    queue.append((neighbor, d+1))
        
        visited = sorted(visited.items(), key=lambda x: x[1])
        return visited


class EcoreNxG(LangGraph):
    def __init__(
            self, 
            json_obj: dict, 
            use_type=True,
            remove_generic_nodes=True,
        ):
        super().__init__()
        self.use_type = use_type
        self.remove_generic_nodes = remove_generic_nodes
        self.json_obj = json_obj
        self.graph_id = json_obj.get('ids')
        self.graph_type = json_obj.get('model_type')
        self.label = json_obj.get('labels')
        self.is_duplicated = json_obj.get('is_duplicated')
        self.directed = json.loads(json_obj.get('graph')).get('directed')
        self.create_graph(json_obj)
        self.text = json_obj.get('txt')


    def create_graph(self, json_obj):
        generic_nodes = list()
        graph = json.loads(json_obj['graph'])
        nodes = graph['nodes']
        edges = graph['links']
        for node in nodes:
            self.add_node(node['id'], **node)
            if node['eClass'] in GenericNodes:
                generic_nodes.append(node['id'])
                
        for edge in edges:
            self.add_edge(edge['source'], edge['target'], **edge)
        
        if self.remove_generic_nodes:
            self.remove_nodes_from(generic_nodes)
    
    def get_graph_node_text(self, node):
        data = self.nodes[node]
        node_class = data.get('eClass')
        node_name = data.get('name', '')

        if self.use_type:
            return f'{node_class}({node_name})'

        return node_name


    def find_node_str_upto_distance(self, node, distance=1):
        nodes_with_distance = self.find_nodes_within_distance(
            node, 
            distance=distance
        )
        d2n = {d: set() for _, d in nodes_with_distance}
        for n, d in nodes_with_distance:
            node_text = self.get_graph_node_text(n)
            if node_text:
                d2n[d].add(node_text)
        
        d2n = sorted(d2n.items(), key=lambda x: x[0])
        node_buckets = [f"{SEP}".join(nbs) for _, nbs in d2n]
        path_str = " | ".join(node_buckets)
        
        return path_str


    def get_node_texts(self, distance=1):
        node_texts = []
        for node in self.nodes:
            node_texts.append(
                self.find_node_str_upto_distance(node, distance=distance)
            )
        
        return node_texts
    
        
    def __repr__(self):
        return f'{self.json_obj}\nGraph({self.graph_id}, nodes={self.number_of_nodes()}, edges={self.number_of_edges()})'


class Dataset:
    def __init__(
            self, 
            dataset_name: str, 
            dataset_dir = datasets_dir,
            save_dir = 'datasets/pickles',
            reload=False,
            remove_duplicates=False,
            use_type=False,
            remove_generic_nodes=False,
            extension='.jsonl'
        ):
        self.name = dataset_name
        self.dataset_dir = dataset_dir
        self.save_dir = save_dir
        self.extension = extension
        os.makedirs(save_dir, exist_ok=True)

        dataset_exists = os.path.exists(os.path.join(save_dir, f'{dataset_name}.pkl'))
        if reload or not dataset_exists:
            self.graphs: List[EcoreNxG] = []
            data_path = os.path.join(dataset_dir, dataset_name)
            for file in os.listdir(data_path):
                if file.endswith(self.extension) and file.startswith('ecore'):
                    json_objects = json.load(open(os.path.join(data_path, file)))
                    self.graphs += [
                        EcoreNxG(
                            g, 
                            use_type=use_type, 
                            remove_generic_nodes=remove_generic_nodes
                        ) for g in tqdm(json_objects, desc=f'Loading {dataset_name.title()}')
                    ]
            self.save()
        
        else:
            self.load()
        
        if remove_duplicates:
            self.remove_duplicates()

        print(f'Graphs: {len(self.graphs)}')


    def remove_duplicates(self):
        self.graphs = self.dedup()

    def dedup(self) -> List[EcoreNxG]:
        return [g for g in self.graphs if not g.is_duplicated]
    
    
    def get_train_test_split(self, train_size=0.8):
        n = len(self.graphs)
        train_size = int(n * train_size)
        idx = list(range(n))
        shuffle(idx)
        train_idx = idx[:train_size]
        test_idx = idx[train_size:]
        return train_idx, test_idx
    

    def k_fold_split(
            self,  
            k=10
        ):
        kfold = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)
        n = len(self.graphs)
        for train_idx, test_idx in kfold.split(np.zeros(n), np.zeros(n)):
            yield train_idx, test_idx


    @property
    def data(self):
        X, y = [], []
        for g in self.graphs:
            X.append(g.text)
            y.append(g.label)
        
        return X, y

    def __repr__(self):
        return f'Dataset({self.name}, graphs={len(self.graphs)})'
    
    def __getitem__(self, key):
        return self.graphs[key]
    
    def __iter__(self):
        return iter(self.graphs)
    
    def __len__(self):
        return len(self.graphs)
    
    def save(self):
        print(f'Saving {self.name} to pickle')
        with open(os.path.join(self.save_dir, f'{self.name}.pkl'), 'wb') as f:
            pickle.dump(self.graphs, f)
        print(f'Saved {self.name} to pickle')


    def load(self):
        print(f'Loading {self.name} from pickle')
        with open(os.path.join(self.save_dir, f'{self.name}.pkl'), 'rb') as f:
            self.graphs = pickle.load(f)
        
        print(f'Loaded {self.name} from pickle')
    

reload = False
ecore = Dataset('ecore_555', reload=reload)
modelset = Dataset('modelset', reload=reload, remove_duplicates=True)
mar = Dataset('mar-ecore-github', reload=reload)


datasets = {
    'ecore': ecore,
    'modelset': modelset,
    'mar': mar
}

Loading Ecore_555:   0%|          | 0/548 [00:00<?, ?it/s]

Saving ecore_555 to pickle
Saved ecore_555 to pickle
Graphs: 548


Loading Modelset:   0%|          | 0/4127 [00:00<?, ?it/s]

Saving modelset to pickle
Saved modelset to pickle
Graphs: 2043


Loading Mar-Ecore-Github:   0%|          | 0/18110 [00:00<?, ?it/s]

Saving mar-ecore-github to pickle
Saved mar-ecore-github to pickle
Graphs: 18110


In [6]:
ecore[0].find_node_str_upto_distance(8, 2)

'Article | Entry BIBTEX | key Publisher Journal Chapter Url Type MastersThesis Title Inproceedings Volume Address Authors Proceedings Issn Day LocatedElement Institution Number Manual Edition Text Inbook Techreport Month AbstractField Field PhdThesis Howpublished Year Booklet Isbn Doi Organization Series AuthorUrls School Book Editor Bibtex Incollection fields Misc Pages Note BookTitle'

#### Training Fasttext

##### Fasttext classification

In [None]:
import fasttext
from sklearn.metrics import accuracy_score
from sklearn.metrics import balanced_accuracy_score

for name, dataset in datasets.items():
    if name not in ['ecore', 'modelset']:
        continue
    print("Dataset: ", name)
    i = 0
    accuracies, bal_accuracies = [], []
    for X_train, X_test, y_train, y_test in dataset.k_fold_split():
        print("Fold number: ", i+1)
        f_train = f'datasets/fasttext_train_{name}_{i}.txt'
        f_test = f'datasets/fasttext_test_{name}_{i}.txt'
        if not os.path.exists(f_train):
            with open(f_train, 'w') as f:
                for x, y in zip(X_train, y_train):
                    x = " ".join(doc_tokenizer(x))
                    f.write(f"__label__{y} {x}\n")
        
        if not os.path.exists(f_test):
            with open(f_test, 'w') as f:
                for x, y in zip(X_test, y_test):
                    x = " ".join(doc_tokenizer(x))
                    f.write(f"__label__{y} {x}\n")
        
        if os.path.exists(f'models/{name}_{i}.bin'):
            model = fasttext.load_model(f'models/{name}_{i}.bin')
        else:
            model = fasttext.train_supervised(
                input=f_train, 
                epoch=100, 
                lr=0.2, 
                wordNgrams=2, 
            )
            model.save_model(f'models/{name}_{i}.bin')        
        y_pred = model.predict([i.strip() for i in open(f_test).readlines()])[0]
        y_true = [i.split()[0].split('__label__')[1] for i in open(f_test).readlines()]
        y_pred = [i[0].split('__label__')[1] for i in y_pred]


        accuracy = accuracy_score(y_true, y_pred)
        bal_accuracy = balanced_accuracy_score(y_true, y_pred)
        print(f"Accuracy: {accuracy}, Balanced Accuracy: {bal_accuracy}")
        accuracies.append(accuracy)
        bal_accuracies.append(bal_accuracy)

        i += 1            
    print(f"Average Accuracy: {np.mean(accuracies)}, Average Balanced Accuracy: {np.mean(bal_accuracies)}")
        

##### Fasttext word embeddings

In [42]:
X_udata = list(set([g.text for dataset in datasets.values() for g in dataset]))
X_udata = [f"{SEP}".join(doc_tokenizer(x)) for x in X_udata]
f_udata = 'datasets/fasttext_udata.txt'
with open(f'{f_udata}', 'w') as f:
    for x in X_udata:
        f.write(f"{x}\n")


In [57]:
model = fasttext.train_unsupervised(
    input=f_udata, 
    epoch=500, 
    lr=0.1,
    minn=2,
    maxn=5,
    dim=128
)
model.save_model("models/uml_fasttext.bin")

Read 0M words
Number of words:  8120
Number of labels: 0
Progress: 100.0% words/sec/thread:    7026 lr: -0.000001 avg.loss:  1.111646 ETA:   0h 0m 0s 60.3% words/sec/thread:    7034 lr:  0.039708 avg.loss:  1.177690 ETA:   0h 3m12s100.0% words/sec/thread:    7026 lr:  0.000000 avg.loss:  1.111496 ETA:   0h 0m 0s


#### Model Encoding

In [102]:
tf_idf_encoder = TFIDFEncoder()
bert_encoder = BertTokenizerEncoder('bert-base-uncased')
bert_tfidf_encoder = BertTFIDF('bert-base-uncased')
fasttext_encoder = FasttextEncoder('models/uml_fasttext.bin')
class_label_encoder = ClassLabelEncoder()

In [103]:
from sklearn import svm
from sklearn.metrics import (
    accuracy_score, 
    balanced_accuracy_score
)
from typing import Union


def train_svm(dataset: Dataset, encoder: Union[TFIDFEncoder, BertTFIDF, FasttextEncoder]):
    accuracies, bal_accuracies = [], []
    for train_idx, test_idx in dataset.k_fold_split():
        X = encoder.encode(dataset.data[0])
        y = class_label_encoder.encode(dataset.data[1])

        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        
        svm_classifier = svm.SVC(kernel='linear')  # You can change the kernel as needed
        svm_classifier.fit(X_train, y_train)
        # Predict on the test set
        y_pred = svm_classifier.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)
        # print(f'SVM Classifier Accuracy: {accuracy}')
        bal_accuracy = balanced_accuracy_score(y_test, y_pred)
        # print(f'SVM Classifier Balanced Accuracy: {bal_accuracy}')

        accuracies.append(accuracy)
        bal_accuracies.append(bal_accuracy)
    
    print(f'Mean Accuracy: {np.mean(accuracies)}')
    print(f'Mean Balanced Accuracy: {np.mean(bal_accuracies)}')


In [None]:
train_svm(modelset, tf_idf_encoder)

In [106]:
model.get_nearest_neighbors('petrinet', k=5)

[(0.5969380140304565, 'petrinetv3'),
 (0.5963557362556458, 'petrinetv1'),
 (0.5946762561798096, 'petrinetv2'),
 (0.5399251580238342, 'petri'),
 (0.5047121047973633, 'tokens')]

In [115]:
from transformers import Trainer

In [6]:
from transformers import BertTokenizer
import numpy as np

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def split_into_chunks(text, max_length=512):
    tokens = tokenizer.tokenize(text)
    chunks = [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)]
    return [' '.join(chunk) for chunk in chunks]

# # Example usage
long_text = max(modelset, key=lambda x: len(x.text)).text
chunks = split_into_chunks(long_text)
len(chunks)

15

In [43]:
from transformers import (
    Trainer, 
    TrainingArguments
)
from transformers import (
    AutoModelForSequenceClassification, 
    AutoTokenizer
)
import torch
import numpy as np
import random
from data_loading.dataset import Dataset
from settings import device, seed
from sklearn.preprocessing import LabelEncoder
from trainers.metrics import compute_metrics


random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

max_length_map = {
    'bert-base-uncased': 512,
    'allenai/longformer-base-4096': 4096
}


# Create your dataset
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.inputs = tokenizer(
            texts, 
            return_tensors='pt', 
            truncation=True, 
            padding='max_length', 
            max_length=max_length
        )
        self.inputs['labels'] = torch.tensor(labels, dtype=torch.long)
 

    def __len__(self):
        return len(self.inputs['input_ids'])
    

    def __getitem__(self, index):
        item = {key: val[index] for key, val in self.inputs.items()}
        return item


def train_hf(model_name, model_ds: Dataset, epochs):
    max_len = max_length_map[model_name]
    i = 0
    print(f'Device used: {device}')

    for train_idx, test_idx in model_ds.k_fold_split():
        print(f'Fold number: {i+1}')
        X, y = model_ds.data
        print(f'X: {len(X)}, y: {len(y)}')
        y = LabelEncoder().fit_transform(y)
        X_train, X_test = [X[i] for i in train_idx], [X[i] for i in test_idx]
        y_train, y_test = [y[i] for i in train_idx], [y[i] for i in test_idx]

        print(f'Train: {len(X_train)}, Test: {len(X_test)}')


        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(set(y)))
        model.to(device)

        train_ds = CustomDataset(X_train, y_train, tokenizer, max_length=max_len)
        test_ds = CustomDataset(X_test, y_test, tokenizer, max_length=max_len)

        # Training arguments
        training_args = TrainingArguments(
            output_dir='./results',
            num_train_epochs=epochs,
            eval_strategy="epoch",
            save_strategy="epoch",
            per_device_train_batch_size=2,
            per_device_eval_batch_size=2,
            warmup_steps=500,
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=10,
            load_best_model_at_end=True,
            save_total_limit=1,
            fp16=True,
            seed=42
        )

        # Trainer
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_ds,
            eval_dataset=test_ds,
            compute_metrics=compute_metrics            
        )

        # Train the model
        trainer.train()
        results = trainer.evaluate()
        print(results)

        i += 1
        break

In [None]:
train_hf('bert-base-uncased', modelset, 10)

In [9]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from settings import device
from data_loading.dataset import EncodingDataset

model_name = 'bert-base-uncased'
max_len = 512

i = 0
for train_idx, test_idx in modelset.k_fold_split():
    print(f'Fold number: {i+1}')
    X, y = modelset.data
    print(f'X: {len(X)}, y: {len(y)}')
    y = LabelEncoder().fit_transform(y)
    X_train, X_test = [X[i] for i in train_idx], [X[i] for i in test_idx]
    y_train, y_test = [y[i] for i in train_idx], [y[i] for i in test_idx]

    print(f'Train: {len(X_train)}, Test: {len(X_test)}')

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained('results/checkpoint-1380', num_labels=len(set(y)))
    model.to(device)

    train_ds = EncodingDataset(tokenizer, X_train, y_train, max_length=max_len)
    test_ds = EncodingDataset(tokenizer, X_test, y_test, max_length=max_len)

    break

Fold number: 1
X: 2043, y: 2043
Train: 1838, Test: 205


In [49]:
test_ds[:]['input_ids'].shape

torch.Size([205, 512])

In [51]:
with torch.no_grad():
    model.eval()
    #### Put vaues of custom data on device

    test_ds = {k: v.to(device) for k, v in test_ds[:].items()}

    outputs = model(**test_ds)
    pred_classes = torch.argmax(outputs.logits, dim=1)

In [53]:
y_pred = pred_classes.cpu().numpy()

[4, 27, 10]

In [26]:
from transformers import AutoModel, AutoTokenizer
import torch
from typing import List, Union
from data_loading.dataset import EncodingDataset
from torch.utils.data import DataLoader
from settings import device


from abc import abstractmethod
from typing import List, Union


class Embedder:

    def ___init__(self):
        self.finetuned = False

    @abstractmethod
    def embed(self, text: Union[str, List[str]], aggregate='mean'):
        pass


class BertEmbedder(Embedder):
    def __init__(self, model_name, ckpt=None):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(ckpt if ckpt else model_name)
        self.model.to(device)
        self.finetuned = ckpt is not None
    
    def embed(self, text: Union[str, List[str]], aggregate='mean'):
        dataset = EncodingDataset(self.tokenizer, texts=text)
        loader = DataLoader(dataset, batch_size=128)

        with torch.no_grad():
            embeddings = []
            for batch in loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                outputs = self.model(input_ids, attention_mask)
                embeddings.append(outputs.last_hidden_state)
                
            
            embeddings = torch.cat(embeddings, dim=0)
            if aggregate == 'mean':
                embeddings = embeddings.mean(dim=1)
            elif aggregate == 'max':
                embeddings = embeddings.max(dim=1)
            elif aggregate == 'cls':
                embeddings = embeddings[:, 0, :]
            elif aggregate == 'pool':
                embeddings = embeddings.mean(dim=1)
            else:
                raise ValueError(f'Unknown aggregation method: {aggregate}') 
        
        return embeddings.cpu()

In [25]:
ft_embedder = BertEmbedder('bert-base-uncased', 'results/checkpoint-1380')
bert_embedder = BertEmbedder('bert-base-uncased')

In [12]:
texts = modelset[0].get_node_texts()
len(texts), modelset[0].number_of_nodes()

(27, 27)

In [None]:
modelset[0]

In [98]:
from torch_geometric.data import Data


class TorchGraph:
    def __init__(
            self, 
            graph: LangGraph, 
            embedder: Embedder,
            save_dir: str,
        ):
        self.graph = graph
        self.embedder = embedder
        
        self.save_dir = save_dir
        self.process_graph()
    

    def process_graph(self):
        if self.load():
            return
        texts = self.graph.get_node_texts()
        self.embeddings = self.embedder.embed(texts)
        self.edge_index = torch.tensor(
            list(self.graph.edges), dtype=torch.long).t().contiguous()
        self.save()
    

    @property
    def name(self):
        return '.'.join(self.graph.graph_id.replace('/', '_').split('.')[:-1])


    @property
    def save_path(self):
        path = os.path.join(self.save_dir, f'{self.name}')
        if self.embedder.finetuned:
            path = f'{path}_finetuned'
        return path


    def load(self):
        if os.path.exists(self.save_path):
            self.embeddings = torch.load(f"{self.save_path}/embeddings.pt")
            self.edge_index = torch.load(f"{self.save_path}/edge_index.pt")
            return True
        return False
    

    def save(self):
        os.makedirs(self.save_path, exist_ok=True)
        torch.save(self.embeddings, f"{self.save_path}/embeddings.pt")
        torch.save(self.edge_index, f"{self.save_path}/edge_index.pt")


class GraphDataset(torch.utils.data.Dataset):
    def __init__(
            self, 
            models_dataset: Dataset,
            embedder: Embedder,
            save_dir='datasets/graph_data',
        ):
        self.save_dir = f'{save_dir}/{models_dataset.name}'
        os.makedirs(self.save_dir, exist_ok=True)
        self.graphs = [
            TorchGraph(g, embedder, save_dir=self.save_dir) 
            for g in tqdm(models_dataset, desc=f'Processing {models_dataset.name}')
        ]

        self._c = {label:j for j, label in enumerate({g.label for g in models_dataset})}
        self.labels = torch.tensor([self._c[g.label] for g in models_dataset], dtype=torch.long)
        self.num_classes = len(self._c)
        self.num_features = self.graphs[0].embeddings.shape[-1]

    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, index: int):
        return Data(
            x=self.graphs[index].embeddings,
            edge_index=self.graphs[index].edge_index,
            y=self.labels[index]
        )

In [99]:
graph_dataset_ft = GraphDataset(modelset, ft_embedder)
graph_dataset = GraphDataset(modelset, bert_embedder)

Processing modelset:   0%|          | 0/2043 [00:00<?, ?it/s]

Processing modelset:   0%|          | 0/2043 [00:00<?, ?it/s]

In [None]:
train_idx, test_idx = modelset.get_train_test_split()

In [109]:
train_dataset = [graph_dataset_ft[i] for i in train_idx]
test_dataset = [graph_dataset_ft[i] for i in test_idx]

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, global_mean_pool
from torch_geometric.loader import DataLoader
from torch import nn


class GraphSAGEBlock(nn.Module):
    def __init__(self, in_dim, out_dim, dropout):
        super(GraphSAGEBlock, self).__init__()
        self.conv = SAGEConv(in_dim, out_dim, aggr='sum')
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        x = self.relu(x)
        x = self.dropout(x)
        return x


class GraphSAGE(nn.Module):
    def __init__(
            self, 
            input_dim, 
            hidden_dim, 
            output_dim, 
            num_layers, 
            dropout=0.1
        ):
        super(GraphSAGE, self).__init__()
        self.gnn = nn.ModuleList([
            GraphSAGEBlock(
                input_dim if i == 0 else hidden_dim, 
                hidden_dim, 
                dropout
            ) for i in range(num_layers)
        ])
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch):
        for gnn_layer in self.gnn:
            x = gnn_layer(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)


def train(loader):
    avg_loss = 0
    model.train()
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()

        avg_loss += loss.item()
    
    return avg_loss / len(loader)


def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)



train_idx, test_idx = modelset.get_train_test_split()

batch_size = 32
hidden_dim = 64
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    shuffle=True
)

# Training the model
model = GraphSAGE(
    input_dim=graph_dataset_ft.num_features,
    hidden_dim=hidden_dim, 
    output_dim=graph_dataset_ft.num_classes,
    num_layers=3,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [None]:
for epoch in range(1, 201):
    loss = train(train_dataloader)
    test_acc = test(test_dataloader)
    print(f'Epoch: {epoch:03d}, Loss: {loss:03f} Test Acc: {test_acc:.4f}')