# MARA - IMDB_mlh dataset tests - by Bartosz Trojan
The implementation will be based on the official MARA paper
Right now I don't have much to show, but this notebook will be updated

## Imports and data preprocessing

In [1]:
# os.environ['TORCH'] = torch.__version__
# print(torch.__version__)

# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

In [None]:
import os
import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from utils.read_data_new import IMDB_mlh

imdb = IMDB_mlh()
imdb.info()

IMDB movie type dataset:
 Number of nodes: 2807
 Number of edges: layer1: 752, layer2: 2276
 Number of features: 1000
 Number of classes: 3
 Number of nodes per class: tensor([ 320, 1219, 1268])


## Model architecture

In [None]:
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(1234)
        self.conv1 = GCNConv(imdb.num_features, 512)
        self.conv2 = GCNConv(512, 256)
        self.conv3 = GCNConv(256, 52)
        self.classifier = Linear(52, imdb.num_classes)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()  # Final GNN embedding space.

        # Apply a final (linear) classifier.
        out = torch.sigmoid(self.classifier(h))

        return out, h

model = GCN()
print(model)

GCN(
  (conv1): GCNConv(1000, 512)
  (conv2): GCNConv(512, 256)
  (conv3): GCNConv(256, 52)
  (classifier): Linear(in_features=52, out_features=3, bias=True)
)


## Simple model training

In [None]:
model = GCN()

out, h = model(imdb.node_features, imdb.layer_1.t())

print(out.shape)
print(h.shape)

torch.Size([2807, 3])
torch.Size([2807, 52])


In [None]:
# tymczasowo dla przyśpieszenia testów

from config import config
import torch

class MARA():
    def __init__(self, simplificaton_type=config["simplification_type"], simplification_stages=config["simplification_stages"], simplification_strategy=config["simplification_strategy"], DE_p=config["DE_p"], NS_k=config["NS_k"]):
        self.simplification_type = simplificaton_type
        self.simplification_stages = simplification_stages
        self.simplification_strategy = simplification_strategy
        self.DE_p = DE_p
        self.NS_k = NS_k

    def simplify(self, nodes_for_each_layer, edges_for_each_layer, cross_layer_edges, node_classes):
        if(self.simplification_strategy == "DE"):
            if(self.simplification_type == "l-b-l"):
                simplified = []
                for layer in range(len(edges_for_each_layer)):
                    print(edges_for_each_layer[layer].shape)
                    mask = torch.rand(1, edges_for_each_layer[layer].shape[0]) > self.DE_p
                    simplified.append(edges_for_each_layer[layer][mask.squeeze()].clone())
                    print(simplified[layer].shape)
                return simplified


In [None]:
mara = MARA()

siplified_edges = mara.simplify(imdb.node_features, [imdb.layer_1, imdb.layer_2], [], imdb.classes)

torch.Size([752, 2])
torch.Size([609, 2])
torch.Size([2276, 2])
torch.Size([1811, 2])


In [None]:
model = GCN()
criterion = torch.nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 

def accuracy(preds, labels):
    predicted_labels = torch.argmax(preds, dim=1)
    accuracy = (predicted_labels == labels).float().mean()

    return accuracy

def train(data):
    optimizer.zero_grad()
    out, h = model(data.node_features, data.layer_1.t()) 
    train_mask = data.get_training_mask(mask_size=0.5)

    loss = criterion(out[train_mask], data.classes[train_mask])
    acc = accuracy(out[train_mask], data.classes[train_mask])

    loss.backward()
    optimizer.step()

    return loss, acc

for epoch in range(201):
    loss, acc = train(imdb)
    if (epoch+1)%10 == 0:
        print("======== ",epoch+1," ========")
        print(f"Loss: {loss}")
        print(f"Accuracy: {acc}")

Loss: 0.7814971804618835
Accuracy: 0.783345103263855
Loss: 0.7094255089759827
Accuracy: 0.8564493656158447
Loss: 0.6735900044441223
Accuracy: 0.8836413621902466
Loss: 0.665148913860321
Accuracy: 0.9021126627922058
Loss: 0.6645676493644714
Accuracy: 0.8951048851013184
Loss: 0.6600362658500671
Accuracy: 0.8955672383308411
Loss: 0.6429871320724487
Accuracy: 0.9092229604721069
Loss: 0.650898277759552
Accuracy: 0.9011064767837524
Loss: 0.6468911170959473
Accuracy: 0.9045910835266113
Loss: 0.6388868093490601
Accuracy: 0.9132047295570374
Loss: 0.6359266638755798
Accuracy: 0.9148044586181641
Loss: 0.6373798847198486
Accuracy: 0.9122301936149597
Loss: 0.6442964673042297
Accuracy: 0.9060734510421753
Loss: 0.6262832283973694
Accuracy: 0.9236860871315002
Loss: 0.6194504499435425
Accuracy: 0.9311568737030029
Loss: 0.6366981267929077
Accuracy: 0.913165271282196
Loss: 0.6318746209144592
Accuracy: 0.9183955788612366
Loss: 0.6418827772140503
Accuracy: 0.9066374897956848
Loss: 0.634825587272644
Accuracy