In [1]:
import random
from collections import namedtuple

import numpy as np
import torch as th
import dgl

from utils import subgraph_extraction_labeling

Using backend: pytorch


In [46]:
class RottenTomatoDataset(th.utils.data.Dataset):
    def __init__(self, links, g_labels, graph, 
                hop=1, sample_ratio=1.0, max_nodes_per_hop=200):
        self.links = links
        self.g_labels = g_labels
        self.graph = graph 

        self.hop = hop
        self.sample_ratio = sample_ratio
        self.max_nodes_per_hop = max_nodes_per_hop

    def __len__(self):
        return len(self.links[0])

    # 배치 단위로 묶어서 출력함
    def __getitem__(self, idx):
        u, v = self.links[0][idx], self.links[1][idx]
        g_label = self.g_labels[idx]

        subgraph = subgraph_extraction_labeling(
            (u, v), self.graph, 
            hop=self.hop, sample_ratio=self.sample_ratio, max_nodes_per_hop=self.max_nodes_per_hop)

        return subgraph, g_label

    # 배치 단위로
def collate_rotten_tomato(data):
    print(data[0])
    g_list, label_list = map(list, zip(*data))
    g = dgl.batch(g_list)
    g_label = th.stack(label_list)
    return g, g_label

In [176]:
class MultiRottenTomatoDataset(th.utils.data.Dataset):
    def __init__(self, links, g_labels, graph, 
                hop, sample_ratio, max_nodes_per_hop):
        # 리스트로 입력받음
        self.links = links
        self.g_labels = g_labels
        self.graph = graph 

        self.hop = hop
        self.sample_ratio = sample_ratio
        self.max_nodes_per_hop = max_nodes_per_hop

    def __len__(self):
        return len(self.links[0][0])

    def __getitem__(self, idx):
        # rating
        u_r, v_r = self.links[0][0][idx], self.links[0][1][idx]
        g_label_r = self.g_labels[0][idx]

        subgraph_r = subgraph_extraction_labeling(
            (u_r, v_r), self.graph[0], 
            hop=self.hop[0], sample_ratio=self.sample_ratio[0], max_nodes_per_hop=self.max_nodes_per_hop[0])
        
        # sentiment
        u_s, v_s = self.links[1][0][idx], self.links[1][1][idx]
        g_label_s = self.g_labels[1][idx]
        
        subgraph_s = subgraph_extraction_labeling(
            (u_s, v_s), self.graph[1], 
            hop=self.hop[1], sample_ratio=self.sample_ratio[1], max_nodes_per_hop=self.max_nodes_per_hop[1])
        
        # emotion
        u_e, v_e = self.links[2][0][idx], self.links[2][1][idx]
        g_label_e = self.g_labels[2][idx]
        
        subgraph_e = subgraph_extraction_labeling(
            (u_e, v_e), self.graph[2], 
            hop=self.hop[2], sample_ratio=self.sample_ratio[2], max_nodes_per_hop=self.max_nodes_per_hop[2])
        
        
        # 통합
        g_label = [g_label_r, g_label_s, g_label_e]
        subgraph = [subgraph_r, subgraph_s, subgraph_e]

        return subgraph, g_label

# data : tuple 형태(subgraph, g_label)
def multi_collate_rotten_tomato(data):
    r_data    = list()
    s_data = list()
    e_data   = list()

    # batch 샘플 순서
    for i in range(len(data)):
        r_data.append((data[i][0][0], data[i][1][0])) # rating
        s_data.append((data[i][0][1], data[i][1][1])) # sentiment
        e_data.append((data[i][0][2], data[i][1][2])) # emotion

   # rating
    g_list_r, label_list_r = map(list, zip(*r_data))
    g_r = dgl.batch(g_list_r)
    g_label_r = th.stack(label_list_r)
    
    # sentiment
    g_list_s, label_list_s = map(list, zip(*s_data))
    g_s = dgl.batch(g_list_s)
    g_label_s = th.stack(label_list_s)
    
    # emotion
    g_list_e, label_list_e = map(list, zip(*e_data))
    g_e = dgl.batch(g_list_e)
    g_label_e = th.stack(label_list_e)
    
    # 리스트로 출력
    g = [g_r, g_s, g_e]
    g_label = [g_label_r, g_label_s, g_label_e]
    
    return g, g_label

# 1. Main()

In [4]:
import easydict

args = easydict.EasyDict({ 
    'data_name':            'rotten',
    'testing':     	        True,
    'device':      	        0,
    'seed':        	        1234,
    'data_test_ratio':      0.1,
    'num_workers':   	    8,
    'data_valid_ratio':     0.2,
    'train_log_interval':   200,
    'valid_log_interval':   10,
    'save_appendix':   	    'debug',
    'hop':   	            1,
    'sample_ratio':    	    1.0,
    'max_nodes_per_hop':    100,
    'edge_dropout':   	    0.2,
    'force_undirected':     False,
    'train_lr':   	        1e-3,
    'train_min_lr':   	    1e-6,
    'train_lr_decay_factor':0.1,
    'train_lr_decay_step':  50,
    'train_epochs':   	    10,
    'batch_size':   	    16,
    'arr_lambda':   	    0.001,
    'num_rgcn_bases':   	4,
    'train_epochs':   	    1
})

In [5]:
random.seed(args.seed)
np.random.seed(args.seed)
th.manual_seed(args.seed)
if th.cuda.is_available():
    th.cuda.manual_seed_all(args.seed)    

In [6]:
import time
from explicit_model_rotten import IGMC
from explicit_data_rotten import RottenTomato

### prepare data and set model
path = './raw_data/rotten_tomato/'
rotten_tomato_r = RottenTomato('rating',    path, testing=args.testing,test_ratio=args.data_test_ratio, valid_ratio=args.data_valid_ratio)
rotten_tomato_s = RottenTomato('sentiment', path, testing=args.testing,test_ratio=args.data_test_ratio, valid_ratio=args.data_valid_ratio)
rotten_tomato_e = RottenTomato('emotion',   path, testing=args.testing,test_ratio=args.data_test_ratio, valid_ratio=args.data_valid_ratio)

Label_type: rating
	Train rating pairs : 216328
	Valid rating pairs : 43266
	Test rating pairs  : 28766
Label_type: sentiment
	Train rating pairs : 216328
	Valid rating pairs : 43266
	Test rating pairs  : 28766
Label_type: emotion
	Train rating pairs : 216328
	Valid rating pairs : 43266
	Test rating pairs  : 28766


## 기존 코드

In [177]:
train_dataset_r = RottenTomatoDataset(
    rotten_tomato_r.train_rating_pairs, rotten_tomato_r.train_rating_values, rotten_tomato_r.train_graph, 
    hop=1, sample_ratio=1.0, max_nodes_per_hop=200)

In [178]:
train_loader_r = th.utils.data.DataLoader(train_dataset_r, batch_size=32, shuffle=True, 
                        num_workers=0, collate_fn=collate_rotten_tomato)

In [179]:
batch = next(iter(train_loader_r))

In [181]:
batch

(Graph(num_nodes=8125, num_edges=303226,
       ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'nlabel': Scheme(shape=(4,), dtype=torch.float32), 'x': Scheme(shape=(4,), dtype=torch.float32)}
       edata_schemes={'etype': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64), 'edge_mask': Scheme(shape=(), dtype=torch.float32)}),
 tensor([4., 7., 5., 6., 6., 6., 4., 4., 4., 7., 4., 7., 7., 5., 5., 8., 4., 9.,
         4., 6., 4., 6., 4., 2., 6., 4., 3., 6., 4., 6., 6., 6.]))

In [185]:
batch[0]

Graph(num_nodes=8125, num_edges=303226,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'nlabel': Scheme(shape=(4,), dtype=torch.float32), 'x': Scheme(shape=(4,), dtype=torch.float32)}
      edata_schemes={'etype': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64), 'edge_mask': Scheme(shape=(), dtype=torch.float32)})

In [183]:
batch[1]

tensor([4., 7., 5., 6., 6., 6., 4., 4., 4., 7., 4., 7., 7., 5., 5., 8., 4., 9.,
        4., 6., 4., 6., 4., 2., 6., 4., 3., 6., 4., 6., 6., 6.])

In [175]:
batch[:3]

[(Graph(num_nodes=312, num_edges=10492,
        ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'nlabel': Scheme(shape=(4,), dtype=torch.float32), 'x': Scheme(shape=(4,), dtype=torch.float32)}
        edata_schemes={'etype': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64), 'edge_mask': Scheme(shape=(), dtype=torch.float32)}),
  tensor(3.)),
 (Graph(num_nodes=257, num_edges=6412,
        ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'nlabel': Scheme(shape=(4,), dtype=torch.float32), 'x': Scheme(shape=(4,), dtype=torch.float32)}
        edata_schemes={'etype': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64), 'edge_mask': Scheme(shape=(), dtype=torch.float32)}),
  tensor(2.)),
 (Graph(num_nodes=257, num_edges=9924,
        ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'nlabel': Scheme(shape=(4,), dtype=torch.float32), 'x': Scheme(shape=(4,), dtype=torch.float32)}
        edata_schemes={'et

In [172]:
# 32개 샘플에 대한 그래프 1개, 레이블 리스트 1개
batch[0]

(Graph(num_nodes=312, num_edges=10492,
       ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'nlabel': Scheme(shape=(4,), dtype=torch.float32), 'x': Scheme(shape=(4,), dtype=torch.float32)}
       edata_schemes={'etype': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64), 'edge_mask': Scheme(shape=(), dtype=torch.float32)}),
 tensor(3.))

In [173]:
batch[0][0]

Graph(num_nodes=312, num_edges=10492,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'nlabel': Scheme(shape=(4,), dtype=torch.float32), 'x': Scheme(shape=(4,), dtype=torch.float32)}
      edata_schemes={'etype': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64), 'edge_mask': Scheme(shape=(), dtype=torch.float32)})

In [174]:
batch[0][1]

tensor(3.)

## 수정중

In [145]:
# 수정한 것 (단, 그래프의 모든 행의 길이 동일함)
train_rating_pairs  = [rotten_tomato_r.train_rating_pairs, rotten_tomato_s.train_rating_pairs, rotten_tomato_e.train_rating_pairs]
train_rating_values = [rotten_tomato_r.train_rating_values, rotten_tomato_s.train_rating_values, rotten_tomato_e.train_rating_values]
train_graph         = [rotten_tomato_r.train_graph, rotten_tomato_s.train_graph, rotten_tomato_e.train_graph]
hop = [1, 1, 1]
sample_ratio = [1.0, 1.0, 1.0]
max_nodes_per_hop = [200, 200, 200]

In [146]:
train_dataset = MultiRottenTomatoDataset(
    train_rating_pairs, train_rating_values, train_graph, hop, sample_ratio, max_nodes_per_hop)

train_loader = th.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, 
                        num_workers=0, collate_fn=multi_collate_rotten_tomato)

In [147]:
len(train_loader)

6761

In [148]:
batch = next(iter(train_loader))

In [149]:
batch

([Graph(num_nodes=8069, num_edges=248580,
        ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'nlabel': Scheme(shape=(4,), dtype=torch.float32), 'x': Scheme(shape=(4,), dtype=torch.float32)}
        edata_schemes={'etype': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64), 'edge_mask': Scheme(shape=(), dtype=torch.float32)}),
  Graph(num_nodes=8069, num_edges=244618,
        ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'nlabel': Scheme(shape=(4,), dtype=torch.float32), 'x': Scheme(shape=(4,), dtype=torch.float32)}
        edata_schemes={'etype': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64), 'edge_mask': Scheme(shape=(), dtype=torch.float32)}),
  Graph(num_nodes=8069, num_edges=241996,
        ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'nlabel': Scheme(shape=(4,), dtype=torch.float32), 'x': Scheme(shape=(4,), dtype=torch.float32)}
        edata_schemes={'etype': Scheme(shape=(),

In [157]:
batch[0][0]

Graph(num_nodes=8069, num_edges=248580,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'nlabel': Scheme(shape=(4,), dtype=torch.float32), 'x': Scheme(shape=(4,), dtype=torch.float32)}
      edata_schemes={'etype': Scheme(shape=(), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64), 'edge_mask': Scheme(shape=(), dtype=torch.float32)})

In [156]:
batch[1][0]

tensor([3., 6., 6., 9., 9., 6., 3., 4., 6., 4., 4., 3., 5., 3., 9., 6., 1., 9.,
        6., 4., 6., 4., 6., 5., 4., 1., 6., 2., 5., 1., 6., 5.])

In [23]:
iter_dur = []
t_epoch = time.time()
for iter_idx, batch in enumerate(train_loader, start=1):
    t_iter = time.time()

    inputs = batch[0] # .to(th.device('cuda:0'))

    iter_dur.append(time.time() - t_iter)
    if iter_idx % 100 == 0:
        print("Iter={}, time={:.4f}".format(
            iter_idx, np.average(iter_dur)))
        iter_dur = []
        
    if iter_idx == 200:
        break
print("Epoch time={:.2f}".format(time.time()-t_epoch))

Iter=100, time=0.0010
Iter=200, time=0.0011
Epoch time=28.49
