In [179]:
import fnmatch
import os
import json
import sqlite3
from tqdm.auto import tqdm
import re
import pickle
import numpy as np
import networkx as nx
import random
from torch.utils.data import Dataset, DataLoader
import torch

In [10]:
gen_db_file = '../datasets/ModelClassification/modelset/datasets/dataset.genmymodel/data/genmymodel.db'
ecore_db_file = '../datasets/ModelClassification/modelset/datasets/dataset.ecore/data/ecore.db'
db_files = [ecore_db_file, gen_db_file]

def create_dataset_json():
    count = 0
    metadata = dict()
    for file in db_files:
        conn = sqlite3.connect(file)
        cursor = conn.cursor()
        cursor.execute("SELECT * FROM metadata;")
        rows = cursor.fetchall()
        for row in rows:
            try:
                metadata[row[0]] = json.loads(row[2])
            except:
                metadata[row[0]] = dict()
                count += 1
        conn.close()
    return metadata

data = create_dataset_json()
print(len(data))

10595


In [12]:
print(len(data))
print(list(data.keys())[-1], data[list(data.keys())[0]])
ecore_models = {k:v for k,v in data.items() if k.startswith('repo-ecore-all')}
gen_models = {k:v for k,v in data.items() if k.startswith('repo-genmymodel-uml')}
print(len(ecore_models), len(gen_models))

10595
repo-genmymodel-uml/data/40966914-1f3d-4f22-8334-5ef53c4cc64b.xmi {'type': ['behaviour'], 'category': ['petrinet'], 'tags': ['behaviour']}
5475 5120


In [55]:
from pyecore.resources import ResourceSet, URI
import networkx as nx


GENERALIZATION = 'generalization'
ASSOCIATION = 'association'
REFERENCE = 'reference'


def get_attributes(classifier):
    all_feats = set((feat.name, feat.eType.name) for feat in classifier.eAllStructuralFeatures() if type(feat).__name__ == 'EAttribute')
    return list(all_feats)

def get_model_root(file_name):
    rset = ResourceSet()
    resource = rset.get_resource(URI(file_name))
    mm_root = resource.contents[0]
    return mm_root


def get_ecore_data(file_name):
    rset = ResourceSet()
    resource = rset.get_resource(URI(file_name))
    mm_root = resource.contents[0]
    references = list()
    for classifier in mm_root.eClassifiers:
        # print(classifier.name, get_features(classifier))
        if type(classifier).__name__ == 'EClass':
            references.append((classifier.name, get_attributes(classifier)))
    super_types = list()
    for classifier in mm_root.eClassifiers:
        if type(classifier).__name__ == 'EClass':
            for supertype in classifier.eAllSuperTypes():
                super_types.append((classifier.name, supertype.name))
    return references, super_types


def create_nx_from_ecore(file_name):
    try:
        model_root = get_model_root(file_name)
    except Exception as e:
        return None
    if type(model_root).__name__ != 'EPackage':
        return None
    nxg = nx.DiGraph()
    for classifier in model_root.eClassifiers:
        if type(classifier).__name__ == 'EClass':
            if not nxg.has_node(classifier.name):
                nxg.add_node(classifier.name, name=classifier.name, type='class')

            classifier_attrs = set(feat.name for feat in classifier.eAllStructuralFeatures() if type(feat).__name__ == 'EAttribute')
            nxg.nodes[classifier.name]['attrs'] = list(classifier_attrs)
    
    for classifier in model_root.eClassifiers:
        if type(classifier).__name__ == 'EClass':
            for supertype in classifier.eAllSuperTypes():
                if not nxg.has_node(supertype.name):
                    nxg.add_node(supertype.name, name=supertype.name, type='class')
                nxg.add_edge(classifier.name, supertype.name, type=GENERALIZATION)
            
            for reference in classifier.eReferences:
                try:
                    if reference.eType is not None and not nxg.has_edge(classifier.name, reference.eType.name):
                        nxg.add_edge(
                            classifier.name, reference.eType.name, name=reference.name, \
                                type=REFERENCE if reference.containment else ASSOCIATION
                        )
                except Exception as e:
                    # print("ref", reference)
                    # raise(e)
                    pass
        
    return nxg


def get_graphs_from_dir(models_metadata, dir=None):
    graphs = list()
    count = 0
    models = models_metadata if dir is None else [os.path.join(dir, model) for model in models_metadata.keys()]
    for model_file_name in tqdm(models):
        try:
            g = create_nx_from_ecore(model_file_name)
            if g is not None:
                graphs.append(g)
        except Exception as e:
            print(model_file_name)
            count += 1
    print(count)
    return graphs

In [None]:
models_dir = '../datasets/ModelClassification/modelset/raw-data'
graphs = get_graphs_from_dir(ecore_models, models_dir)

In [186]:
filtered_graphs = [g for g in filter(lambda g: g.number_of_edges() >= 10, graphs)]
print(len(filtered_graphs))

2677


In [187]:
with open('ecore_modelset_graphs.pkl', 'wb') as f:
    pickle.dump(filtered_graphs, f)

In [None]:
all_ecore_dir = '/Users/junaid/Downloads/TUWien/Projects/CM-KB-Search-Project/MAR-Models-Repository/repo-github-ecore'

## Recursively get all ecore files from the directory
def get_all_ecore_files(dir):
    ecore_files = list()
    for root, _, filenames in os.walk(dir):
        for filename in fnmatch.filter(filenames, '*.ecore'):
            ecore_files.append(os.path.join(root, filename))
    return ecore_files

# all_ecore_files = get_all_ecore_files(all_ecore_dir)
# print(len(all_ecore_files))

with open('datasets/ecore_graph_pickles/all_ecore_files.pkl', 'rb') as f:
    all_ecore_files = pickle.load(f)

all_ecore_graphs = get_graphs_from_dir(all_ecore_files)

with open('datasets/ecore_graph_pickles/all_ecore_graphs.pkl', 'rb') as f:
    all_ecore_graphs = pickle.load(f)

In [209]:
filtered_all_ecore_graphs = [g for g in filter(lambda g: g.number_of_edges() >= 10, all_ecore_graphs)]
print(len(filtered_all_ecore_graphs))
print(max([g.number_of_nodes() for g in filtered_all_ecore_graphs]))

26393
701


In [19]:
def graph2str(g):
    return str(g.edges())

def remove_duplicates(graphs):
    return list({graph2str(g):g for g in graphs}.values())

def filter_graphs(graphs, min_edges=10):
    return [g for g in filter(lambda g: g.number_of_edges() >= min_edges, graphs)]

def clean_graph_set(graphs):
    graphs = remove_duplicates(graphs)
    graphs = filter_graphs(graphs)
    return graphs

def write_graphs_to_file(graphs, file_name):
    with open(file_name, 'wb') as f:
        pickle.dump(graphs, f)

def read_graphs_from_file(file_name):
    with open(file_name, 'rb') as f:
        graphs = pickle.load(f)
    return graphs

def write_clean_graphs_to_file(graphs, file_name):
    graphs = clean_graph_set(graphs)
    write_graphs_to_file(graphs, file_name)

def read_clean_graphs_from_file(file_name):
    graphs = read_graphs_from_file(file_name)
    graphs = clean_graph_set(graphs)
    return graphs

In [227]:
write_graphs_to_file(graphs, 'datasets/ecore_graph_pickles/ecore_modelset_graphs.pkl')
write_graphs_to_file(all_ecore_graphs, 'datasets/ecore_graph_pickles/all_ecore_graphs.pkl')
write_clean_graphs_to_file(graphs, 'datasets/ecore_graph_pickles/ecore_modelset_graphs_clean.pkl')
write_clean_graphs_to_file(all_ecore_graphs, 'datasets/ecore_graph_pickles/all_ecore_graphs_clean.pkl')

In [226]:
combined_graphs = graphs + all_ecore_graphs
write_graphs_to_file(combined_graphs, 'datasets/ecore_graph_pickles/combined_graphs.pkl')
write_clean_graphs_to_file(combined_graphs, 'datasets/ecore_graph_pickles/combined_graphs_clean.pkl')

In [20]:
combined_clean_graphs = read_clean_graphs_from_file('datasets/ecore_graph_pickles/combined_graphs_clean.pkl')
print(len(combined_clean_graphs))

6547


In [41]:
with open('datasets/ecore_graph_pickles/combined_graph.pkl', 'rb') as f:
    combined_graph = pickle.load(f)

In [56]:
from collections import deque

SEP = '->'
remove_extra_spaces = lambda txt: re.sub(r'\s+', ' ', txt.strip())

edge_type_map = {
    GENERALIZATION: 0,
    ASSOCIATION: 1,
    REFERENCE: 2,
}

def process_edge_for_string(graph, edge):
    u, v = edge
    edge_type = edge_type_map[graph.edges[edge]['type']]
    edge_name = graph.edges[edge]['name'] if 'name' in graph.edges[u, v] else ''
    # edge_str = (': ' + edge_name) if edge_name else ''
    # edge_str = remove_extra_spaces(edge_str)
    return remove_extra_spaces(edge_name), edge_type


def process_node_for_string(graph, node, add_attrs=True):
    assert graph.has_node(node) and 'name' in graph.nodes[node], "Node not found in graph or node name not found in node"
        
    node_name = graph.nodes[node]['name']
    node_attrs_str = ''
    if add_attrs:
        node_attrs = graph.nodes[node]['attrs'] if 'attrs' in graph.nodes[node] else []
        node_attrs_str = "(" + (("attributes=" + ', '.join(node_attrs)) if len(node_attrs) else "") + ")"
        
    node_str = node_name + node_attrs_str
    node_str = remove_extra_spaces(node_str)
    return node_str
    

def find_nodes_within_distance(graph, start_node, distance):
    q, visited = deque(), dict()
    q.append((start_node, 0))
    
    while q:
        n, d = q.popleft()
        if d <= distance:
            visited[n] = d
            neighbours = [
                neighbor for node, neighbor in graph.edges(n) \
                    if neighbor != n and \
                        neighbor not in visited and \
                        graph.edges[node, neighbor]['type'] != GENERALIZATION
                ]
            for neighbour in neighbours:
                if neighbour not in visited:
                    q.append((neighbour, d + 1))
    
    sorted_list = sorted(visited.items(), key=lambda x: x[1])
    return sorted_list

def get_node_neighbours(graph, start_node, distance):
    neighbours = find_nodes_within_distance(graph, start_node, distance)
    max_distance = max(distance for _, distance in neighbours)
    distance = min(distance, max_distance)
    return [node for node, d in neighbours if d == distance]


def get_triple_from_edge(g, edge, attrs=False):
    u, v = edge
    edge_str, edge_type = process_edge_for_string(g, edge)
    u_string, v_string = process_node_for_string(g, u, add_attrs=attrs), process_node_for_string(g, v, add_attrs=attrs)
    return ((u_string, edge_str, v_string), edge_type)


def get_triples_from_edges(g, edges=None, attrs=False):
    if edges is None:
        edges = g.edges()
    triples = []
    for edge in edges:
        triple = get_triple_from_edge(g, edge, attrs)
        triples.append(triple)
        
    return triples


def process_path_string(g, path, attrs=False):
    edges = list(zip(path[:-1], path[1:]))
    triples = get_triples_from_edges(g, edges, attrs)
    Xs, ys = [" ".join(t[0]) for t in triples], [t[1] for t in triples]

    return Xs, ys


def get_triples_from_node(g, n, distance=1, attrs=False):
    triples = list()
    node_neighbours = get_node_neighbours(g, n, distance)
    for neighbour in node_neighbours:
        paths = [p for p in nx.all_simple_paths(g, n, neighbour, cutoff=distance)]
        for path in paths:
            triples.append(process_path_string(g, path, attrs))
    
    return triples


def get_graph_triples(g, distance=1, attrs=False):
    triples = list()
    for node in g.nodes():
        triples += get_triples_from_node(g, node, distance, attrs)
    return triples


def get_triples(graphs, distance=1, attrs=False):
    triples = []
    for g in tqdm(graphs):
        triples += get_graph_triples(g, distance, attrs)
    return triples

def remove_duplicate_triples(triples):
    return list({t[0]:t for t in triples}.values())

In [None]:
import numpy as np


def mask_graph(g, mask_ratio=0.2):
    for edge in g.edges():
        g.edges[edge]['mask'] = False

    num_edges = g.number_of_edges()
    num_edges_to_mask = int(num_edges * mask_ratio)
    edges_to_mask = np.random.choice(num_edges, num_edges_to_mask, replace=False)
    for edge in edges_to_mask:
        g.edges[edge]['mask'] = True
    
    return g

def mask_graphs(graphs, mask_ratio=0.2):
    masked_graphs = [mask_graph(g, mask_ratio) for g in graphs]
    return masked_graphs

In [53]:
distance = [1, 2, 3]
attr_flag = [False, True]
for d in distance:
    for a in attr_flag:
        triples = get_triples(combined_clean_graphs, distance=d, attrs=a)
        triples = remove_duplicate_triples(triples)
    #     break
    # break
        print("Total triples:", len(triples))
        
        with open(f'datasets/ecore_graph_pickles/combined_graphs_triples_d{d}_attr{a}.json', 'w') as f:
            json.dump(triples, f, indent=4)

  0%|          | 0/6547 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
def combine_graphs(graphs):
    combined_graph = nx.DiGraph()
    for g in tqdm(graphs):
        for edge in g.edges():
            u, v = edge
            if not combined_graph.has_node(u):
                combined_graph.add_node(u, **g.nodes[u])
            if not combined_graph.has_node(v):
                combined_graph.add_node(v, **g.nodes[v])
                
            if not combined_graph.has_edge(*edge):
                combined_graph.add_edge(*edge, **g.edges[edge])
    return combined_graph

combined_graph = combine_graphs(combined_clean_graphs)

In [10]:
data = json.load(open('datasets/ecore_graph_pickles/combined_graphs_triples_d3_attrFalse.json'))

In [None]:
for edges, classes in tqdm(data):
    edges = edges.split(f' {SEP} ')
    for i, j in zip(edges, classes):
        try:
            if j == 0:
                assert len(i.split()) == 2
            else:
                assert len(i.split()) == 3
        except AssertionError as e:
            print(i, j)
            # raise(e)
    

In [36]:
super_types = list(set([edge[1] for g in tqdm(combined_clean_graphs) for n in g.nodes() for edge in g.edges(n) if g.edges[edge]['type'] == 'generalization']))
entities = list(set([n for g in tqdm(combined_clean_graphs) for n in g.nodes()]))

  0%|          | 0/6547 [00:00<?, ?it/s]

  0%|          | 0/6547 [00:00<?, ?it/s]

In [37]:
len(super_types), len(entities)

(13234, 62788)

In [43]:
nodes_with_super_types = dict()
for g in tqdm(combined_clean_graphs):
    for n in g.nodes():
        gen_edges = [edge[1] for edge in g.edges(n) if g.edges[edge]['type'] == 'generalization']
        other_edges = [edge[1] for edge in g.edges(n) if g.edges[edge]['type'] != 'generalization']
        if len(gen_edges) > 1:
            key_str = f"{n} {SEP} {' '.join(gen_edges)}"
            nodes_with_super_types[key_str] = (n, gen_edges)

len(nodes_with_super_types)

  0%|          | 0/6547 [00:00<?, ?it/s]

39397

In [59]:
l = [len(g.edges(n)) for g in tqdm(combined_clean_graphs) for n in g.nodes()]
max(l)

  0%|          | 0/6547 [00:00<?, ?it/s]

160

In [None]:
from collections import Counter

c = Counter(l)
sorted(c.items(), key=lambda x: x[0])

In [44]:
with open('supertypes.txt', 'w') as f:
    for _, v in nodes_with_super_types.items():
        f.write(f'{v[0]} -> {v[1]}\n')


In [62]:
relevant_nodes = [
    n for g in tqdm(combined_clean_graphs) for n in g.nodes() \
        if len([edge for edge in g.edges(n) if g.edges[edge]['type'] != 'generalization'])
]

len(relevant_nodes)

  0%|          | 0/6547 [00:00<?, ?it/s]

77972

In [63]:
len(set(relevant_nodes))

32291

In [93]:
node_triples = set()

for g in tqdm(combined_clean_graphs):
    for n in g.nodes():
        super_type_nodes = [edge[1] for edge in g.edges(n) if g.edges[edge]['type'] == 'generalization' and len(edge[1].strip())]

        if 'NamedElement' in super_type_nodes:
            super_type_nodes.remove('NamedElement')

        reference_nodes = [edge[1] for edge in g.edges(n) if g.edges[edge]['type'] != 'generalization' and len(edge[1].strip())]
        if not len(reference_nodes):
            continue

        selected_super_types = random.sample(super_type_nodes, min(5, len(super_type_nodes)))

        node_references = [edge[1] for edge in g.edges(n) if g.edges[edge]['type'] != 'generalization']
        for node_reference in node_references:
            edge_name = g.edges[n, node_reference]['name'] if 'name' in g.edges[n, node_reference] else ''
            node_triples.add((n, edge_name, node_reference, ", ".join(selected_super_types)))

node_triples = list(node_triples)
print(len(node_triples))

  0%|          | 0/6547 [00:00<?, ?it/s]

109715


In [94]:
all_super_types = set([st for t in node_triples for st in t[3].split(', ') if len(st.strip())])
print(len(all_super_types))


8366


In [95]:
all_entities = set([t[0] for t in node_triples] + [t[2] for t in node_triples])
print(len(all_entities))

42940


In [261]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(node_triples, test_size=0.1, random_state=42)

print(len(train), len(test))

98743 10972


In [262]:
train_entities, train_super_types = set([t[0] for t in train] + [t[2] for t in train]), set([st for t in train for st in t[3].split(', ')])
# val_entities, val_super_types = set([t[0] for t in val] + [t[2] for t in val]), set([st for t in val for st in t[3].split(', ')])
test_entities, test_super_types = set([t[0] for t in test] + [t[2] for t in test]), set([st for t in test for st in t[3].split(', ')])

In [263]:
print(len(train_entities), len(test_entities))
print(sum([1 for v in test_entities if v not in train_entities]))

41336 11463
1604


In [264]:
test_seen = list([v for v in test if v[0] in train_entities])

print(len(test_seen))

10117


In [266]:
test_unseen = list([v for v in test if v[0] not in train_entities])

print(len(test_unseen))

855


In [271]:
seen_entities = list(set([t[0] for t in train] + [t[0] for t in test_seen]))

unseen_entities = list(set([t[0] for t in test_unseen]))

In [270]:
seen_super_types = list(set([st for t in train for st in t[3].split(', ')] + [st for t in test_seen for st in t[3].split(', ')]))
unseen_super_types = list(set([st for t in test_unseen for st in t[3].split(', ')]))

In [269]:
data = {
    'train': train,
    'test': test,
    'test_seen': test_seen,
    'test_unseen': test_unseen,
    'seen_entities': seen_entities,
    'unseen_entities': unseen_entities,
    'seen_super_types': seen_super_types,
    'unseen_super_types': unseen_super_types,
}

with open('datasets/ecore_graph_pickles/ecore_node_triples.json', 'w') as f:
    json.dump(data, f, indent=4)

In [None]:
seen_super_types

In [118]:
len(test), sum([1 for i in test if len(i[-1].strip())])

(10972, 5729)

In [272]:
label_map = {v:k for k,v in enumerate(seen_entities)}
labels = [label_map[i[0]] for i in train]

In [273]:
stp_map = {v:k for k,v in enumerate(seen_super_types)}

In [303]:
from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder


def get_stereotype_labels(triples, stereotype_map, multi_label=True):
    stp_labels = [[stereotype_map[j] for j in i[3].split(', ') if len(j.strip())] for i in triples]
    if not multi_label:
        le = LabelEncoder()
        stp_labels = le.fit_transform([i[0] if len(i) else -1 for i in stp_labels])

        stp_labels = torch.from_numpy(stp_labels)
    else:
        mlb = MultiLabelBinarizer()
        stp_labels = torch.from_numpy(mlb.fit_transform(stp_labels))
        
    return stp_labels

In [None]:
stp_map

In [304]:
print(len(stp_map))
stp_labels = get_stereotype_labels(train, stp_map, multi_label=False)
max(stp_labels)

8255


tensor(6301)

In [201]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')

In [288]:
from torch.utils.data import Dataset, DataLoader
import torch

class TripleDataset(Dataset):
    def __init__(self, triples, entity_map, stereotype_map, tokenizer, multi_label=True):
        self.labels = torch.from_numpy(np.array([entity_map[t[0]] for t in triples]))
        self.stereotype_labels = get_stereotype_labels(triples, stereotype_map, multi_label)
        
        triples = [f'{tokenizer.mask_token} {t[1]} {t[2]}' for t in triples]
        self.tokenized = tokenizer(triples, padding=True, return_tensors='pt')

        self.entity_map = entity_map
        self.stereotype_map = stereotype_map

    @property
    def num_labels(self):
        return torch.unique(self.labels).shape[0]

    @property
    def num_stereotype_labels(self):
        if len(self.stereotype_labels.shape) == 1:
            return torch.unique(self.stereotype_labels).shape[0]
        return self.stereotype_labels.shape[1]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        inputs = {
            'input_ids': self.tokenized['input_ids'][idx],
            'attention_mask': self.tokenized['attention_mask'][idx],
            'labels': self.labels[idx],
        }
        entity_label = self.labels[idx]
        stereotype_label = self.stereotype_labels[idx]

        return inputs, entity_label, stereotype_label

label_map = {v:k for k,v in enumerate(seen_entities)}

In [289]:
stereotype_map = {v:k for k,v in enumerate(seen_super_types)}

In [310]:
multi_label = True

In [316]:
train_dataset = TripleDataset(data['train'], label_map, stereotype_map, tokenizer, multi_label)
test_dataset = TripleDataset(data['test_seen'], label_map, stereotype_map, tokenizer, multi_label)

In [317]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [318]:
class TripleClassifier(torch.nn.Module):
    """
        Classifier that first uses the model to get the pooled output
        Then applies a linear layer to get the logits for the entity classification
        Then applies another linear layer to get the logits for the stereotype classification
    """

    def __init__(self, num_labels, num_stp_labels, model_name='xlm-roberta-base'):
        super(TripleClassifier, self).__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self.num_labels = num_labels
        self.num_stp_labels = num_stp_labels
        self.dropout = torch.nn.Dropout(0.1)
        self.linear = torch.nn.Linear(self.model.config.hidden_size, self.num_labels)
        self.stp_linear = torch.nn.Linear(self.model.config.hidden_size, self.num_stp_labels)
        self.softmax = torch.nn.Softmax(dim=1)
        
    
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        logits = self.linear(pooled_output)
        stp_logits = self.stp_linear(pooled_output)
        return self.softmax(logits), self.softmax(stp_logits)
    

    def get_entity_loss(self, logits, labels):
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        return loss

    def get_stereotype_loss(self, stp_logits, stp_labels):
        """
            stp_logits: (batch_size, num_stp_labels)
            stp_labels: (batch_size, num_stp_labels)
            This method calculates the loss for the stereotype classification such that,
            if stp_labels shape is (batch_size, num_stp_labels), then the loss is calculated using cross entropy loss
            else if stp_labels shape is (batch_size, num_stp_labels, k), then the loss is calculated using binary cross entropy loss
        """

        if len(stp_labels.shape) == 1:
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(stp_logits, stp_labels)
        else:
            loss_fct = torch.nn.BCELoss()
            loss = loss_fct(stp_logits.float(), stp_labels.float())
        return loss
    

    def get_loss(self, logits, stp_logits, labels, stp_labels, alpha=0.5):
        entity_loss = self.get_entity_loss(logits, labels)
        stp_loss = self.get_stereotype_loss(stp_logits, stp_labels)
        loss = alpha * entity_loss + (1 - alpha) * stp_loss
        return loss    

In [319]:
model = TripleClassifier(train_dataset.num_labels, train_dataset.num_stereotype_labels, model_name='xlm-roberta-base')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model.to(device)

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaModel: ['lm_head.dense.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


TripleClassifier(
  (model): XLMRobertaModel(
    (embeddings): XLMRobertaEmbeddings(
      (word_embeddings): Embedding(250002, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): XLMRobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x XLMRobertaLayer(
          (attention): XLMRobertaAttention(
            (self): XLMRobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): XLMRobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
         

In [320]:
for batch, entity_labels, stereotype_labels in train_dataloader:
    print(batch['input_ids'].shape, batch['attention_mask'].shape, entity_labels.shape, stereotype_labels.shape)
    ### Start training loop
    optimizer.zero_grad()
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)

    entity_labels = entity_labels.to(device)
    stereotype_labels = stereotype_labels.to(device)

    if stereotype_labels.shape[-1] == 1:
        stp_mask = (stereotype_labels != -1)
    

    logits, stp_logits = model(input_ids, attention_mask)
    loss = model.get_loss(logits, stp_logits[stp_mask], entity_labels, stereotype_labels[stp_mask])
    loss.backward()
    optimizer.step()
    scheduler.step()
    break

torch.Size([4, 38]) torch.Size([4, 38]) torch.Size([4]) torch.Size([4, 8181])


KeyboardInterrupt: 