# AGNews Graph version

### Sources
- Mátyás Reuters Graph notebook

In [3]:
import torch
import numpy as np

from data_prep.agnews_graph import AGNewsGraph

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [14]:
agnews = AGNewsGraph(device, train_doc=1000)

Prepare AGNews dataset


Using custom data configuration default
Reusing dataset ag_news (/home/mat/.cache/huggingface/datasets/ag_news/default/0.0.0/0eeeaaa5fb6dffd81458e293dfea1adba2881ffcbdc3fb56baeb5a892566c29a)


Compute tf.idf
Compute PMI scores
Generate edges
Generate masks
Generate feature matrix
Features mtx is 0.5503716000000001 GBs in size


In [15]:
print(np.unique(agnews.data.y[agnews.data.train_mask].cpu(), return_counts=True))
print(np.unique(agnews.data.y[agnews.data.val_mask].cpu(), return_counts=True))

(array([0, 1, 2, 3]), array([269, 222, 241, 268]))
(array([0, 1, 2, 3]), array([17, 23, 30, 30]))


In [16]:
print(sum(agnews.data.val_mask * agnews.data.train_mask * agnews.data.test_mask))
print(sum(agnews.data.train_mask))
print(sum(agnews.data.val_mask))
print(sum(agnews.data.test_mask))

tensor(0, device='cuda:0')
tensor(1000, device='cuda:0')
tensor(100, device='cuda:0')
tensor(100, device='cuda:0')


In [7]:
# GraphConv, GATConv
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GraphConv

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(len(agnews.iton), 200)
        self.conv2 = GCNConv(200, 8)

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
        
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)
        return x

def eval(model, data, mask):
    _, pred = model(data).max(dim=1)
    #print(pred[mask])
    #print(data.y[mask])
    correct = pred[mask].eq(data.y[mask]).sum().item()
    acc = correct / mask.sum()
    print('Accuracy: {:.4f}'.format(acc))

In [8]:
model = Net()
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)#, weight_decay=5e-4)
agnews.data.to(device)

Data(edge_attr=[694792], edge_index=[2, 694792], test_mask=[11730], train_mask=[11730], val_mask=[11730], x=[11730, 11730], y=[11730])

In [9]:
from tqdm.notebook import tqdm

model.train()
for epoch in tqdm(range(40)):
    optimizer.zero_grad()
    out = model(agnews.data)
    # We might want to use the "weight" parameter for the loss with unbalanced dataset
    # since with a low learning rate the model just assigns every doc to class "earn"
    # https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
    loss = F.cross_entropy(out[agnews.data.train_mask], agnews.data.y[agnews.data.train_mask])
    print('Loss:', loss.item())
    loss.backward()
    optimizer.step()
    print(np.unique(out[agnews.data.train_mask].max(dim=1)[1].detach().cpu().numpy(), return_counts=True))
    #print(np.unique(r8.data.y[r8.data.train_mask].detach().cpu().numpy(), return_counts=True))
    eval(model, agnews.data, agnews.data.val_mask)

HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))

Loss: 2.079390048980713
(array([0, 1, 2, 3, 4, 5, 6, 7]), array([ 29, 125,  97, 345, 129, 128,  90,  57]))
Accuracy: 0.4600
Loss: 1.6133803129196167
(array([0, 1, 2, 3]), array([484, 119,  73, 324]))
Accuracy: 0.4000
Loss: 1.3256406784057617
(array([0, 1, 2, 3]), array([560,   3,  20, 417]))
Accuracy: 0.4500
Loss: 1.1572171449661255
(array([0, 1, 2, 3]), array([372, 120, 140, 368]))
Accuracy: 0.7300
Loss: 0.9445360898971558
(array([0, 1, 2, 3]), array([263, 232, 243, 262]))
Accuracy: 0.6600
Loss: 0.6910749077796936
(array([0, 1, 2, 3]), array([267, 222, 232, 279]))
Accuracy: 0.7600
Loss: 0.44048789143562317
(array([0, 1, 2, 3]), array([269, 222, 246, 263]))
Accuracy: 0.8300
Loss: 0.268040269613266
(array([0, 1, 2, 3]), array([266, 224, 241, 269]))
Accuracy: 0.8100
Loss: 0.14889079332351685
(array([0, 1, 2, 3]), array([270, 223, 243, 264]))
Accuracy: 0.8200
Loss: 0.08041746914386749
(array([0, 1, 2, 3]), array([268, 222, 241, 269]))
Accuracy: 0.7800
Loss: 0.049951814115047455
(array([0,

In [10]:
eval(model, agnews.data, agnews.data.test_mask)

Accuracy: 0.8600
