In [1]:
import random
from collections import namedtuple

import numpy as np
import torch as th
import dgl

from utils import subgraph_extraction_labeling

Using backend: pytorch


In [16]:
class RottenTomatoDataset(th.utils.data.Dataset):
    def __init__(self, links, g_labels, graph, 
                hop=1, sample_ratio=1.0, max_nodes_per_hop=200):
        self.links = links
        self.g_labels = g_labels
        self.graph = graph 

        self.hop = hop
        self.sample_ratio = sample_ratio
        self.max_nodes_per_hop = max_nodes_per_hop

    def __len__(self):
        return len(self.links[0])

    def __getitem__(self, idx):
        u, v = self.links[0][idx], self.links[1][idx]
        g_label = self.g_labels[idx]

        subgraph = subgraph_extraction_labeling(
            (u, v), self.graph, 
            hop=self.hop, sample_ratio=self.sample_ratio, max_nodes_per_hop=self.max_nodes_per_hop)

        return subgraph, g_label

In [27]:
def collate_rotten_tomato(data):
    g_list, label_list = map(list, zip(*data))
    g = dgl.batch(g_list)
    g_label = th.stack(label_list)
    return g, g_label

# 1. Main()

In [10]:
import easydict

args = easydict.EasyDict({ 
    'data_name':            'rotten',
    'testing':     	        True,
    'device':      	        0,
    'seed':        	        1234,
    'data_test_ratio':      0.1,
    'num_workers':   	    8,
    'data_valid_ratio':     0.2,
    'train_log_interval':   200,
    'valid_log_interval':   10,
    'save_appendix':   	    'debug',
    'hop':   	            1,
    'sample_ratio':    	    1.0,
    'max_nodes_per_hop':    100,
    'edge_dropout':   	    0.2,
    'force_undirected':     False,
    'train_lr':   	        1e-3,
    'train_min_lr':   	    1e-6,
    'train_lr_decay_factor':0.1,
    'train_lr_decay_step':  50,
    'train_epochs':   	    10,
    'batch_size':   	    16,
    'arr_lambda':   	    0.001,
    'num_rgcn_bases':   	4,
    'train_epochs':   	    1
})

In [11]:
random.seed(args.seed)
np.random.seed(args.seed)
th.manual_seed(args.seed)
if th.cuda.is_available():
    th.cuda.manual_seed_all(args.seed)    

In [15]:
import time
from explicit_model_rotten import IGMC
from explicit_data_rotten import RottenTomato

### prepare data and set model
path = './raw_data/rotten_tomato/'
rotten_tomato_r = RottenTomato('rating',    path, testing=args.testing,test_ratio=args.data_test_ratio, valid_ratio=args.data_valid_ratio)
rotten_tomato_s = RottenTomato('sentiment', path, testing=args.testing,test_ratio=args.data_test_ratio, valid_ratio=args.data_valid_ratio)
rotten_tomato_e = RottenTomato('emotion',   path, testing=args.testing,test_ratio=args.data_test_ratio, valid_ratio=args.data_valid_ratio)

Label_type: rating
	Train rating pairs : 216328
	Valid rating pairs : 43266
	Test rating pairs  : 28766
Label_type: sentiment
	Train rating pairs : 216328
	Valid rating pairs : 43266
	Test rating pairs  : 28766
Label_type: emotion
	Train rating pairs : 216328
	Valid rating pairs : 43266
	Test rating pairs  : 28766


In [28]:
train_dataset = RottenTomatoDataset(
    rotten_tomato_r.train_rating_pairs, rotten_tomato_r.train_rating_values, rotten_tomato_r.train_graph, 
    hop=1, sample_ratio=1.0, max_nodes_per_hop=200)

train_loader = th.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, 
                        num_workers=0, collate_fn=collate_rotten_tomato)
# batch = next(iter(train_loader))

In [29]:
batch = next(iter(train_loader))

In [30]:
batch

(Graph(num_nodes=7847, num_edges=258656,
       ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'nlabel': Scheme(shape=(4,), dtype=torch.float32), 'x': Scheme(shape=(4,), dtype=torch.float32)}
       edata_schemes={'etype': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64), 'edge_mask': Scheme(shape=(), dtype=torch.float32)}),
 tensor([6., 0., 6., 2., 7., 7., 3., 6., 5., 6., 4., 0., 4., 5., 9., 5., 6., 4.,
         5., 6., 3., 4., 0., 7., 6., 7., 6., 5., 6., 5., 5., 0.]))

In [23]:
iter_dur = []
t_epoch = time.time()
for iter_idx, batch in enumerate(train_loader, start=1):
    t_iter = time.time()

    inputs = batch[0] # .to(th.device('cuda:0'))

    iter_dur.append(time.time() - t_iter)
    if iter_idx % 100 == 0:
        print("Iter={}, time={:.4f}".format(
            iter_idx, np.average(iter_dur)))
        iter_dur = []
        
    if iter_idx == 200:
        break
print("Epoch time={:.2f}".format(time.time()-t_epoch))

Iter=100, time=0.0010
Iter=200, time=0.0011
Epoch time=28.49


## 여러개 데이터셋 처리

In [24]:
train_dataset_r = RottenTomatoDataset(
    rotten_tomato_r.train_rating_pairs, rotten_tomato_r.train_rating_values, rotten_tomato_r.train_graph, 
    args.hop, args.sample_ratio, args.max_nodes_per_hop)

train_dataset_s = RottenTomatoDataset(
    rotten_tomato_s.train_rating_pairs, rotten_tomato_s.train_rating_values, rotten_tomato_s.train_graph, 
    args.hop, args.sample_ratio, args.max_nodes_per_hop)

train_dataset_e = RottenTomatoDataset(
    rotten_tomato_e.train_rating_pairs, rotten_tomato_e.train_rating_values, rotten_tomato_e.train_graph, 
    args.hop, args.sample_ratio, args.max_nodes_per_hop)

In [None]:
class MultiGraphDataset(Dataset):
    def __init__(self, graph_r, graph_s, graph_e):
        super(MultiGraphDataset, self).__init__()

        self.graph_r = graph_r
        self.graph_s = graph_s
        self.graph_e = graph_e

    def __len__(self):
        return len(self.graph_r)

    def __getitem__(self, idx):
        return self.graph_r, self.graph_s, self.graph_e

In [None]:
multi_graph_dataset = MultiGraphDataset(train_dataset_r, train_dataset_s, train_dataset_e)

In [None]:
train_loader = th.utils.data.DataLoader(dataset=multi_graph_dataset, batch_size=args.batch_size, 
                                        num_workers=args.num_workers, collate_fn=collate_rotten_tomato)

In [14]:
# iter_dur = []
# t_epoch = time.time()
# for iter_idx, batch in enumerate(train_loader, start=1):
#     t_iter = time.time()

#     inputs = batch[0] # .to(th.device('cuda:0'))

#     iter_dur.append(time.time() - t_iter)
#     if iter_idx % 100 == 0:
#         print("Iter={}, time={:.4f}".format(
#             iter_idx, np.average(iter_dur)))
#         iter_dur = []
# print("Epoch time={:.2f}".format(time.time()-t_epoch))

RuntimeError: DataLoader worker (pid(s) 60704, 50908, 55072, 5860, 57008, 52056, 60396, 21256) exited unexpectedly