In [35]:
import math
import random
import os.path as osp


import os
import pandas as pd

from torch import Tensor

import numpy as np

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from torch.utils.data import Dataset as BaseDataset
from torch.utils.data import DataLoader

from tqdm import tqdm
from datetime import datetime

In [25]:
root = os.getcwd()

In [2]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [81]:
class Args:
    num_ent = 2708
    embed_dim = 1500
    inp_drop = 0
    feat_drop = 0 # 0.5
    hid_drop = 0 # 0.5
    perm = 1
    k_w = 30
    k_h = 50
    num_filt = 96
    ker_sz = 9

In [82]:
class Dataset(BaseDataset):
    
    def __init__(self, root, split="train"):
        
        # Read node features
        content = pd.read_csv(os.path.join(root, "content.csv"), delimiter="\t", header=None)
        content = content.sort_values(by=[0]).loc[:, 1:].to_numpy()
#         content = torch.from_numpy(content)
        self.features = content
        
        
        self.links = None
        if split == "train":
            self.links = pd.read_csv(os.path.join(root, "train.csv"))
        else:
            self.links = pd.read_csv(os.path.join(root, "test.csv"))
            
        tmp = pd.DataFrame(data={
                                "from": self.links["to"], 
                                "to": self.links["from"],
                                "label": self.links["label"]})
        self.links = pd.concat([self.links, tmp], ignore_index=True)
        self.links = self.links.drop_duplicates()
        
    def __len__(self):
        return len(self.links)
        
    def __getitem__(self, idx):
        
        row = self.links.iloc[idx]

        follower, followee, label = row["from"], row["to"], row["label"]

        follower_feat = self.features[follower]
        followee_feat = self.features[followee]
        
        return follower, followee, follower_feat, followee_feat, label

In [83]:
def get_chequer_perm(perm=1, k_w=30, k_h=50):
    """
    Function to generate the chequer permutation required for InteractE model
    Parameters
    ----------

    Returns
    -------

    """
    embed_dim = k_w * k_h
    ent_perm  = np.int32([np.random.permutation(embed_dim) for _ in range(perm)])
#     print(ent_perm)
    rel_perm  = np.int32([np.random.permutation(embed_dim) for _ in range(perm)])
#     print(rel_perm)

    comb_idx = []
    for k in range(perm):
        temp = []
        ent_idx, rel_idx = 0, 0

        for i in range(embed_dim):
            if k % 2 == 0:
                if i % 2 == 0:
                    temp.append(ent_perm[k, ent_idx]); ent_idx += 1;
                    temp.append(rel_perm[k, rel_idx]+embed_dim); rel_idx += 1;
                else:
                    temp.append(rel_perm[k, rel_idx]+embed_dim); rel_idx += 1;
                    temp.append(ent_perm[k, ent_idx]); ent_idx += 1;
            else:
                if i % 2 == 0:
                    temp.append(rel_perm[k, rel_idx]+embed_dim); rel_idx += 1;
                    temp.append(ent_perm[k, ent_idx]); ent_idx += 1;
                else:
                    temp.append(ent_perm[k, ent_idx]); ent_idx += 1;
                    temp.append(rel_perm[k, rel_idx]+embed_dim); rel_idx += 1;

        comb_idx.append(temp)

    chequer_perm = torch.LongTensor(np.int32(comb_idx))
    return chequer_perm


In [84]:
class InteractE(torch.nn.Module):
    """
    Proposed method in the paper. Refer Section 6 of the paper for mode details 
    Parameters
    ----------
    params:        	Hyperparameters of the model
    chequer_perm:   Reshaping to be used by the model

    Returns
    -------
    The InteractE model instance

    """
    def __init__(self, params, chequer_perm):
        super(InteractE, self).__init__()

        self.p = params
        self.ent_embed = torch.nn.Embedding(self.p.num_ent, self.p.embed_dim-1433, padding_idx=None); # xavier_normal_(self.ent_embed.weight)
        self.inp_drop = torch.nn.Dropout(self.p.inp_drop)
        self.hidden_drop = torch.nn.Dropout(self.p.hid_drop)
        self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop)
        self.bn0 = torch.nn.BatchNorm2d(self.p.perm)

        flat_sz_h  = self.p.k_h
        flat_sz_w  = 2*self.p.k_w
        self.padding  = 0

        self.bn1  = torch.nn.BatchNorm2d(self.p.num_filt*self.p.perm)
        self.flat_sz  = flat_sz_h * flat_sz_w * self.p.num_filt*self.p.perm

        self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim)
#         self.fc  = torch.nn.Linear(self.flat_sz, self.p.embed_dim)
        self.fc = torch.nn.Linear(self.flat_sz, 1)
        self.chequer_perm = chequer_perm

#         self.register_parameter('bias', Parameter(torch.zeros(self.p.num_ent)))
        self.register_parameter('conv_filt', Parameter(torch.zeros(self.p.num_filt, 1, self.p.ker_sz,  self.p.ker_sz))); # xavier_normal_(self.conv_filt)

    def circular_padding_chw(self, batch, padding):
        upper_pad = batch[..., -padding:, :]
        lower_pad = batch[..., :padding, :]
        temp = torch.cat([upper_pad, batch, lower_pad], dim=2)

        left_pad = temp[..., -padding:]
        right_pad = temp[..., :padding]
        padded = torch.cat([left_pad, temp, right_pad], dim=3)
        return padded

    def forward(self, follower, followee, follower_id, followee_id, strategy='one_to_x'):
#         sub_emb = self.ent_embed(sub)
#         rel_emb = self.rel_embed(rel)

        follower_emb = self.ent_embed(follower_id)
        if isinstance(follower, np.ndarray):
            follower = torch.from_numpy(follower)
        follower_emb = torch.cat((follower, follower_emb), dim=-1)
    
#         print(follower_emb.size())
        
        followee_emb = self.ent_embed(followee_id)
        if isinstance(followee, np.ndarray):
            followee = torch.from_numpy(followee)
        followee_emb = torch.cat((followee, followee_emb), dim=-1)
        
        comb_emb = torch.cat([follower_emb, followee_emb], dim=-1)
        chequer_perm = comb_emb[:, self.chequer_perm] # batch, 1, embed_size
        
#         print(chequer_perm)
#         print(chequer_perm.size())
        
        stack_inp = chequer_perm.reshape((-1, self.p.perm, 2*self.p.k_w, self.p.k_h)) # batch, 1, 2*k_w, k_h
        
#         print(stack_inp.size())
        
        stack_inp = self.bn0(stack_inp)
        
        x = self.inp_drop(stack_inp)
        x = self.circular_padding_chw(x, self.p.ker_sz//2)
#         print(x.size())
        x = F.conv2d(x, self.conv_filt.repeat(self.p.perm, 1, 1, 1), padding=self.padding, groups=self.p.perm)
#         print(x.size())
        x = self.bn1(x)
        x = F.relu(x)
        x = self.feature_map_drop(x)
        x = x.view(-1, self.flat_sz)
        x = self.fc(x) # batch, flat_sz
        
#         print(x.size())

#         x = self.hidden_drop(x)
#         x = self.bn2(x)
#         x = F.relu(x)

#         if strategy == 'one_to_n':
#             x = torch.mm(x, self.ent_embed.weight.transpose(1,0))
#             x += self.bias.expand_as(x)
#         else:
#             x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1)
#             x += self.bias[neg_ents]

        pred = torch.sigmoid(x)

        return pred


In [43]:
chequer = get_chequer_perm()

In [86]:
def load_checkpoint(filepath, device):
    
    model = InteractE(Args(), chequer).to(device)

    if os.path.exists(filepath):
        print("pretrained finded")
        checkpoint = torch.load(filepath)
        model.load_state_dict(checkpoint['model_stat'])
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        optimizer.load_state_dict(checkpoint['optimizer_stat'])

    else:
        print("use a new optimizer")
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    return model, optimizer

In [87]:
trainset = Dataset(root=osp.join(root, "hw2_data", "dataset1", "raw"))

In [88]:
trainloader = DataLoader(trainset, batch_size=32, shuffle=False, num_workers=2)

In [89]:
model, optimizer = load_checkpoint(os.path.join(root, "results", "dataset1", "conv", "weight.pth"), device=device)

use a new optimizer


In [90]:
def train(model, optim, loader, device):
    
    model.train()
    model = model.to(device)
    
    total_loss = 0
    for index, data in tqdm(enumerate(loader)):
        
        optim.zero_grad()
        
        follower, followee, follower_feat, followee_feat, label = data
        
        follower = follower.to(device)
        followee = followee.to(device)
        follower_feat = follower_feat.to(device)
        followee_feat = followee_feat.to(device)
        label = label.to(device).float()

        pred = model(follower_feat, followee_feat, follower, followee) 
        pred = torch.squeeze(pred)
        
        loss = torch.nn.BCELoss()(pred, label)
#         print(pred, label)
        loss.backward()
        optim.step()
        
        total_loss += loss.item()

    total_loss = total_loss/(index+1)
    return total_loss

In [91]:
@torch.no_grad()
def test(model, loader, device="cpu"):
    
    from sklearn.metrics import roc_auc_score
    
    model.eval()
    model = model.to(device)
    
    y_pred, y_true = [], []
    for data in tqdm(loader):
        follower, followee, follower_feat, followee_feat, label = data
        
        follower = follower.to(device)
        followee = followee.to(device)
        follower_feat = follower_feat.to(device)
        followee_feat = followee_feat.to(device)
        label = label.to(device).float()
        
        pred = model(follower_feat, followee_feat, follower, followee) 
        pred = torch.squeeze(pred)
        
        y_pred.append(pred.view(-1).cpu())
        y_true.append(label.view(-1).cpu())
        
#     print(torch.cat(y_true), torch.cat(y_pred))

    return roc_auc_score(torch.cat(y_true), torch.cat(y_pred))

In [None]:
auc_max = 0
for epoch in range(100):

    loss = train(model, optimizer, trainloader, device)
    auc = test(model, trainloader)
   
    if auc > auc_max:
        auc_max = auc
        
        checkpoint = {
            'model_stat': model.state_dict(),
            'optimizer_stat': optimizer.state_dict(),
        }
        
        torch.save(checkpoint, os.path.join(root, "results", "dataset1", "conv", "weight.pth"))
        print("\nSave Model")
    
    print("Epoch: {}, loss={}, auc={}".format(epoch+1, loss, auc))
    

543it [00:03, 138.04it/s]
100%|██████████| 543/543 [00:22<00:00, 23.93it/s]


Save Model
Epoch: 1, loss=0.6931515957330033, auc=0.5



543it [00:03, 137.35it/s]
100%|██████████| 543/543 [00:23<00:00, 23.23it/s]

Epoch: 2, loss=0.693149410668438, auc=0.5



543it [00:03, 143.23it/s]
100%|██████████| 543/543 [00:22<00:00, 24.48it/s]

Epoch: 3, loss=0.6931480378955328, auc=0.5



543it [00:03, 137.55it/s]
100%|██████████| 543/543 [00:22<00:00, 23.94it/s]

Epoch: 4, loss=0.693147045582478, auc=0.5



543it [00:03, 140.79it/s]
100%|██████████| 543/543 [00:22<00:00, 23.69it/s]

Epoch: 5, loss=0.693146312873447, auc=0.5



543it [00:03, 140.47it/s]
100%|██████████| 543/543 [00:22<00:00, 24.32it/s]

Epoch: 6, loss=0.6931457732483388, auc=0.5



543it [00:03, 138.00it/s]
100%|██████████| 543/543 [00:22<00:00, 24.44it/s]

Epoch: 7, loss=0.6931453705053523, auc=0.5



543it [00:03, 138.68it/s]
100%|██████████| 543/543 [00:21<00:00, 25.01it/s]

Epoch: 8, loss=0.6931450694085923, auc=0.5



543it [00:03, 140.53it/s]
100%|██████████| 543/543 [00:22<00:00, 24.53it/s]

Epoch: 9, loss=0.6931448461381551, auc=0.5



543it [00:03, 136.34it/s]
100%|██████████| 543/543 [00:22<00:00, 24.61it/s]

Epoch: 10, loss=0.6931446772034436, auc=0.5



543it [00:03, 135.91it/s]
100%|██████████| 543/543 [00:21<00:00, 24.78it/s]

Epoch: 11, loss=0.6931445460293174, auc=0.5



543it [00:03, 136.35it/s]
100%|██████████| 543/543 [00:22<00:00, 23.65it/s]

Epoch: 12, loss=0.6931444487738565, auc=0.5



543it [00:03, 141.42it/s]
100%|██████████| 543/543 [00:22<00:00, 23.64it/s]

Epoch: 13, loss=0.6931443754480688, auc=0.5



543it [00:03, 138.16it/s]
100%|██████████| 543/543 [00:22<00:00, 23.92it/s]

Epoch: 14, loss=0.693144314965271, auc=0.5



543it [00:03, 137.34it/s]
100%|██████████| 543/543 [00:22<00:00, 23.69it/s]

Epoch: 15, loss=0.6931442748995337, auc=0.5



543it [00:03, 139.87it/s]
100%|██████████| 543/543 [00:22<00:00, 23.69it/s]

Epoch: 16, loss=0.6931442442739427, auc=0.5



543it [00:04, 128.71it/s]
100%|██████████| 543/543 [00:21<00:00, 24.97it/s]

Epoch: 17, loss=0.6931442098064318, auc=0.5



543it [00:03, 143.89it/s]
100%|██████████| 543/543 [00:21<00:00, 24.96it/s]

Epoch: 18, loss=0.6931441888405253, auc=0.5



543it [00:03, 136.30it/s]
100%|██████████| 543/543 [00:22<00:00, 23.99it/s]

Epoch: 19, loss=0.6931441723751539, auc=0.5



543it [00:03, 146.54it/s]
100%|██████████| 543/543 [00:22<00:00, 23.90it/s]

Epoch: 20, loss=0.6931441581051653, auc=0.5



543it [00:04, 134.23it/s]
100%|██████████| 543/543 [00:22<00:00, 24.23it/s]

Epoch: 21, loss=0.6931441461403286, auc=0.5



543it [00:04, 132.56it/s]
100%|██████████| 543/543 [00:22<00:00, 24.07it/s]

Epoch: 22, loss=0.6931441388955651, auc=0.5



543it [00:03, 137.21it/s]
100%|██████████| 543/543 [00:22<00:00, 24.50it/s]

Epoch: 23, loss=0.6931441300042646, auc=0.5



543it [00:03, 144.42it/s]
100%|██████████| 543/543 [00:22<00:00, 24.68it/s]

Epoch: 24, loss=0.693144124954884, auc=0.5



543it [00:03, 140.85it/s]
 21%|██        | 115/543 [00:04<00:17, 25.02it/s]