In [28]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import f1_score, roc_auc_score

import torch as th
import json
import pandas as pd
import numpy as np

from tqdm import tqdm

In [2]:
import sys
HOME = '/srv/home/christinedk/wp_internship/'
sys.path.append(HOME + 'collaboration/')

from utils import load_all

In [3]:
# graph

In [49]:
class EditorGraph(object):
    def __init__(self):
        self.num_editor_features = 10
        self.num_collab_dir_feat = 2

    def construct_graph(self, sample):
        self.editor_nodes = [d['event_user_id'] for d in sample['editor']]
        self.editor_to_ind = {j:i for i,j in enumerate(self.editor_nodes)}
        self.ind_to_editor = {i:j for j,i in self.editor_to_ind.items()}
        self.collab_links_directed = [(self.editor_to_ind[pair['event_user_id']],self.editor_to_ind[pair['event_user_id_r']]) 
                                     for pair in sample['collaboration']['directed']]        
        g = dgl.graph(self.collab_links_directed,num_nodes=len(self.editor_nodes))
        return g

    def format_editor_features(self, sample):
        editor_features_lookup = pd.DataFrame(sample['editor'])\
                                            .set_index('event_user_id').to_dict('index')
        null_dict = {i: 0 for i in range(self.num_editor_features)}
        editor_features = [list(editor_features_lookup.get(i,null_dict).values()) 
                                     for i in self.editor_nodes]
        editor_features = th.tensor([[i if not np.isnan(i) else 0 for i in sample_feat]
                                    for sample_feat in editor_features]) # this is awful, please rewrite
        return editor_features
    
    def format_edge_features(self,sample):
        null_dict = {i: 0 for i in range(self.num_collab_dir_feat)}
        if len(sample['collaboration']['directed']) > 0:
            collab_dir_lookup = pd.DataFrame(sample['collaboration']['directed'])\
                                    .set_index(['event_user_id','event_user_id_r']).to_dict('index')
            features = [list(collab_dir_lookup.get((self.ind_to_editor[i],self.ind_to_editor[j]),null_dict).values())
                                  for i,j in self.collab_links_directed]
            features = th.tensor([[i if not np.isnan(i) else 0 for i in sample_feat]
                                    for sample_feat in features])
            
        else:
            features = th.tensor([list(null_dict.values())])
                            
        return features
    
    def make_graph(self, sample):
        graph = self.construct_graph(sample) 
        if graph.num_edges() > 0:
            graph.ndata['attr'] = self.format_editor_features(sample)
            graph.edata['attr'] = self.format_edge_features(sample)
            graph = dgl.add_self_loop(graph)
            return graph
        else:
            return None
        

    #def format_features(graph):
        

In [5]:
#with open('/srv/home/christinedk/wp_internship/features/editorsfanpov_v2.json','rb') as f:
#    samples_pos = json.load(f)
    
samples_pos = load_all('/srv/home/christinedk/wp_internship/features/editors{}_v2.json')

In [166]:
pos = []
for sample in tqdm(samples_pos):
    egraph = EditorGraph()
    g = egraph.make_graph(sample)
    if g:
        pos.append((g,1))

100%|██████████| 11543/11543 [03:20<00:00, 57.61it/s] 


In [167]:
pos[0][0]

Graph(num_nodes=231, num_edges=8839,
      ndata_schemes={'attr': Scheme(shape=(10,), dtype=torch.float32)}
      edata_schemes={'attr': Scheme(shape=(2,), dtype=torch.float32)})

In [168]:
#with open('/srv/home/christinedk/wp_internship/negative_features/editorsfanpov_v2.json','rb') as f:
#    samples_neg = json.load(f)
samples_neg = load_all('/srv/home/christinedk/wp_internship/negative_features/editors{}_v2.json') 

In [169]:
neg = []
for sample in tqdm(samples_neg):
    egraph = EditorGraph()
    g = egraph.make_graph(sample)
    if g:
        neg.append((g,0))

100%|██████████| 14264/14264 [04:01<00:00, 59.03it/s] 


In [170]:
dataset = pos + neg
np.random.shuffle(dataset)

In [172]:
dataset[0][0]

Graph(num_nodes=8, num_edges=38,
      ndata_schemes={'attr': Scheme(shape=(10,), dtype=torch.float32)}
      edata_schemes={'attr': Scheme(shape=(2,), dtype=torch.float32)})

In [173]:
len(dataset)

25504

In [12]:
# node features only

In [185]:
import dgl.nn.pytorch as dglnn
import torch.nn as nn

class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.conv1 = dglnn.GraphConv(in_dim, hidden_dim)
        self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim)
        self.conv3 = dglnn.GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)
        self.dropout = nn.Dropout(p=0.5)
        self.sigmoid = nn.Sigmoid()

    def forward(self, g, h):
        # Apply graph convolution and activation.
        h = F.relu(self.conv1(g, h))
        h = self.dropout(h)
        h = F.relu(self.conv2(g, h))
        h = self.dropout(h)
        h = F.relu(self.conv3(g, h))
        #h = F.relu(self.conv1(g, h))
        
        with g.local_scope():
            g.ndata['h'] = h
            # Calculate graph representation by average readout.
            hg = dgl.mean_nodes(g, 'h')
            return self.classify(hg)

In [186]:
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(dataset)
num_test = int(num_examples * 0.2)

test_sampler = SubsetRandomSampler(torch.arange(num_test))
train_sampler = SubsetRandomSampler(torch.arange(num_test, num_examples))

BATCH_SIZE=64
train_dataloader = GraphDataLoader(
    dataset, sampler=train_sampler, batch_size=BATCH_SIZE, drop_last=False)
test_dataloader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=BATCH_SIZE, drop_last=False)

In [187]:
import torch.nn.functional as F

num_correct = 0
num_tests = 0

model = Classifier(10, 512, 2)
opt = torch.optim.Adam(model.parameters(),lr=0.01)

for epoch in range(30):
    for batched_graph, labels in tqdm(train_dataloader):
        feats = batched_graph.ndata['attr']
        pred = model(batched_graph, feats)
        loss = F.cross_entropy(pred, labels)
        num_correct += (pred.argmax(1) == labels).sum().item()
        num_tests += len(labels)
        #labels = labels.unsqueeze(1)
        #loss = F.binary_cross_entropy(pred, labels.float())
        opt.zero_grad()
        loss.backward()
        opt.step()
    
    print('Epoch {}    Train accuracy: {}'.format(epoch, num_correct / num_tests))

100%|██████████| 638/638 [00:39<00:00, 16.01it/s]
  0%|          | 1/638 [00:00<01:21,  7.86it/s]

Epoch 0    Train accuracy: 0.5572926877082925


100%|██████████| 638/638 [00:39<00:00, 15.99it/s]
  0%|          | 2/638 [00:00<00:57, 11.09it/s]

Epoch 1    Train accuracy: 0.5579298176828072


100%|██████████| 638/638 [00:40<00:00, 15.76it/s]
  0%|          | 1/638 [00:00<01:41,  6.26it/s]

Epoch 2    Train accuracy: 0.5581912043390185


100%|██████████| 638/638 [00:39<00:00, 16.04it/s]
  0%|          | 1/638 [00:00<01:56,  5.45it/s]

Epoch 3    Train accuracy: 0.5582973926681043


100%|██████████| 638/638 [00:40<00:00, 15.76it/s]
  0%|          | 1/638 [00:00<01:04,  9.83it/s]

Epoch 4    Train accuracy: 0.5578023916879044


100%|██████████| 638/638 [00:39<00:00, 16.13it/s]
  0%|          | 2/638 [00:00<00:43, 14.73it/s]

Epoch 5    Train accuracy: 0.557978827680847


100%|██████████| 638/638 [00:39<00:00, 16.05it/s]
  0%|          | 1/638 [00:00<01:09,  9.22it/s]

Epoch 6    Train accuracy: 0.5580558433920519


100%|██████████| 638/638 [00:39<00:00, 16.14it/s]
  0%|          | 2/638 [00:00<00:53, 11.82it/s]

Epoch 7    Train accuracy: 0.5580707214271712


100%|██████████| 638/638 [00:39<00:00, 15.97it/s]
  0%|          | 1/638 [00:00<01:12,  8.80it/s]

Epoch 8    Train accuracy: 0.5581421943409789


100%|██████████| 638/638 [00:40<00:00, 15.90it/s]
  0%|          | 0/638 [00:00<?, ?it/s]

Epoch 9    Train accuracy: 0.5582483826700647


100%|██████████| 638/638 [00:39<00:00, 15.97it/s]
  0%|          | 2/638 [00:00<00:39, 16.22it/s]

Epoch 10    Train accuracy: 0.55826397676035


100%|██████████| 638/638 [00:39<00:00, 16.01it/s]
  0%|          | 2/638 [00:00<00:55, 11.48it/s]

Epoch 11    Train accuracy: 0.5582892243350976


100%|██████████| 638/638 [00:39<00:00, 16.07it/s]
  0%|          | 3/638 [00:00<00:36, 17.24it/s]

Epoch 12    Train accuracy: 0.5583030476678781


100%|██████████| 638/638 [00:40<00:00, 15.85it/s]
  0%|          | 1/638 [00:00<02:00,  5.28it/s]

Epoch 13    Train accuracy: 0.5583148962388327


100%|██████████| 638/638 [00:39<00:00, 16.01it/s]
  0%|          | 2/638 [00:00<00:35, 17.82it/s]

Epoch 14    Train accuracy: 0.5583349669999347


100%|██████████| 638/638 [00:39<00:00, 16.11it/s]
  0%|          | 2/638 [00:00<00:44, 14.36it/s]

Epoch 15    Train accuracy: 0.5583555920407763


100%|██████████| 638/638 [00:40<00:00, 15.91it/s]
  0%|          | 2/638 [00:00<00:37, 16.80it/s]

Epoch 16    Train accuracy: 0.5583709076651637


100%|██████████| 638/638 [00:39<00:00, 16.11it/s]
  0%|          | 2/638 [00:00<00:35, 18.04it/s]

Epoch 17    Train accuracy: 0.5583872443311769


100%|██████████| 638/638 [00:39<00:00, 16.21it/s]
  0%|          | 2/638 [00:00<00:36, 17.54it/s]

Epoch 18    Train accuracy: 0.5584044408217171


100%|██████████| 638/638 [00:39<00:00, 16.05it/s]
  0%|          | 2/638 [00:00<00:41, 15.17it/s]

Epoch 19    Train accuracy: 0.5584174671633013


100%|██████████| 638/638 [00:39<00:00, 16.10it/s]
  0%|          | 1/638 [00:00<01:27,  7.28it/s]

Epoch 20    Train accuracy: 0.5584292529009252


100%|██████████| 638/638 [00:39<00:00, 16.04it/s]
  0%|          | 0/638 [00:00<?, ?it/s]

Epoch 21    Train accuracy: 0.5584399672078558


100%|██████████| 638/638 [00:39<00:00, 15.98it/s]
  0%|          | 2/638 [00:00<00:38, 16.56it/s]

Epoch 22    Train accuracy: 0.5584497498359231


100%|██████████| 638/638 [00:39<00:00, 15.96it/s]
  0%|          | 0/638 [00:00<?, ?it/s]

Epoch 23    Train accuracy: 0.5584587172449846


100%|██████████| 638/638 [00:39<00:00, 16.01it/s]
  0%|          | 1/638 [00:00<01:18,  8.15it/s]

Epoch 24    Train accuracy: 0.5584669672613213


100%|██████████| 638/638 [00:39<00:00, 16.23it/s]
  0%|          | 1/638 [00:00<01:23,  7.59it/s]

Epoch 25    Train accuracy: 0.5584745826610167


100%|██████████| 638/638 [00:39<00:00, 16.01it/s]
  0%|          | 1/638 [00:00<01:03,  9.97it/s]

Epoch 26    Train accuracy: 0.558481633957031


100%|██████████| 638/638 [00:43<00:00, 14.61it/s]
  0%|          | 2/638 [00:00<00:39, 16.31it/s]

Epoch 27    Train accuracy: 0.5584881815890441


100%|██████████| 638/638 [00:40<00:00, 15.94it/s]
  0%|          | 1/638 [00:00<01:09,  9.18it/s]

Epoch 28    Train accuracy: 0.5584942776602289


100%|██████████| 638/638 [00:39<00:00, 16.01it/s]

Epoch 29    Train accuracy: 0.5584983336600666





In [189]:
predictions = []
y_test = []

for batched_graph, labels in tqdm(test_dataloader):
    feats = batched_graph.ndata['attr']
    pred = model(batched_graph, feats)
    predictions.extend(pred.tolist())
    y_test.extend(labels.tolist())
    
y_pred = np.argmax(predictions,axis=1)
roc = roc_auc_score(y_score=np.array(predictions)[:,1],y_true=y_test)
f1 = f1_score(y_pred=y_pred, y_true=y_test)


print('f1: {} roc auc: {}'.format(f1,roc))

100%|██████████| 160/160 [00:04<00:00, 34.34it/s]

f1: 0.0 roc auc: 0.4993958943595225





In [190]:
predictions

[[0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904377490282059, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084440469742],
 [0.11904376745223999, -0.06564084

In [None]:
# with edge features

In [193]:
class GNNLayer(nn.Module):
    def __init__(self, ndim, edim, hidden_dim): 
        super(GNNLayer, self).__init__()
        self.W_msg = nn.Linear(ndim + edim, hidden_dim)
        self.W_apply = nn.Linear(ndim + hidden_dim, hidden_dim)  

    def message_func(self, edges):
        return {'m': F.relu(self.W_msg(torch.cat([edges.src['h'], edges.data['h']], dim=1)))}

    def forward(self, g_dgl, nfeats, efeats):
        with g_dgl.local_scope():
            g = g_dgl
            g.ndata['h'] = nfeats 
            g.edata['h'] = efeats  
            g.update_all(self.message_func, fn.sum(msg='m', out='h_neigh'))
            g.ndata['h'] = F.relu(self.W_apply(torch.cat([g.ndata['h'], g.ndata['h_neigh']], dim=1)))
            return g.ndata['h']

class GNN(nn.Module):
    def __init__(self, ndim, edim, hidden_dim, n_classes, dropout):
        super(GNN, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(GNNLayer(ndim, edim, hidden_dim))
        self.layers.append(GNNLayer(hidden_dim, edim, hidden_dim))
        self.layers.append(GNNLayer(hidden_dim, edim, hidden_dim))
        self.fc = nn.Linear(hidden_dim, n_classes)   # added
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, g, nfeats, efeats):
        for i, layer in enumerate(self.layers):
            if i != 0:
                nfeats = self.dropout(nfeats)
            nfeats = layer(g, nfeats, efeats)
        g.ndata['h'] = nfeats                  # added
        h = dgl.mean_nodes(g, 'h')             # added
        return self.fc(h) 

In [194]:
g = pos[0][0]
nfeats = g.ndata['attr']
efeats = g.edata['attr']

In [195]:
model = GNN(10, 2, 64, 2, 0.5)

In [196]:
g

Graph(num_nodes=231, num_edges=8839,
      ndata_schemes={'attr': Scheme(shape=(10,), dtype=torch.float32)}
      edata_schemes={'attr': Scheme(shape=(2,), dtype=torch.float32)})

In [197]:
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(dataset)
num_test = int(num_examples * 0.2)

test_sampler = SubsetRandomSampler(torch.arange(num_test))
train_sampler = SubsetRandomSampler(torch.arange(num_test, num_examples))

BATCH_SIZE=64
train_dataloader = GraphDataLoader(
    dataset, sampler=train_sampler, batch_size=BATCH_SIZE, drop_last=False)
test_dataloader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=BATCH_SIZE, drop_last=False)

In [198]:
import torch.nn.functional as F

num_correct = 0
num_tests = 0

#model = Classifier(10, 512, 2)
opt = torch.optim.Adam(model.parameters(),lr=0.01)

for epoch in range(30):
    for batched_graph, labels in tqdm(train_dataloader):
        nfeats = batched_graph.ndata['attr']
        efeats = batched_graph.edata['attr']
        pred = model(batched_graph, nfeats, efeats)
        loss = F.cross_entropy(pred, labels)
        num_correct += (pred.argmax(1) == labels).sum().item()
        num_tests += len(labels)
        #labels = labels.unsqueeze(1)
        #loss = F.binary_cross_entropy(pred, labels.float())
        opt.zero_grad()
        loss.backward()
        opt.step()    
    print('Epoch {}    Train accuracy: {}'.format(epoch, num_correct / num_tests))

100%|██████████| 319/319 [00:38<00:00,  8.27it/s]
  0%|          | 1/319 [00:00<00:33,  9.58it/s]

Epoch 0    Train accuracy: 0.5531268378749264


100%|██████████| 319/319 [00:34<00:00,  9.31it/s]
  1%|          | 2/319 [00:00<00:21, 14.81it/s]

Epoch 1    Train accuracy: 0.5860615565575378


100%|██████████| 319/319 [00:37<00:00,  8.60it/s]
  0%|          | 0/319 [00:00<?, ?it/s]

Epoch 2    Train accuracy: 0.5982323727373717


100%|██████████| 319/319 [00:36<00:00,  8.82it/s]
  0%|          | 0/319 [00:00<?, ?it/s]

Epoch 3    Train accuracy: 0.6047956283081749


100%|██████████| 319/319 [00:36<00:00,  8.81it/s]
  0%|          | 1/319 [00:00<00:52,  6.02it/s]

Epoch 4    Train accuracy: 0.6090080376396785


100%|██████████| 319/319 [00:36<00:00,  8.63it/s]
  0%|          | 1/319 [00:00<00:32,  9.80it/s]

Epoch 5    Train accuracy: 0.6117591321963014


100%|██████████| 319/319 [00:37<00:00,  8.54it/s]
  1%|          | 2/319 [00:00<00:22, 13.91it/s]

Epoch 6    Train accuracy: 0.6140532668664408


100%|██████████| 319/319 [00:37<00:00,  8.44it/s]
  1%|          | 2/319 [00:00<00:19, 16.22it/s]

Epoch 7    Train accuracy: 0.6159515291119388


100%|██████████| 319/319 [00:37<00:00,  8.60it/s]
  0%|          | 0/319 [00:00<?, ?it/s]

Epoch 8    Train accuracy: 0.6176838964037552


100%|██████████| 319/319 [00:36<00:00,  8.75it/s]
  1%|          | 2/319 [00:00<00:22, 14.12it/s]

Epoch 9    Train accuracy: 0.618991374240345


100%|██████████| 319/319 [00:37<00:00,  8.55it/s]
  0%|          | 1/319 [00:00<00:32,  9.83it/s]

Epoch 10    Train accuracy: 0.6198606333873928


100%|██████████| 319/319 [00:36<00:00,  8.73it/s]
  0%|          | 1/319 [00:00<00:43,  7.28it/s]

Epoch 11    Train accuracy: 0.6208014768346076


100%|██████████| 319/319 [00:36<00:00,  8.81it/s]
  0%|          | 1/319 [00:00<00:43,  7.33it/s]

Epoch 12    Train accuracy: 0.6212281151508754


100%|██████████| 319/319 [00:36<00:00,  8.65it/s]
  0%|          | 0/319 [00:00<?, ?it/s]

Epoch 13    Train accuracy: 0.6219473772649621


100%|██████████| 319/319 [00:35<00:00,  9.05it/s]
  0%|          | 0/319 [00:00<?, ?it/s]

Epoch 14    Train accuracy: 0.6223583611056656


100%|██████████| 319/319 [00:37<00:00,  8.49it/s]
  0%|          | 1/319 [00:00<00:48,  6.49it/s]

Epoch 15    Train accuracy: 0.622794550088218


100%|██████████| 319/319 [00:35<00:00,  8.94it/s]
  0%|          | 0/319 [00:00<?, ?it/s]

Epoch 16    Train accuracy: 0.6232341986000438


100%|██████████| 319/319 [00:37<00:00,  8.44it/s]
  1%|          | 2/319 [00:00<00:23, 13.77it/s]

Epoch 17    Train accuracy: 0.6238128689363741


100%|██████████| 319/319 [00:36<00:00,  8.72it/s]
  0%|          | 1/319 [00:00<00:37,  8.51it/s]

Epoch 18    Train accuracy: 0.6242532423982914


100%|██████████| 319/319 [00:36<00:00,  8.68it/s]
  0%|          | 1/319 [00:00<01:01,  5.18it/s]

Epoch 19    Train accuracy: 0.6245000980199961


100%|██████████| 319/319 [00:37<00:00,  8.55it/s]
  0%|          | 0/319 [00:00<?, ?it/s]

Epoch 20    Train accuracy: 0.6249124821463579


100%|██████████| 319/319 [00:36<00:00,  8.73it/s]
  1%|          | 2/319 [00:00<00:20, 15.70it/s]

Epoch 21    Train accuracy: 0.625180445901873


100%|██████████| 319/319 [00:36<00:00,  8.76it/s]
  0%|          | 0/319 [00:00<?, ?it/s]

Epoch 22    Train accuracy: 0.6255124741099358


100%|██████████| 319/319 [00:37<00:00,  8.61it/s]
  0%|          | 1/319 [00:00<00:55,  5.70it/s]

Epoch 23    Train accuracy: 0.62581683330066


100%|██████████| 319/319 [00:36<00:00,  8.67it/s]
  0%|          | 1/319 [00:00<00:53,  5.98it/s]

Epoch 24    Train accuracy: 0.62612624975495


100%|██████████| 319/319 [00:36<00:00,  8.64it/s]
  1%|          | 2/319 [00:00<00:22, 13.92it/s]

Epoch 25    Train accuracy: 0.6262968799481248


100%|██████████| 319/319 [00:37<00:00,  8.55it/s]
  0%|          | 0/319 [00:00<?, ?it/s]

Epoch 26    Train accuracy: 0.6264766530890821


100%|██████████| 319/319 [00:37<00:00,  8.51it/s]
  0%|          | 1/319 [00:00<00:39,  8.05it/s]

Epoch 27    Train accuracy: 0.6266645895762736


100%|██████████| 319/319 [00:37<00:00,  8.61it/s]
  1%|          | 2/319 [00:00<00:28, 10.99it/s]

Epoch 28    Train accuracy: 0.6267719649291214


100%|██████████| 319/319 [00:36<00:00,  8.74it/s]

Epoch 29    Train accuracy: 0.6270535189178592





In [199]:
predictions = []
y_test = []

for batched_graph, labels in tqdm(test_dataloader):
    nfeats = batched_graph.ndata['attr']
    efeats = batched_graph.edata['attr']
    pred = model(batched_graph, nfeats, efeats)
    predictions.extend(pred.tolist())
    y_test.extend(labels.tolist())
    
y_pred = np.argmax(predictions,axis=1)
roc = roc_auc_score(y_score=np.array(predictions)[:,1],y_true=y_test)
f1 = f1_score(y_pred=y_pred, y_true=y_test)


print('f1: {} roc auc: {}'.format(f1,roc))

100%|██████████| 80/80 [00:04<00:00, 19.75it/s]

f1: 0.49873488895136353 roc auc: 0.6742754281746945





In [200]:
predictions

[[-0.12900593876838684, -1.1936659812927246],
 [-0.4199986457824707, -0.522402822971344],
 [-0.29522114992141724, -0.052160173654556274],
 [0.22752206027507782, -0.7758320569992065],
 [-0.14670753479003906, -0.167902410030365],
 [0.19554226100444794, -0.6547897458076477],
 [-0.07882875204086304, -0.781407356262207],
 [-0.19589650630950928, -0.2629874348640442],
 [0.24909427762031555, -0.7706218957901001],
 [-0.05590280890464783, -0.7472918033599854],
 [-0.23983091115951538, -0.23361551761627197],
 [-0.2444164752960205, -0.14763537049293518],
 [-0.16304826736450195, -0.4495096802711487],
 [-0.2876926064491272, -0.6273097991943359],
 [-1.5658903121948242, -2.160853862762451],
 [-0.34874236583709717, 0.14272496104240417],
 [-0.07351434230804443, -0.9449288845062256],
 [0.2502477169036865, -0.8115764856338501],
 [-0.3503492474555969, -0.14305095374584198],
 [-0.08220794796943665, -0.8555036783218384],
 [-0.3936017155647278, 0.2178218960762024],
 [-0.08996143937110901, -0.9728506207466125],