In [47]:
import sys
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch.nn as nn
import plotly.graph_objects as go
from plotly.subplots import make_subplots
sys.path.append('../')
from dataset import EZData

In [48]:
dataset = EZData()

In [49]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 500)
        self.conv2 = GCNConv(500, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return x

In [70]:
model = Net()
model.train()
loss_metric = nn.BCEWithLogitsLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.1, weight_decay=5e-4)
i=0
losses = []
accuracies = []
while i < 1000:
    (x,y) = next(dataset)
    out = model(x)
    optimizer.zero_grad()
    loss = loss_metric(out[0], y.float())
    loss.backward()
    optimizer.step()
    losses.append(loss.detach().numpy())
    
    if i % 10 == 0:
        model.eval()
        j=0
        correct = 0
        total = 0
        while j < 50:
            (x,y) = next(dataset)
            out = model(x)
            out = nn.functional.sigmoid(out[0]) > 0.5
            correct += (out == y)
            total += 1
            j+=1
            
        accuracy = correct/total
        accuracies.append(accuracy.detach().numpy()[0])
        model.train()
    i+=1
    
fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(go.Scatter(x=list(range(len(losses))), y=losses,  name='Loss'), secondary_y=False)
fig.add_trace(go.Scatter(x=[i*10 for i in range(len(accuracies))], y=accuracies, name='Accuracy'),secondary_y=True)
fig.update_layout(title='Loss vs. Steps', xaxis_title='Steps')
fig.update_yaxes(title_text="Loss", secondary_y=False)
fig.update_yaxes(title_text="Accuracy", secondary_y=True)
fig.show()


nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.



In [30]:
model.eval()
i=0
correct = 0
total = 0
while i < 50:
    (x,y) = next(dataset)
    out = model(x)
    out = nn.functional.sigmoid(out[0]) > 0.5
    correct += (out == y)
    total += 1
    i+=1


nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.



In [62]:
[i*10 for i in range(len(accuracies))]

In [63]:
losses

[array(0.6889902, dtype=float32),
 array(1.6689287e-06, dtype=float32),
 array(22.37998, dtype=float32),
 array(0., dtype=float32),
 array(5.960463e-07, dtype=float32),
 array(19.176592, dtype=float32),
 array(0.00247969, dtype=float32),
 array(1.5616295e-05, dtype=float32),
 array(20.125502, dtype=float32),
 array(0., dtype=float32),
 array(12.461135, dtype=float32),
 array(0.00010907, dtype=float32),
 array(1.4799483, dtype=float32),
 array(0.30047578, dtype=float32),
 array(0.11398519, dtype=float32),
 array(1.3508232, dtype=float32),
 array(4.311777, dtype=float32),
 array(2.6828356, dtype=float32),
 array(0.3964409, dtype=float32),
 array(1.0295136, dtype=float32),
 array(0.9016221, dtype=float32),
 array(0.00836937, dtype=float32),
 array(1.5429058, dtype=float32),
 array(0.00620333, dtype=float32),
 array(4.932947, dtype=float32),
 array(8.179395, dtype=float32),
 array(15.834789, dtype=float32),
 array(0.6920675, dtype=float32),
 array(8.106199e-06, dtype=float32),
 array(2.622

In [64]:
accuracies

[array([0.44], dtype=float32),
 array([0.56], dtype=float32),
 array([0.54], dtype=float32),
 array([0.36], dtype=float32),
 array([0.38], dtype=float32),
 array([0.54], dtype=float32),
 array([0.62], dtype=float32),
 array([0.54], dtype=float32),
 array([0.52], dtype=float32),
 array([0.52], dtype=float32),
 array([0.76], dtype=float32),
 array([0.54], dtype=float32),
 array([0.46], dtype=float32),
 array([0.6], dtype=float32),
 array([0.54], dtype=float32),
 array([0.62], dtype=float32),
 array([0.58], dtype=float32),
 array([0.58], dtype=float32),
 array([0.66], dtype=float32),
 array([0.46], dtype=float32),
 array([0.72], dtype=float32),
 array([0.58], dtype=float32),
 array([0.58], dtype=float32),
 array([0.6], dtype=float32),
 array([0.44], dtype=float32),
 array([0.46], dtype=float32),
 array([0.58], dtype=float32),
 array([0.48], dtype=float32),
 array([0.56], dtype=float32),
 array([0.5], dtype=float32),
 array([0.48], dtype=float32),
 array([0.74], dtype=float32),
 array([0.7