In [1]:
import numpy as np
import os
import pandas as pd
import torch
from gmf import GMF
from data import *
import torch.nn as nn

In [28]:
from ignite.engine import Engine, Events
from ignite.metrics import Accuracy, Precision, Recall, Loss, RunningAverage
from ignite.contrib.handlers import ProgressBar

In [3]:
USE_GPU = True

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

using device: cpu


# Data

In [4]:
path = './data/'

In [5]:
interactions = pd.read_csv(path + 't1_train.csv')
human_feats = np.load(path + 'human_feats.npy')
virus_feats = np.load(path + 'virus_feats.npy')

In [6]:
interactions.head()

Unnamed: 0,node1,node2,edge
0,151,2841,1.0
1,151,2874,1.0
2,151,780,1.0
3,151,1183,1.0
4,155,2346,1.0


In [7]:
human_feats.shape

(7209, 2799)

In [8]:
virus_idxs = sorted(interactions['node1'].unique())
human_idxs = sorted(interactions['node2'].unique())
len(virus_idxs), len(human_idxs)

(172, 5225)

In [9]:
vtoi = {v : i for i, v in enumerate(virus_idxs)}
itov = {i : v for i, v in enumerate(virus_idxs)}
htoi = {h : i for i, h in enumerate(human_idxs)}
itoh = {i : h for i, h in enumerate(human_idxs)}

In [10]:
gen = SingleTaskGenerator(interactions, human_feats, virus_feats, .10)

In [11]:
train_loader = gen.create_train_loader(3)
val_loader = gen.create_val_loader(3)
test_loader = gen.create_test_loader(3)

In [12]:
next(iter(train_loader))

[tensor([[  42, 2564],
         [  45, 4288],
         [   1, 3381]]),
 tensor([[0.5690, 0.6897, 0.8793,  ..., 0.0000, 0.0000, 1.0000],
         [0.4636, 0.6689, 0.5629,  ..., 0.0000, 0.0000, 0.0000],
         [0.5294, 0.9020, 0.6667,  ..., 0.0000, 0.0000, 1.0000]],
        dtype=torch.float64),
 tensor([[0.6515, 0.7576, 0.5909,  ..., 0.0000, 0.0000, 1.0000],
         [0.9375, 0.9375, 0.5000,  ..., 0.0000, 0.0000, 1.0000],
         [0.2273, 0.5303, 0.0758,  ..., 0.0000, 0.0000, 1.0000]],
        dtype=torch.float64),
 tensor([0., 0., 0.], dtype=torch.float64)]

# Model

In [13]:
n_virus, n_human = len(virus_idxs), len(human_idxs)

In [14]:
config = {
    'num_virus': n_virus,
    'num_human': n_human,
    'latent_dim': 2799,
    'sparse': False
}
model = GMF(config)
model.to(device)

GMF(
  (virus): Embedding(172, 2799)
  (human): Embedding(5225, 2799)
)

In [15]:
optimizer = torch.optim.SGD(model.parameters(), 
                            lr = 1e-3,  
                            momentum=0.9, 
                            weight_decay=1e-5)
criterion = nn.MSELoss()

In [16]:
threshhold = .50

### Trainer

In [36]:
debug_loader = gen.create_debug_loader(3)

In [37]:
next(iter(debug_loader))

[tensor([[  56, 2067],
         [  37, 1832],
         [  35,  143]]),
 tensor([[0.3261, 0.4348, 0.3913,  ..., 0.0000, 0.0000, 1.0000],
         [0.5152, 0.8788, 0.6970,  ..., 0.0000, 0.0000, 1.0000],
         [0.9577, 0.9155, 0.9014,  ..., 0.0000, 0.0000, 1.0000]],
        dtype=torch.float64),
 tensor([[0.6122, 0.9796, 0.7551,  ..., 0.0000, 0.0000, 1.0000],
         [0.6939, 0.8163, 0.7959,  ..., 0.0000, 0.0000, 1.0000],
         [0.2667, 0.5333, 0.2667,  ..., 0.0000, 0.0000, 1.0000]],
        dtype=torch.float64),
 tensor([0., 0., 0.], dtype=torch.float64)]

In [38]:
def train_batch(engine, batch):
    model.train()
    optimizer.zero_grad()
    
    x_pairs, human_feats, virus_feats, ys = batch
    v_idxs, h_idxs = x_pairs[:,0], x_pairs[:,1]
    pred = model(h_idxs, v_idxs, human_feats, virus_feats)
#     pred = pred_help(pred)
    loss = criterion(pred, ys)
#     print(loss)
    loss.backward()
    optimizer.step()
        
    return loss.item()

In [39]:
RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

In [40]:
pbar = ProgressBar(persist=True)
pbar.attach(trainer, ['loss'])

### train eval 

In [41]:
trainer = Engine(train_batch)
trainer.run(debug_loader)

<ignite.engine.engine.State at 0x7f96f24d1470>

# evaluation

In [84]:
trainer = Engine(train_batch)

In [85]:
# round probabilities
def thresholded_output_transform(output):
    y_pred, y = output
    y_pred = torch.round(y_pred)
    return y_pred, y

In [86]:
def eval_fn(engine, batch):
    model.eval()
    with torch.no_grad():
        x_pairs, human_feats, virus_feats, ys = batch
        v_idxs, h_idxs = x_pairs[:,0], x_pairs[:,1]
        pred = model(h_idxs, v_idxs, human_feats, virus_feats)
        return pred, ys

In [87]:
train_evaluator = Engine(eval_fn)

In [88]:
Accuracy(output_transform=thresholded_output_transform).attach(train_evaluator, 'accuracy')
Precision(output_transform=thresholded_output_transform).attach(train_evaluator, 'precision')
Recall(output_transform=thresholded_output_transform).attach(train_evaluator, 'recall')
Loss(criterion).attach(train_evaluator, 'loss')

In [89]:
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
    train_evaluator.run(debug_loader)
    metrics = train_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_loss = metrics['loss']
    prec = metrics['precision']
    rec = metrics['recall']
    pbar.log_message(
        "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f} Prec: {:.2f} Rec: {:.2f}"
        .format(engine.state.epoch, avg_accuracy, avg_loss, prec, rec))

In [90]:
RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

In [91]:
pbar = ProgressBar(persist=True)
pbar.attach(trainer, ['loss'])

In [92]:
trainer.run(debug_loader)



[0/4]   0%|           [00:00<?][A[A

Epoch [1/1]: [0/4]   0%|           [00:00<?][A[A

Epoch [1/1]: [0/4]   0%|          , loss=3.33e-05 [00:00<?][A[A

Epoch [1/1]: [1/4]  25%|██▌       , loss=3.33e-05 [00:00<00:00][A[A

Epoch [1/1]: [1/4]  25%|██▌       , loss=3.34e-05 [00:00<00:00][A[A

Epoch [1/1]: [2/4]  50%|█████     , loss=3.34e-05 [00:00<00:00][A[A

Epoch [1/1]: [2/4]  50%|█████     , loss=3.34e-05 [00:00<00:00][A[A

Epoch [1/1]: [2/4]  50%|█████     , loss=3.30e-05 [00:00<00:00][A[A

Epoch [1/1]: [3/4]  75%|███████▌  , loss=3.30e-05 [00:00<00:00][A[A

Epoch [1/1]: [3/4]  75%|███████▌  , loss=3.30e-05 [00:00<00:00][A[A

Epoch [1/1]: [3/4]  75%|███████▌  , loss=3.29e-05 [00:00<00:00][A[A

                                                                   
[A

Epoch [1/1]: [63/2512]   3%|▎         , loss=1.51e-02 [00:59<05:23]
Epoch [1/1]: [4/4] 100%|██████████, loss=3.40e-05 [00:32<00:00][A

Epoch [1/1]: [4/4] 100%|██████████, loss=3.29e-05 [00:00<00:00]

Training Results - Epoch: 1  Avg accuracy: 1.00 Avg loss: 0.00 Prec: 0.00 Rec: 0.00


<ignite.engine.engine.State at 0x7f96f24d1320>