In [19]:
%matplotlib inline

import plotly.express as px


import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import functional as F
from dgl.nn.pytorch import GraphConv

from collections import Counter

In [2]:
class GCN(nn.Module):
    def __init__(self,
                 in_feats,
                 n_classes,
                 activation=F.relu,
                 n_tokens=None,
                 pad_ix=None):
        super(GCN, self).__init__()

        self.pad_ix = pad_ix
        self.n_tokens = n_tokens

        self.emb = nn.Embedding(n_tokens, in_feats, padding_idx=pad_ix)

        self.fc1 = nn.Linear(in_feats, n_classes)


    def apply_embs(self, features):
        seq_len = torch.sum(features != self.pad_ix, axis=1)

        seq_len = seq_len.view((-1, 1))

        seq_len[seq_len == 0] = 1

        h = self.emb(features)

        h = h.sum(dim=1) / seq_len

        # h2 = h.max(dim=1)[0]

        # h = torch.cat((h1, h2), dim=1)

        return h

    def forward(self, inx):
        h = self.apply_embs(inx)

        h = self.fc1(h)

#         h = self.gcn_layer1(self.g, h)

#         h = self.dropout(h)

#         h = self.gcn_layer2(self.g, h)

        return h

In [88]:
def as_matrix(sequences, token_to_id, unk_ix, pad_ix, max_len=None):
    """ Convert a list of tokens into a matrix with padding """
    if isinstance(sequences[0], str):
        sequences = list(map(str.split, sequences))

    max_len = min(max(map(len, sequences)), max_len or float('inf'))

    matrix = np.full((len(sequences), max_len), np.int32(pad_ix))
    for i, seq in enumerate(sequences):
        row_ix = [token_to_id.get(word, unk_ix) for word in seq[:max_len]]
        if len(row_ix) > 0:
            matrix[i, -len(row_ix):] = row_ix

    return matrix


def texts_to_ind(texts):
    counter = Counter()

    for text in texts:
        for token in text.split():
            counter[token] += 1

    unique_tokens = [t for t in counter.keys() if counter[t]]
    unk, pad = "UNK", "PAD"
    unique_tokens = [unk, pad] + unique_tokens

    token_to_id = {t: i for i, t in enumerate(unique_tokens)}

    unk_ix, pad_ix = map(token_to_id.get, [unk, pad])
    n_tokens = len(unique_tokens)

    matrix = as_matrix(texts, token_to_id, unk_ix, pad_ix)
    
    
    return matrix, token_to_id

In [93]:
texts = ['aaaa1', 'bbbb1', 'cccc1', 'aaaa2', 'bbbb2', 'cccc2']
target = [0, 0, 0, 1, 1, 1]

matrix, token_to_ind =  texts_to_ind(texts)

In [94]:
matrix

array([[2],
       [3],
       [4],
       [5],
       [6],
       [7]], dtype=int32)

In [102]:
model = GCN(in_feats=8,
            n_classes=2,
            activation=F.relu,
            pad_ix=1,
            n_tokens=8)

embs = model.emb.weight.detach().cpu().numpy()
embs.shape

(8, 8)

In [103]:
fig = px.scatter(x=embs[:,0], y=embs[:,1], color=[-1, -1] + target, text=list(token_to_ind.keys()), 
                size_max=60)


fig.update_traces(textposition='top center', marker=dict(size=12,
                              line=dict(width=2,
                                        color='DarkSlateGrey')))
fig.show()

In [104]:
matrix_t= torch.LongTensor(matrix)
target_t = torch.LongTensor(target)

loss_fcn = torch.nn.CrossEntropyLoss()

# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [117]:
logits = model(matrix_t)
loss = loss_fcn(logits, target_t)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(loss.item())

0.7039206027984619


In [118]:
model.fc1.weight

Parameter containing:
tensor([[ 0.1257, -0.1868, -0.2718, -0.0465,  0.2113,  0.2253,  0.1248, -0.1106],
        [-0.2693, -0.1495,  0.0099,  0.3768, -0.0549,  0.2714, -0.2322,  0.1358]],
       requires_grad=True)

In [119]:
embs = model.emb.weight.detach().cpu().numpy()
fig = px.scatter(x=embs[:,0], y=embs[:,1], color=[-1, -1] + target, text=list(token_to_ind.keys()), 
                size_max=60)


fig.update_traces(textposition='top center', marker=dict(size=12,
                              line=dict(width=2,
                                        color='DarkSlateGrey')))
fig.show()

In [None]:
def train_gcn(dataset,
              test_ratio=0.5,
              val_ratio=0.2,
              seed=1,
              n_hidden=64,
              n_epochs=200,
              lr=1e-2,
              weight_decay=5e-4,
              dropout=0.5,
              use_embs=False,
              verbose=True,
              cuda=False):
    data = dataset.get_data()
    if use_embs:
        pad_ix, n_tokens, matrix, pretrained_embs = data['features']
        if pretrained_embs is not None:
            pretrained_embs = torch.FloatTensor(pretrained_embs)
        features = torch.LongTensor(matrix)
    else:
        pad_ix = None
        n_tokens = None
        pretrained_embs = None
        features = torch.FloatTensor(data['features'])

    labels = torch.LongTensor(data['labels'])
    n = len(data['ids'])
    train_mask, val_mask, test_mask = get_masks(n,
                                                data['main_ids'],
                                                data['main_labels'],
                                                test_ratio=test_ratio,
                                                val_ratio=val_ratio,
                                                seed=seed)

    train_mask = torch.BoolTensor(train_mask)
    val_mask = torch.BoolTensor(val_mask)
    test_mask = torch.BoolTensor(test_mask)

    if cuda:
        torch.cuda.set_device("cuda:0")
        features = features.cuda()
        labels = labels.cuda()
        train_mask = train_mask.cuda()
        val_mask = val_mask.cuda()
        test_mask = test_mask.cuda()

    g = DGLGraph(data['graph'])
    g = dgl.transform.add_self_loop(g)
    n_edges = g.number_of_edges()

    degs = g.in_degrees().float()
    norm = torch.pow(degs, -0.5)
    norm[torch.isinf(norm)] = 0

    if cuda:
        norm = norm.cuda()

    g.ndata['norm'] = norm.unsqueeze(1)

    if use_embs:
        if pretrained_embs is not None:
            in_feats = 300
        else:
            in_feats = 64
    else:
        in_feats = features.shape[1]

    # + 1 for unknown class
    n_classes = data['n_classes'] + 1
    model = GCN(g,
                in_feats=in_feats,
                n_hidden=n_hidden,
                n_classes=n_classes,
                activation=F.relu,
                dropout=dropout,
                use_embs=use_embs,
                pretrained_embs=pretrained_embs,
                pad_ix=pad_ix,
                n_tokens=n_tokens)

    if cuda:
        model.cuda()

    loss_fcn = torch.nn.CrossEntropyLoss()

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)

    best_f1 = -100
    # initialize graph
    dur = []
    for epoch in range(n_epochs):
        model.train()
        if epoch >= 3:
            t0 = time.time()
        # forward
        mask_probs = torch.empty(features.shape).uniform_(0, 1)
        if cuda:
            mask_probs = mask_probs.cuda()

        mask_features = torch.where(mask_probs > 0.2, features, torch.zeros_like(features))
        logits = model(mask_features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

        f1 = evaluate(model, features, labels, val_mask)

        if f1 > best_f1:
            best_f1 = f1
            torch.save(model.state_dict(), 'best_model.pt')

        if verbose:
            print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | F1 {:.4f} | "
                  "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
                                                f1, n_edges / np.mean(dur) / 1000))

    model.load_state_dict(torch.load('best_model.pt'))
    f1 = evaluate(model, features, labels, test_mask)

    if verbose:
        print()
        print("Test F1 {:.2}".format(f1))

    return f1