In [None]:
# Paper: Link Prediction Based on Graph Neural Networks (NeurIPS 2018)
# Example: https://github.com/rusty1s/pytorch_geometric/blob/99a496e077a4d41417c7d927df7730fd984004b9/examples/seal_link_pred.py#L90

In [54]:
import math
import random
import os.path as osp
from itertools import chain

import os
import pandas as pd
from sklearn.model_selection import train_test_split
from datetime import datetime

import numpy as np
from sklearn.metrics import roc_auc_score
from scipy.sparse.csgraph import shortest_path
import torch
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss
from torch.nn import ModuleList, Linear, Conv1d, MaxPool1d

import torch_geometric
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, global_sort_pool
from torch_geometric.data import Data, InMemoryDataset, DataLoader, Dataset
from torch_geometric.utils import (negative_sampling, add_self_loops,
                                   train_test_split_edges, k_hop_subgraph,
                                   to_scipy_sparse_matrix, to_undirected, 
                                   contains_self_loops, remove_isolated_nodes, is_undirected)

# Define

In [55]:
def to_list(x):
    if not isinstance(x, (tuple, list)):
        x = [x]
    return x


def files_exist(files):
    return len(files) != 0 and all(osp.exists(f) for f in files)

In [56]:
class SEALDataset(InMemoryDataset):
    def __init__(self, dataset, num_hops, split='train'):
        self.data = dataset[0]
        self.num_hops = num_hops
        super(SEALDataset, self).__init__(dataset.root)
        index = ['train', 'val', 'test'].index(split)
        self.data, self.slices = torch.load(self.processed_paths[index])

    @property
    def processed_file_names(self):
        return ['SEAL_train_data.pt', 'SEAL_val_data.pt', 'SEAL_test_data.pt']

    def process(self):
        random.seed(12345)
        torch.manual_seed(12345)

        print(self.data.edge_index)
        print(is_undirected(self.data.edge_index))
        print(contains_self_loops(self.data.edge_index))
        
        data = train_test_split_edges(self.data)
        

        edge_index, _ = add_self_loops(data.train_pos_edge_index)
        
        data.train_neg_edge_index = negative_sampling(
            edge_index, num_nodes=data.num_nodes,
            num_neg_samples=data.train_pos_edge_index.size(1))

        self.__max_z__ = 0

        
        print(is_undirected(data.train_pos_edge_index))
        print(contains_self_loops(data.train_pos_edge_index))

        # Collect a list of subgraphs for training, validation and test.
        train_pos_list = self.extract_enclosing_subgraphs(
            data.train_pos_edge_index, data.train_pos_edge_index, 1)
        
        print(is_undirected(data.train_neg_edge_index))
        print(contains_self_loops(data.train_neg_edge_index))
        
        train_neg_list = self.extract_enclosing_subgraphs(
            data.train_neg_edge_index, data.train_pos_edge_index, 0)

        val_pos_list = self.extract_enclosing_subgraphs(
            data.val_pos_edge_index, data.train_pos_edge_index, 1)
        val_neg_list = self.extract_enclosing_subgraphs(
            data.val_neg_edge_index, data.train_pos_edge_index, 0)

        test_pos_list = self.extract_enclosing_subgraphs(
            data.test_pos_edge_index, data.train_pos_edge_index, 1)
        test_neg_list = self.extract_enclosing_subgraphs(
            data.test_neg_edge_index, data.train_pos_edge_index, 0)

        # Convert labels to one-hot features.
        for data in chain(train_pos_list, train_neg_list, val_pos_list,
                          val_neg_list, test_pos_list, test_neg_list):
            print(data.z.size())
            z = F.one_hot(data.z, self.__max_z__ + 1).to(torch.float)
            print(data.x.size(), z.size())
            data.x = torch.cat([z, data.x], 1)

            
        torch.save(self.collate(train_pos_list + train_neg_list),
                   self.processed_paths[0])
        torch.save(self.collate(val_pos_list + val_neg_list),
                   self.processed_paths[1])
        torch.save(self.collate(test_pos_list + test_neg_list),
                   self.processed_paths[2])

    def extract_enclosing_subgraphs(self, link_index, edge_index, y):
        data_list = []
        for src, dst in link_index.t().tolist():
            sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(
                [src, dst], self.num_hops, edge_index, relabel_nodes=True)
            src, dst = mapping.tolist()

            # Remove target link from the subgraph.
            mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)
            mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)
            sub_edge_index = sub_edge_index[:, mask1 & mask2]

            # Calculate node labeling.
            z = self.drnl_node_labeling(sub_edge_index, src, dst,
                                        num_nodes=sub_nodes.size(0))
            
            data = Data(x=self.data.x[sub_nodes], z=z,
                        edge_index=sub_edge_index, y=y)
        
            data_list.append(data)

        return data_list

    def drnl_node_labeling(self, edge_index, src, dst, num_nodes=None):
        # Double-radius node labeling (DRNL).
        src, dst = (dst, src) if src > dst else (src, dst)
        adj = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr()

        idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
        adj_wo_src = adj[idx, :][:, idx]

        idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
        adj_wo_dst = adj[idx, :][:, idx]

        dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True,
                                 indices=src)
        dist2src = np.insert(dist2src, dst, 0, axis=0)
        dist2src = torch.from_numpy(dist2src)

        dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True,
                                 indices=dst - 1)
        dist2dst = np.insert(dist2dst, src, 0, axis=0)
        dist2dst = torch.from_numpy(dist2dst)

        dist = dist2src + dist2dst
        dist_over_2, dist_mod_2 = dist // 2, dist % 2

        z = 1 + torch.min(dist2src, dist2dst)
        z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
        z[src] = 1.
        z[dst] = 1.
        z[torch.isnan(z)] = 0.

        self.__max_z__ = max(int(z.max()), self.__max_z__)

        return z.to(torch.long)


In [114]:
class MyDataset(InMemoryDataset):
    
    def __init__(self, num_hops, root=None, split="train"):
        
        self.num_hops = num_hops
        super().__init__(root=root)
        
        index = ['train', 'valid', "test"].index(split)
        self.data, self.slices = torch.load(self.processed_paths[index])

        
    @property
    def raw_file_names(self):  
        return ["train.csv"]
    
    
    @property
    def processed_file_names(self):
        return ['SEAL_data_train.pt', 'SEAL_data_valid.pt', 'SEAL_data_test.pt']
    
    
    def process(self):
        # Read node features
        content = pd.read_csv(os.path.join(self.raw_dir, "content.csv"), delimiter="\t", header=None)
        content = content.sort_values(by=[0]).loc[:, 1:].to_numpy()
        content = torch.from_numpy(content)
        num_nodes = content.size(0)
    
        
        # Read train edge list
        train = pd.read_csv(os.path.join(self.raw_dir, "train.csv"))
        
        train_pos = train[ train["label"] == 1]
        train_neg = train[ train["label"] == 0]

        follower_pos = train_pos["from"].to_numpy().tolist()
        followee_pos = train_pos["to"].to_numpy().tolist()
        train_pos_edge = torch.tensor([follower_pos, followee_pos], dtype=torch.long)
        train_pos_edge, _ = torch_geometric.utils.remove_self_loops(train_pos_edge)


        follower_neg = train_neg["from"].to_numpy().tolist()
        followee_neg = train_neg["to"].to_numpy().tolist()
        train_neg_edge = torch.tensor([follower_neg, followee_neg], dtype=torch.long)
        train_neg_edge, _ = torch_geometric.utils.remove_self_loops(train_neg_edge)
    
        train_pos_edge = train_pos_edge.t()
        train_neg_edge = train_neg_edge.t()

        train_pos_edge, valid_pos_edge = train_test_split(train_pos_edge, shuffle=True)
        train_neg_edge, valid_neg_edge = train_test_split(train_neg_edge, shuffle=True)
        
        train_pos_edge = train_pos_edge.t()
        valid_pos_edge = valid_pos_edge.t()
        train_neg_edge = train_neg_edge.t()
        valid_neg_edge = valid_neg_edge.t()
        
        train_pos_edge, _ = add_self_loops(train_pos_edge, num_nodes=num_nodes)
        train_pos_edge = to_undirected(train_pos_edge)

        self.data = Data(x=content, edge_index=train_pos_edge, num_nodes=num_nodes)
        self.__max_z__ = 0
        
        
        train_pos_list = self.extract_enclosing_subgraphs(
            train_pos_edge, train_pos_edge, 1, num_nodes=num_nodes)
        train_neg_list = self.extract_enclosing_subgraphs(
            train_neg_edge, train_pos_edge, 0, num_nodes=num_nodes)
    

        val_pos_list = self.extract_enclosing_subgraphs(
            valid_pos_edge, train_pos_edge, 1, num_nodes=num_nodes)
        val_neg_list = self.extract_enclosing_subgraphs(
            valid_neg_edge, train_pos_edge, 0, num_nodes=num_nodes)
        
        
        # Read test edge list
        test = pd.read_csv(os.path.join(self.raw_dir, "test.csv"))

        test_follower = test["from"].to_numpy().tolist()
        test_followee = test["to"].to_numpy().tolist()
        test_edge = torch.tensor([test_follower, test_followee], dtype=torch.long)
        
        test_list = self.extract_enclosing_subgraphs(
            test_edge, train_pos_edge, 0, num_nodes=num_nodes)


        # Convert labels to one-hot features.
        for data in chain(train_pos_list, train_neg_list, 
                          val_pos_list, val_neg_list, test_list):
           
            if data.z is not None:
                z = F.one_hot(data.z, self.__max_z__ + 1).to(torch.float)
                data.x = torch.cat([z, data.x], 1)

            
        torch.save(self.collate(train_pos_list + train_neg_list),
                   self.processed_paths[0])
        torch.save(self.collate(val_pos_list + val_neg_list),
                   self.processed_paths[1])
        torch.save(self.collate(test_list), self.processed_paths[2])

    
    
    def extract_enclosing_subgraphs(self, link_index, edge_index, y, num_nodes):
        data_list = []
        for src, dst in link_index.t().tolist():
            sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(
                [src, dst], self.num_hops, edge_index, relabel_nodes=True, num_nodes=num_nodes)
            src, dst = mapping.tolist()
            
            # Remove target link from the subgraph.
            mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)
            mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)
            sub_edge_index = sub_edge_index[:, mask1 & mask2]

            # Calculate node labeling.
            if src != dst:
                z = self.drnl_node_labeling(sub_edge_index, src, dst,
                                            num_nodes=sub_nodes.size(0))

                data = Data(x=self.data.x[sub_nodes], z=z,
                            edge_index=sub_edge_index, y=y)
                data_list.append(data)
            else:
                data = Data(x=self.data.x[sub_nodes], z=None,
                            edge_index=sub_edge_index, y=1)
                data_list.append(data)

        return data_list

    def drnl_node_labeling(self, edge_index, src, dst, num_nodes=None):
        # Double-radius node labeling (DRNL).
        src, dst = (dst, src) if src > dst else (src, dst)
        adj = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr()
        
        idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
        adj_wo_src = adj[idx, :][:, idx]
        
        idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
        adj_wo_dst = adj[idx, :][:, idx]

        
#         z = torch.zeros(num_nodes)
        
#         if src != dst:
        dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True,
                                 indices=src)
        dist2src = np.insert(dist2src, dst, 0, axis=0)
        dist2src = torch.from_numpy(dist2src)


        dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True,
                                 indices=dst - 1)
        dist2dst = np.insert(dist2dst, src, 0, axis=0)
        dist2dst = torch.from_numpy(dist2dst)

        dist = dist2src + dist2dst
        dist_over_2, dist_mod_2 = dist // 2, dist % 2

        z = 1 + torch.min(dist2src, dist2dst)
        z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
        z[src] = 1.
        z[dst] = 1.
        z[torch.isnan(z)] = 0.

        self.__max_z__ = max(int(z.max()), self.__max_z__)

            
        return z.to(torch.long)


# Model

In [58]:
class DGCNN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers, GNN=GCNConv, k=0.6):
        super(DGCNN, self).__init__()

        if k < 1:  # Transform percentile to number.
            num_nodes = sorted([data.num_nodes for data in train_dataset])
            k = num_nodes[int(math.ceil(k * len(num_nodes))) - 1]
            k = max(10, k)
        self.k = int(k)

        self.convs = ModuleList()
        self.convs.append(GNN(train_dataset.num_features, hidden_channels))
        for i in range(0, num_layers - 1):
            self.convs.append(GNN(hidden_channels, hidden_channels))
        self.convs.append(GNN(hidden_channels, 1))

        conv1d_channels = [16, 32]
        total_latent_dim = hidden_channels * num_layers + 1
        conv1d_kws = [total_latent_dim, 5]
        self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0],
                            conv1d_kws[0])
        self.maxpool1d = MaxPool1d(2, 2)
        self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1],
                            conv1d_kws[1], 1)
        dense_dim = int((self.k - 2) / 2 + 1)
        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
        self.lin1 = Linear(dense_dim, 128)
        self.lin2 = Linear(128, 1)

    def forward(self, x, edge_index, batch):
        xs = [x]
        for conv in self.convs:
            xs += [torch.tanh(conv(xs[-1], edge_index))]
        x = torch.cat(xs[1:], dim=-1)

        # Global pooling.

        x = global_sort_pool(x, batch, self.k)

        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
        x = F.relu(self.conv1(x))
        x = self.maxpool1d(x)

        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]

        
        # MLP.
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)

        return x


In [88]:
def train(loader):
    model.train()

    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        logits = model(data.x, data.edge_index, data.batch)
        loss = BCEWithLogitsLoss()(logits.view(-1), data.y.to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs

    return total_loss / len(train_dataset)

In [89]:
@torch.no_grad()
def test(loader):
    model.eval()

    y_pred, y_true = [], []
    for data in loader:
        data = data.to(device)
        logits = model(data.x, data.edge_index, data.batch)
    
        pred = torch.nn.Sigmoid()(logits)
        
        y_pred.append(pred.view(-1).cpu())
        y_true.append(data.y.view(-1).cpu().to(torch.float))

    return roc_auc_score(torch.cat(y_true), torch.cat(y_pred))

In [90]:
@torch.no_grad()
def predict(loader):
    model.eval()
    
    y_pred = []
    for data in loader:
        data = data.to(device)
        logits = model(data.x, data.edge_index, data.batch)
        
        pred = torch.nn.Sigmoid()(logits)
        
        y_pred.append(pred.view(-1).cpu().numpy())

    return y_pred

# Usage

In [25]:
from torch_geometric.datasets import Planetoid

In [39]:
cora = Planetoid(root=os.getcwd(), name='Cora')

cora_train_dataset = SEALDataset(cora, num_hops=2, split='train')
cora_val_dataset = SEALDataset(cora, num_hops=2, split='val')
cora_test_dataset = SEALDataset(cora, num_hops=2, split='test')

cora_train_loader = DataLoader(cora_train_dataset, batch_size=32, shuffle=True)
cora_val_loader = DataLoader(cora_val_dataset, batch_size=32)
cora_test_loader = DataLoader(cora_test_dataset, batch_size=32)

Processing...
tensor([[   0,    0,    0,  ..., 2707, 2707, 2707],
        [ 633, 1862, 2582,  ...,  598, 1473, 2706]])
True
False
True
False
False
False
torch.Size([8])
torch.Size([8, 1433]) torch.Size([8, 136])
torch.Size([66])
torch.Size([66, 1433]) torch.Size([66, 136])
torch.Size([8])
torch.Size([8, 1433]) torch.Size([8, 136])
torch.Size([68])
torch.Size([68, 1433]) torch.Size([68, 136])
torch.Size([13])
torch.Size([13, 1433]) torch.Size([13, 136])
torch.Size([9])
torch.Size([9, 1433]) torch.Size([9, 136])
torch.Size([68])
torch.Size([68, 1433]) torch.Size([68, 136])
torch.Size([76])
torch.Size([76, 1433]) torch.Size([76, 136])
torch.Size([67])
torch.Size([67, 1433]) torch.Size([67, 136])
torch.Size([75])
torch.Size([75, 1433]) torch.Size([75, 136])
torch.Size([163])
torch.Size([163, 1433]) torch.Size([163, 136])
torch.Size([2])
torch.Size([2, 1433]) torch.Size([2, 136])
torch.Size([15])
torch.Size([15, 1433]) torch.Size([15, 136])
torch.Size([160])
torch.Size([160, 1433]) torch.Si

torch.Size([109])
torch.Size([109, 1433]) torch.Size([109, 136])
torch.Size([56])
torch.Size([56, 1433]) torch.Size([56, 136])
torch.Size([54])
torch.Size([54, 1433]) torch.Size([54, 136])
torch.Size([51])
torch.Size([51, 1433]) torch.Size([51, 136])
torch.Size([5])
torch.Size([5, 1433]) torch.Size([5, 136])
torch.Size([44])
torch.Size([44, 1433]) torch.Size([44, 136])
torch.Size([29])
torch.Size([29, 1433]) torch.Size([29, 136])
torch.Size([63])
torch.Size([63, 1433]) torch.Size([63, 136])
torch.Size([30])
torch.Size([30, 1433]) torch.Size([30, 136])
torch.Size([31])
torch.Size([31, 1433]) torch.Size([31, 136])
torch.Size([30])
torch.Size([30, 1433]) torch.Size([30, 136])
torch.Size([18])
torch.Size([18, 1433]) torch.Size([18, 136])
torch.Size([7])
torch.Size([7, 1433]) torch.Size([7, 136])
torch.Size([6])
torch.Size([6, 1433]) torch.Size([6, 136])
torch.Size([6])
torch.Size([6, 1433]) torch.Size([6, 136])
torch.Size([6])
torch.Size([6, 1433]) torch.Size([6, 136])
torch.Size([6])
torc

torch.Size([70])
torch.Size([70, 1433]) torch.Size([70, 136])
torch.Size([63])
torch.Size([63, 1433]) torch.Size([63, 136])
torch.Size([104])
torch.Size([104, 1433]) torch.Size([104, 136])
torch.Size([104])
torch.Size([104, 1433]) torch.Size([104, 136])
torch.Size([102])
torch.Size([102, 1433]) torch.Size([102, 136])
torch.Size([81])
torch.Size([81, 1433]) torch.Size([81, 136])
torch.Size([108])
torch.Size([108, 1433]) torch.Size([108, 136])
torch.Size([102])
torch.Size([102, 1433]) torch.Size([102, 136])
torch.Size([100])
torch.Size([100, 1433]) torch.Size([100, 136])
torch.Size([100])
torch.Size([100, 1433]) torch.Size([100, 136])
torch.Size([29])
torch.Size([29, 1433]) torch.Size([29, 136])
torch.Size([16])
torch.Size([16, 1433]) torch.Size([16, 136])
torch.Size([21])
torch.Size([21, 1433]) torch.Size([21, 136])
torch.Size([8])
torch.Size([8, 1433]) torch.Size([8, 136])
torch.Size([8])
torch.Size([8, 1433]) torch.Size([8, 136])
torch.Size([6])
torch.Size([6, 1433]) torch.Size([6, 13

torch.Size([91])
torch.Size([91, 1433]) torch.Size([91, 136])
torch.Size([82])
torch.Size([82, 1433]) torch.Size([82, 136])
torch.Size([13])
torch.Size([13, 1433]) torch.Size([13, 136])
torch.Size([130])
torch.Size([130, 1433]) torch.Size([130, 136])
torch.Size([39])
torch.Size([39, 1433]) torch.Size([39, 136])
torch.Size([21])
torch.Size([21, 1433]) torch.Size([21, 136])
torch.Size([13])
torch.Size([13, 1433]) torch.Size([13, 136])
torch.Size([161])
torch.Size([161, 1433]) torch.Size([161, 136])
torch.Size([15])
torch.Size([15, 1433]) torch.Size([15, 136])
torch.Size([15])
torch.Size([15, 1433]) torch.Size([15, 136])
torch.Size([6])
torch.Size([6, 1433]) torch.Size([6, 136])
torch.Size([37])
torch.Size([37, 1433]) torch.Size([37, 136])
torch.Size([72])
torch.Size([72, 1433]) torch.Size([72, 136])
torch.Size([10])
torch.Size([10, 1433]) torch.Size([10, 136])
torch.Size([10])
torch.Size([10, 1433]) torch.Size([10, 136])
torch.Size([10])
torch.Size([10, 1433]) torch.Size([10, 136])
torch

torch.Size([77])
torch.Size([77, 1433]) torch.Size([77, 136])
torch.Size([74])
torch.Size([74, 1433]) torch.Size([74, 136])
torch.Size([91])
torch.Size([91, 1433]) torch.Size([91, 136])
torch.Size([74])
torch.Size([74, 1433]) torch.Size([74, 136])
torch.Size([83])
torch.Size([83, 1433]) torch.Size([83, 136])
torch.Size([88])
torch.Size([88, 1433]) torch.Size([88, 136])
torch.Size([51])
torch.Size([51, 1433]) torch.Size([51, 136])
torch.Size([40])
torch.Size([40, 1433]) torch.Size([40, 136])
torch.Size([42])
torch.Size([42, 1433]) torch.Size([42, 136])
torch.Size([44])
torch.Size([44, 1433]) torch.Size([44, 136])
torch.Size([40])
torch.Size([40, 1433]) torch.Size([40, 136])
torch.Size([44])
torch.Size([44, 1433]) torch.Size([44, 136])
torch.Size([72])
torch.Size([72, 1433]) torch.Size([72, 136])
torch.Size([43])
torch.Size([43, 1433]) torch.Size([43, 136])
torch.Size([42])
torch.Size([42, 1433]) torch.Size([42, 136])
torch.Size([41])
torch.Size([41, 1433]) torch.Size([41, 136])
torch.Si

torch.Size([24])
torch.Size([24, 1433]) torch.Size([24, 136])
torch.Size([21])
torch.Size([21, 1433]) torch.Size([21, 136])
torch.Size([3])
torch.Size([3, 1433]) torch.Size([3, 136])
torch.Size([22])
torch.Size([22, 1433]) torch.Size([22, 136])
torch.Size([26])
torch.Size([26, 1433]) torch.Size([26, 136])
torch.Size([15])
torch.Size([15, 1433]) torch.Size([15, 136])
torch.Size([2])
torch.Size([2, 1433]) torch.Size([2, 136])
torch.Size([296])
torch.Size([296, 1433]) torch.Size([296, 136])
torch.Size([43])
torch.Size([43, 1433]) torch.Size([43, 136])
torch.Size([44])
torch.Size([44, 1433]) torch.Size([44, 136])
torch.Size([74])
torch.Size([74, 1433]) torch.Size([74, 136])
torch.Size([35])
torch.Size([35, 1433]) torch.Size([35, 136])
torch.Size([13])
torch.Size([13, 1433]) torch.Size([13, 136])
torch.Size([21])
torch.Size([21, 1433]) torch.Size([21, 136])
torch.Size([19])
torch.Size([19, 1433]) torch.Size([19, 136])
torch.Size([46])
torch.Size([46, 1433]) torch.Size([46, 136])
torch.Size(

torch.Size([93])
torch.Size([93, 1433]) torch.Size([93, 136])
torch.Size([101])
torch.Size([101, 1433]) torch.Size([101, 136])
torch.Size([103])
torch.Size([103, 1433]) torch.Size([103, 136])
torch.Size([6])
torch.Size([6, 1433]) torch.Size([6, 136])
torch.Size([93])
torch.Size([93, 1433]) torch.Size([93, 136])
torch.Size([21])
torch.Size([21, 1433]) torch.Size([21, 136])
torch.Size([39])
torch.Size([39, 1433]) torch.Size([39, 136])
torch.Size([14])
torch.Size([14, 1433]) torch.Size([14, 136])
torch.Size([55])
torch.Size([55, 1433]) torch.Size([55, 136])
torch.Size([14])
torch.Size([14, 1433]) torch.Size([14, 136])
torch.Size([34])
torch.Size([34, 1433]) torch.Size([34, 136])
torch.Size([33])
torch.Size([33, 1433]) torch.Size([33, 136])
torch.Size([40])
torch.Size([40, 1433]) torch.Size([40, 136])
torch.Size([43])
torch.Size([43, 1433]) torch.Size([43, 136])
torch.Size([38])
torch.Size([38, 1433]) torch.Size([38, 136])
torch.Size([35])
torch.Size([35, 1433]) torch.Size([35, 136])
torch

torch.Size([106])
torch.Size([106, 1433]) torch.Size([106, 136])
torch.Size([31])
torch.Size([31, 1433]) torch.Size([31, 136])
torch.Size([5])
torch.Size([5, 1433]) torch.Size([5, 136])
torch.Size([7])
torch.Size([7, 1433]) torch.Size([7, 136])
torch.Size([9])
torch.Size([9, 1433]) torch.Size([9, 136])
torch.Size([5])
torch.Size([5, 1433]) torch.Size([5, 136])
torch.Size([13])
torch.Size([13, 1433]) torch.Size([13, 136])
torch.Size([17])
torch.Size([17, 1433]) torch.Size([17, 136])
torch.Size([7])
torch.Size([7, 1433]) torch.Size([7, 136])
torch.Size([16])
torch.Size([16, 1433]) torch.Size([16, 136])
torch.Size([9])
torch.Size([9, 1433]) torch.Size([9, 136])
torch.Size([15])
torch.Size([15, 1433]) torch.Size([15, 136])
torch.Size([64])
torch.Size([64, 1433]) torch.Size([64, 136])
torch.Size([54])
torch.Size([54, 1433]) torch.Size([54, 136])
torch.Size([111])
torch.Size([111, 1433]) torch.Size([111, 136])
torch.Size([19])
torch.Size([19, 1433]) torch.Size([19, 136])
torch.Size([26])
tor

torch.Size([35])
torch.Size([35, 1433]) torch.Size([35, 136])
torch.Size([26])
torch.Size([26, 1433]) torch.Size([26, 136])
torch.Size([30])
torch.Size([30, 1433]) torch.Size([30, 136])
torch.Size([29])
torch.Size([29, 1433]) torch.Size([29, 136])
torch.Size([20])
torch.Size([20, 1433]) torch.Size([20, 136])
torch.Size([20])
torch.Size([20, 1433]) torch.Size([20, 136])
torch.Size([35])
torch.Size([35, 1433]) torch.Size([35, 136])
torch.Size([20])
torch.Size([20, 1433]) torch.Size([20, 136])
torch.Size([29])
torch.Size([29, 1433]) torch.Size([29, 136])
torch.Size([366])
torch.Size([366, 1433]) torch.Size([366, 136])
torch.Size([146])
torch.Size([146, 1433]) torch.Size([146, 136])
torch.Size([21])
torch.Size([21, 1433]) torch.Size([21, 136])
torch.Size([21])
torch.Size([21, 1433]) torch.Size([21, 136])
torch.Size([32])
torch.Size([32, 1433]) torch.Size([32, 136])
torch.Size([34])
torch.Size([34, 1433]) torch.Size([34, 136])
torch.Size([21])
torch.Size([21, 1433]) torch.Size([21, 136])
to

torch.Size([8])
torch.Size([8, 1433]) torch.Size([8, 136])
torch.Size([9])
torch.Size([9, 1433]) torch.Size([9, 136])
torch.Size([73])
torch.Size([73, 1433]) torch.Size([73, 136])
torch.Size([23])
torch.Size([23, 1433]) torch.Size([23, 136])
torch.Size([23])
torch.Size([23, 1433]) torch.Size([23, 136])
torch.Size([186])
torch.Size([186, 1433]) torch.Size([186, 136])
torch.Size([132])
torch.Size([132, 1433]) torch.Size([132, 136])
torch.Size([153])
torch.Size([153, 1433]) torch.Size([153, 136])
torch.Size([111])
torch.Size([111, 1433]) torch.Size([111, 136])
torch.Size([46])
torch.Size([46, 1433]) torch.Size([46, 136])
torch.Size([20])
torch.Size([20, 1433]) torch.Size([20, 136])
torch.Size([28])
torch.Size([28, 1433]) torch.Size([28, 136])
torch.Size([17])
torch.Size([17, 1433]) torch.Size([17, 136])
torch.Size([25])
torch.Size([25, 1433]) torch.Size([25, 136])
torch.Size([15])
torch.Size([15, 1433]) torch.Size([15, 136])
torch.Size([15])
torch.Size([15, 1433]) torch.Size([15, 136])
to

torch.Size([172, 1433]) torch.Size([172, 136])
torch.Size([98])
torch.Size([98, 1433]) torch.Size([98, 136])
torch.Size([19])
torch.Size([19, 1433]) torch.Size([19, 136])
torch.Size([9])
torch.Size([9, 1433]) torch.Size([9, 136])
torch.Size([10])
torch.Size([10, 1433]) torch.Size([10, 136])
torch.Size([10])
torch.Size([10, 1433]) torch.Size([10, 136])
torch.Size([20])
torch.Size([20, 1433]) torch.Size([20, 136])
torch.Size([11])
torch.Size([11, 1433]) torch.Size([11, 136])
torch.Size([20])
torch.Size([20, 1433]) torch.Size([20, 136])
torch.Size([21])
torch.Size([21, 1433]) torch.Size([21, 136])
torch.Size([20])
torch.Size([20, 1433]) torch.Size([20, 136])
torch.Size([21])
torch.Size([21, 1433]) torch.Size([21, 136])
torch.Size([21])
torch.Size([21, 1433]) torch.Size([21, 136])
torch.Size([30])
torch.Size([30, 1433]) torch.Size([30, 136])
torch.Size([11])
torch.Size([11, 1433]) torch.Size([11, 136])
torch.Size([21])
torch.Size([21, 1433]) torch.Size([21, 136])
torch.Size([10])
torch.Siz

torch.Size([68, 1433]) torch.Size([68, 136])
torch.Size([89])
torch.Size([89, 1433]) torch.Size([89, 136])
torch.Size([136])
torch.Size([136, 1433]) torch.Size([136, 136])
torch.Size([36])
torch.Size([36, 1433]) torch.Size([36, 136])
torch.Size([31])
torch.Size([31, 1433]) torch.Size([31, 136])
torch.Size([35])
torch.Size([35, 1433]) torch.Size([35, 136])
torch.Size([55])
torch.Size([55, 1433]) torch.Size([55, 136])
torch.Size([51])
torch.Size([51, 1433]) torch.Size([51, 136])
torch.Size([35])
torch.Size([35, 1433]) torch.Size([35, 136])
torch.Size([30])
torch.Size([30, 1433]) torch.Size([30, 136])
torch.Size([26])
torch.Size([26, 1433]) torch.Size([26, 136])
torch.Size([54])
torch.Size([54, 1433]) torch.Size([54, 136])
torch.Size([70])
torch.Size([70, 1433]) torch.Size([70, 136])
torch.Size([81])
torch.Size([81, 1433]) torch.Size([81, 136])
torch.Size([55])
torch.Size([55, 1433]) torch.Size([55, 136])
torch.Size([54])
torch.Size([54, 1433]) torch.Size([54, 136])
torch.Size([50])
torch

torch.Size([19])
torch.Size([19, 1433]) torch.Size([19, 136])
torch.Size([24])
torch.Size([24, 1433]) torch.Size([24, 136])
torch.Size([35])
torch.Size([35, 1433]) torch.Size([35, 136])
torch.Size([45])
torch.Size([45, 1433]) torch.Size([45, 136])
torch.Size([36])
torch.Size([36, 1433]) torch.Size([36, 136])
torch.Size([51])
torch.Size([51, 1433]) torch.Size([51, 136])
torch.Size([39])
torch.Size([39, 1433]) torch.Size([39, 136])
torch.Size([46])
torch.Size([46, 1433]) torch.Size([46, 136])
torch.Size([68])
torch.Size([68, 1433]) torch.Size([68, 136])
torch.Size([34])
torch.Size([34, 1433]) torch.Size([34, 136])
torch.Size([35])
torch.Size([35, 1433]) torch.Size([35, 136])
torch.Size([42])
torch.Size([42, 1433]) torch.Size([42, 136])
torch.Size([31])
torch.Size([31, 1433]) torch.Size([31, 136])
torch.Size([37])
torch.Size([37, 1433]) torch.Size([37, 136])
torch.Size([99])
torch.Size([99, 1433]) torch.Size([99, 136])
torch.Size([80])
torch.Size([80, 1433]) torch.Size([80, 136])
torch.Si

torch.Size([83])
torch.Size([83, 1433]) torch.Size([83, 136])
torch.Size([43])
torch.Size([43, 1433]) torch.Size([43, 136])
torch.Size([19])
torch.Size([19, 1433]) torch.Size([19, 136])
torch.Size([18])
torch.Size([18, 1433]) torch.Size([18, 136])
torch.Size([19])
torch.Size([19, 1433]) torch.Size([19, 136])
torch.Size([75])
torch.Size([75, 1433]) torch.Size([75, 136])
torch.Size([54])
torch.Size([54, 1433]) torch.Size([54, 136])
torch.Size([17])
torch.Size([17, 1433]) torch.Size([17, 136])
torch.Size([13])
torch.Size([13, 1433]) torch.Size([13, 136])
torch.Size([15])
torch.Size([15, 1433]) torch.Size([15, 136])
torch.Size([71])
torch.Size([71, 1433]) torch.Size([71, 136])
torch.Size([17])
torch.Size([17, 1433]) torch.Size([17, 136])
torch.Size([12])
torch.Size([12, 1433]) torch.Size([12, 136])
torch.Size([84])
torch.Size([84, 1433]) torch.Size([84, 136])
torch.Size([45])
torch.Size([45, 1433]) torch.Size([45, 136])
torch.Size([50])
torch.Size([50, 1433]) torch.Size([50, 136])
torch.Si

torch.Size([4, 1433]) torch.Size([4, 136])
torch.Size([4])
torch.Size([4, 1433]) torch.Size([4, 136])
torch.Size([31])
torch.Size([31, 1433]) torch.Size([31, 136])
torch.Size([19])
torch.Size([19, 1433]) torch.Size([19, 136])
torch.Size([9])
torch.Size([9, 1433]) torch.Size([9, 136])
torch.Size([2])
torch.Size([2, 1433]) torch.Size([2, 136])
torch.Size([26])
torch.Size([26, 1433]) torch.Size([26, 136])
torch.Size([23])
torch.Size([23, 1433]) torch.Size([23, 136])
torch.Size([28])
torch.Size([28, 1433]) torch.Size([28, 136])
torch.Size([29])
torch.Size([29, 1433]) torch.Size([29, 136])
torch.Size([19])
torch.Size([19, 1433]) torch.Size([19, 136])
torch.Size([4])
torch.Size([4, 1433]) torch.Size([4, 136])
torch.Size([7])
torch.Size([7, 1433]) torch.Size([7, 136])
torch.Size([12])
torch.Size([12, 1433]) torch.Size([12, 136])
torch.Size([13])
torch.Size([13, 1433]) torch.Size([13, 136])
torch.Size([75])
torch.Size([75, 1433]) torch.Size([75, 136])
torch.Size([12])
torch.Size([12, 1433]) to

torch.Size([19, 1433]) torch.Size([19, 136])
torch.Size([17])
torch.Size([17, 1433]) torch.Size([17, 136])
torch.Size([39])
torch.Size([39, 1433]) torch.Size([39, 136])
torch.Size([23])
torch.Size([23, 1433]) torch.Size([23, 136])
torch.Size([113])
torch.Size([113, 1433]) torch.Size([113, 136])
torch.Size([11])
torch.Size([11, 1433]) torch.Size([11, 136])
torch.Size([31])
torch.Size([31, 1433]) torch.Size([31, 136])
torch.Size([3])
torch.Size([3, 1433]) torch.Size([3, 136])
torch.Size([61])
torch.Size([61, 1433]) torch.Size([61, 136])
torch.Size([15])
torch.Size([15, 1433]) torch.Size([15, 136])
torch.Size([11])
torch.Size([11, 1433]) torch.Size([11, 136])
torch.Size([33])
torch.Size([33, 1433]) torch.Size([33, 136])
torch.Size([73])
torch.Size([73, 1433]) torch.Size([73, 136])
torch.Size([14])
torch.Size([14, 1433]) torch.Size([14, 136])
torch.Size([47])
torch.Size([47, 1433]) torch.Size([47, 136])
torch.Size([295])
torch.Size([295, 1433]) torch.Size([295, 136])
torch.Size([27])
torch

torch.Size([5, 1433]) torch.Size([5, 136])
torch.Size([105])
torch.Size([105, 1433]) torch.Size([105, 136])
torch.Size([13])
torch.Size([13, 1433]) torch.Size([13, 136])
torch.Size([44])
torch.Size([44, 1433]) torch.Size([44, 136])
torch.Size([251])
torch.Size([251, 1433]) torch.Size([251, 136])
torch.Size([17])
torch.Size([17, 1433]) torch.Size([17, 136])
torch.Size([164])
torch.Size([164, 1433]) torch.Size([164, 136])
torch.Size([40])
torch.Size([40, 1433]) torch.Size([40, 136])
torch.Size([6])
torch.Size([6, 1433]) torch.Size([6, 136])
torch.Size([151])
torch.Size([151, 1433]) torch.Size([151, 136])
torch.Size([58])
torch.Size([58, 1433]) torch.Size([58, 136])
torch.Size([19])
torch.Size([19, 1433]) torch.Size([19, 136])
torch.Size([35])
torch.Size([35, 1433]) torch.Size([35, 136])
torch.Size([68])
torch.Size([68, 1433]) torch.Size([68, 136])
torch.Size([7])
torch.Size([7, 1433]) torch.Size([7, 136])
torch.Size([33])
torch.Size([33, 1433]) torch.Size([33, 136])
torch.Size([60])
torc

torch.Size([83])
torch.Size([83, 1433]) torch.Size([83, 136])
torch.Size([41])
torch.Size([41, 1433]) torch.Size([41, 136])
torch.Size([65])
torch.Size([65, 1433]) torch.Size([65, 136])
torch.Size([25])
torch.Size([25, 1433]) torch.Size([25, 136])
torch.Size([60])
torch.Size([60, 1433]) torch.Size([60, 136])
torch.Size([162])
torch.Size([162, 1433]) torch.Size([162, 136])
torch.Size([35])
torch.Size([35, 1433]) torch.Size([35, 136])
torch.Size([20])
torch.Size([20, 1433]) torch.Size([20, 136])
torch.Size([47])
torch.Size([47, 1433]) torch.Size([47, 136])
torch.Size([30])
torch.Size([30, 1433]) torch.Size([30, 136])
torch.Size([81])
torch.Size([81, 1433]) torch.Size([81, 136])
torch.Size([29])
torch.Size([29, 1433]) torch.Size([29, 136])
torch.Size([14])
torch.Size([14, 1433]) torch.Size([14, 136])
torch.Size([24])
torch.Size([24, 1433]) torch.Size([24, 136])
torch.Size([20])
torch.Size([20, 1433]) torch.Size([20, 136])
torch.Size([87])
torch.Size([87, 1433]) torch.Size([87, 136])
torch

torch.Size([26])
torch.Size([26, 1433]) torch.Size([26, 136])
torch.Size([35])
torch.Size([35, 1433]) torch.Size([35, 136])
torch.Size([144])
torch.Size([144, 1433]) torch.Size([144, 136])
torch.Size([166])
torch.Size([166, 1433]) torch.Size([166, 136])
torch.Size([25])
torch.Size([25, 1433]) torch.Size([25, 136])
torch.Size([51])
torch.Size([51, 1433]) torch.Size([51, 136])
torch.Size([48])
torch.Size([48, 1433]) torch.Size([48, 136])
torch.Size([13])
torch.Size([13, 1433]) torch.Size([13, 136])
torch.Size([33])
torch.Size([33, 1433]) torch.Size([33, 136])
torch.Size([25])
torch.Size([25, 1433]) torch.Size([25, 136])
torch.Size([11])
torch.Size([11, 1433]) torch.Size([11, 136])
torch.Size([66])
torch.Size([66, 1433]) torch.Size([66, 136])
torch.Size([90])
torch.Size([90, 1433]) torch.Size([90, 136])
torch.Size([23])
torch.Size([23, 1433]) torch.Size([23, 136])
torch.Size([162])
torch.Size([162, 1433]) torch.Size([162, 136])
torch.Size([42])
torch.Size([42, 1433]) torch.Size([42, 136])

torch.Size([71])
torch.Size([71, 1433]) torch.Size([71, 136])
torch.Size([61])
torch.Size([61, 1433]) torch.Size([61, 136])
torch.Size([22])
torch.Size([22, 1433]) torch.Size([22, 136])
torch.Size([27])
torch.Size([27, 1433]) torch.Size([27, 136])
torch.Size([167])
torch.Size([167, 1433]) torch.Size([167, 136])
torch.Size([15])
torch.Size([15, 1433]) torch.Size([15, 136])
torch.Size([36])
torch.Size([36, 1433]) torch.Size([36, 136])
torch.Size([91])
torch.Size([91, 1433]) torch.Size([91, 136])
torch.Size([120])
torch.Size([120, 1433]) torch.Size([120, 136])
torch.Size([76])
torch.Size([76, 1433]) torch.Size([76, 136])
torch.Size([24])
torch.Size([24, 1433]) torch.Size([24, 136])
torch.Size([23])
torch.Size([23, 1433]) torch.Size([23, 136])
torch.Size([95])
torch.Size([95, 1433]) torch.Size([95, 136])
torch.Size([54])
torch.Size([54, 1433]) torch.Size([54, 136])
torch.Size([16])
torch.Size([16, 1433]) torch.Size([16, 136])
torch.Size([70])
torch.Size([70, 1433]) torch.Size([70, 136])
to

torch.Size([32])
torch.Size([32, 1433]) torch.Size([32, 136])
torch.Size([19])
torch.Size([19, 1433]) torch.Size([19, 136])
torch.Size([58])
torch.Size([58, 1433]) torch.Size([58, 136])
torch.Size([30])
torch.Size([30, 1433]) torch.Size([30, 136])
torch.Size([92])
torch.Size([92, 1433]) torch.Size([92, 136])
torch.Size([35])
torch.Size([35, 1433]) torch.Size([35, 136])
torch.Size([24])
torch.Size([24, 1433]) torch.Size([24, 136])
torch.Size([31])
torch.Size([31, 1433]) torch.Size([31, 136])
torch.Size([46])
torch.Size([46, 1433]) torch.Size([46, 136])
torch.Size([46])
torch.Size([46, 1433]) torch.Size([46, 136])
torch.Size([148])
torch.Size([148, 1433]) torch.Size([148, 136])
torch.Size([182])
torch.Size([182, 1433]) torch.Size([182, 136])
torch.Size([34])
torch.Size([34, 1433]) torch.Size([34, 136])
torch.Size([5])
torch.Size([5, 1433]) torch.Size([5, 136])
torch.Size([120])
torch.Size([120, 1433]) torch.Size([120, 136])
torch.Size([161])
torch.Size([161, 1433]) torch.Size([161, 136])

torch.Size([59])
torch.Size([59, 1433]) torch.Size([59, 136])
torch.Size([9])
torch.Size([9, 1433]) torch.Size([9, 136])
torch.Size([77])
torch.Size([77, 1433]) torch.Size([77, 136])
torch.Size([54])
torch.Size([54, 1433]) torch.Size([54, 136])
torch.Size([217])
torch.Size([217, 1433]) torch.Size([217, 136])
torch.Size([166])
torch.Size([166, 1433]) torch.Size([166, 136])
torch.Size([34])
torch.Size([34, 1433]) torch.Size([34, 136])
torch.Size([68])
torch.Size([68, 1433]) torch.Size([68, 136])
torch.Size([30])
torch.Size([30, 1433]) torch.Size([30, 136])
torch.Size([45])
torch.Size([45, 1433]) torch.Size([45, 136])
torch.Size([29])
torch.Size([29, 1433]) torch.Size([29, 136])
torch.Size([61])
torch.Size([61, 1433]) torch.Size([61, 136])
torch.Size([28])
torch.Size([28, 1433]) torch.Size([28, 136])
torch.Size([14])
torch.Size([14, 1433]) torch.Size([14, 136])
torch.Size([165])
torch.Size([165, 1433]) torch.Size([165, 136])
torch.Size([43])
torch.Size([43, 1433]) torch.Size([43, 136])
to

torch.Size([24, 1433]) torch.Size([24, 136])
torch.Size([50])
torch.Size([50, 1433]) torch.Size([50, 136])
torch.Size([9])
torch.Size([9, 1433]) torch.Size([9, 136])
torch.Size([31])
torch.Size([31, 1433]) torch.Size([31, 136])
torch.Size([22])
torch.Size([22, 1433]) torch.Size([22, 136])
torch.Size([28])
torch.Size([28, 1433]) torch.Size([28, 136])
torch.Size([8])
torch.Size([8, 1433]) torch.Size([8, 136])
torch.Size([24])
torch.Size([24, 1433]) torch.Size([24, 136])
torch.Size([3])
torch.Size([3, 1433]) torch.Size([3, 136])
torch.Size([57])
torch.Size([57, 1433]) torch.Size([57, 136])
torch.Size([5])
torch.Size([5, 1433]) torch.Size([5, 136])
torch.Size([22])
torch.Size([22, 1433]) torch.Size([22, 136])
torch.Size([14])
torch.Size([14, 1433]) torch.Size([14, 136])
torch.Size([4])
torch.Size([4, 1433]) torch.Size([4, 136])
torch.Size([24])
torch.Size([24, 1433]) torch.Size([24, 136])
torch.Size([18])
torch.Size([18, 1433]) torch.Size([18, 136])
torch.Size([150])
torch.Size([150, 1433]

torch.Size([60])
torch.Size([60, 1433]) torch.Size([60, 136])
torch.Size([47])
torch.Size([47, 1433]) torch.Size([47, 136])
torch.Size([116])
torch.Size([116, 1433]) torch.Size([116, 136])
torch.Size([207])
torch.Size([207, 1433]) torch.Size([207, 136])
torch.Size([32])
torch.Size([32, 1433]) torch.Size([32, 136])
torch.Size([55])
torch.Size([55, 1433]) torch.Size([55, 136])
torch.Size([6])
torch.Size([6, 1433]) torch.Size([6, 136])
torch.Size([44])
torch.Size([44, 1433]) torch.Size([44, 136])
torch.Size([24])
torch.Size([24, 1433]) torch.Size([24, 136])
torch.Size([48])
torch.Size([48, 1433]) torch.Size([48, 136])
torch.Size([47])
torch.Size([47, 1433]) torch.Size([47, 136])
torch.Size([169])
torch.Size([169, 1433]) torch.Size([169, 136])
torch.Size([31])
torch.Size([31, 1433]) torch.Size([31, 136])
torch.Size([25])
torch.Size([25, 1433]) torch.Size([25, 136])
torch.Size([169])
torch.Size([169, 1433]) torch.Size([169, 136])
torch.Size([15])
torch.Size([15, 1433]) torch.Size([15, 136])

torch.Size([30, 1433]) torch.Size([30, 136])
torch.Size([256])
torch.Size([256, 1433]) torch.Size([256, 136])
torch.Size([56])
torch.Size([56, 1433]) torch.Size([56, 136])
torch.Size([173])
torch.Size([173, 1433]) torch.Size([173, 136])
torch.Size([52])
torch.Size([52, 1433]) torch.Size([52, 136])
torch.Size([30])
torch.Size([30, 1433]) torch.Size([30, 136])
torch.Size([12])
torch.Size([12, 1433]) torch.Size([12, 136])
torch.Size([8])
torch.Size([8, 1433]) torch.Size([8, 136])
torch.Size([38])
torch.Size([38, 1433]) torch.Size([38, 136])
torch.Size([29])
torch.Size([29, 1433]) torch.Size([29, 136])
torch.Size([28])
torch.Size([28, 1433]) torch.Size([28, 136])
torch.Size([30])
torch.Size([30, 1433]) torch.Size([30, 136])
torch.Size([50])
torch.Size([50, 1433]) torch.Size([50, 136])
torch.Size([20])
torch.Size([20, 1433]) torch.Size([20, 136])
torch.Size([104])
torch.Size([104, 1433]) torch.Size([104, 136])
torch.Size([21])
torch.Size([21, 1433]) torch.Size([21, 136])
torch.Size([43])
to

torch.Size([73])
torch.Size([73, 1433]) torch.Size([73, 136])
torch.Size([45])
torch.Size([45, 1433]) torch.Size([45, 136])
torch.Size([35])
torch.Size([35, 1433]) torch.Size([35, 136])
torch.Size([16])
torch.Size([16, 1433]) torch.Size([16, 136])
torch.Size([53])
torch.Size([53, 1433]) torch.Size([53, 136])
torch.Size([40])
torch.Size([40, 1433]) torch.Size([40, 136])
torch.Size([14])
torch.Size([14, 1433]) torch.Size([14, 136])
torch.Size([43])
torch.Size([43, 1433]) torch.Size([43, 136])
torch.Size([69])
torch.Size([69, 1433]) torch.Size([69, 136])
torch.Size([18])
torch.Size([18, 1433]) torch.Size([18, 136])
torch.Size([27])
torch.Size([27, 1433]) torch.Size([27, 136])
torch.Size([23])
torch.Size([23, 1433]) torch.Size([23, 136])
torch.Size([147])
torch.Size([147, 1433]) torch.Size([147, 136])
torch.Size([28])
torch.Size([28, 1433]) torch.Size([28, 136])
torch.Size([7])
torch.Size([7, 1433]) torch.Size([7, 136])
torch.Size([176])
torch.Size([176, 1433]) torch.Size([176, 136])
torch

torch.Size([73, 1433]) torch.Size([73, 136])
torch.Size([32])
torch.Size([32, 1433]) torch.Size([32, 136])
torch.Size([51])
torch.Size([51, 1433]) torch.Size([51, 136])
torch.Size([28])
torch.Size([28, 1433]) torch.Size([28, 136])
torch.Size([89])
torch.Size([89, 1433]) torch.Size([89, 136])
torch.Size([173])
torch.Size([173, 1433]) torch.Size([173, 136])
torch.Size([10])
torch.Size([10, 1433]) torch.Size([10, 136])
torch.Size([54])
torch.Size([54, 1433]) torch.Size([54, 136])
torch.Size([62])
torch.Size([62, 1433]) torch.Size([62, 136])
torch.Size([26])
torch.Size([26, 1433]) torch.Size([26, 136])
torch.Size([52])
torch.Size([52, 1433]) torch.Size([52, 136])
torch.Size([141])
torch.Size([141, 1433]) torch.Size([141, 136])
torch.Size([21])
torch.Size([21, 1433]) torch.Size([21, 136])
torch.Size([29])
torch.Size([29, 1433]) torch.Size([29, 136])
torch.Size([53])
torch.Size([53, 1433]) torch.Size([53, 136])
torch.Size([52])
torch.Size([52, 1433]) torch.Size([52, 136])
torch.Size([10])
to

torch.Size([47, 1433]) torch.Size([47, 136])
torch.Size([15])
torch.Size([15, 1433]) torch.Size([15, 136])
torch.Size([17])
torch.Size([17, 1433]) torch.Size([17, 136])
torch.Size([89])
torch.Size([89, 1433]) torch.Size([89, 136])
torch.Size([48])
torch.Size([48, 1433]) torch.Size([48, 136])
torch.Size([152])
torch.Size([152, 1433]) torch.Size([152, 136])
torch.Size([72])
torch.Size([72, 1433]) torch.Size([72, 136])
torch.Size([65])
torch.Size([65, 1433]) torch.Size([65, 136])
torch.Size([31])
torch.Size([31, 1433]) torch.Size([31, 136])
torch.Size([99])
torch.Size([99, 1433]) torch.Size([99, 136])
torch.Size([17])
torch.Size([17, 1433]) torch.Size([17, 136])
torch.Size([66])
torch.Size([66, 1433]) torch.Size([66, 136])
torch.Size([27])
torch.Size([27, 1433]) torch.Size([27, 136])
torch.Size([44])
torch.Size([44, 1433]) torch.Size([44, 136])
torch.Size([45])
torch.Size([45, 1433]) torch.Size([45, 136])
torch.Size([165])
torch.Size([165, 1433]) torch.Size([165, 136])
torch.Size([96])
to

torch.Size([16, 1433]) torch.Size([16, 136])
torch.Size([12])
torch.Size([12, 1433]) torch.Size([12, 136])
torch.Size([193])
torch.Size([193, 1433]) torch.Size([193, 136])
torch.Size([111])
torch.Size([111, 1433]) torch.Size([111, 136])
torch.Size([24])
torch.Size([24, 1433]) torch.Size([24, 136])
torch.Size([39])
torch.Size([39, 1433]) torch.Size([39, 136])
torch.Size([76])
torch.Size([76, 1433]) torch.Size([76, 136])
torch.Size([28])
torch.Size([28, 1433]) torch.Size([28, 136])
torch.Size([26])
torch.Size([26, 1433]) torch.Size([26, 136])
torch.Size([23])
torch.Size([23, 1433]) torch.Size([23, 136])
torch.Size([36])
torch.Size([36, 1433]) torch.Size([36, 136])
torch.Size([17])
torch.Size([17, 1433]) torch.Size([17, 136])
torch.Size([13])
torch.Size([13, 1433]) torch.Size([13, 136])
torch.Size([43])
torch.Size([43, 1433]) torch.Size([43, 136])
torch.Size([25])
torch.Size([25, 1433]) torch.Size([25, 136])
torch.Size([13])
torch.Size([13, 1433]) torch.Size([13, 136])
torch.Size([70])
to

torch.Size([17])
torch.Size([17, 1433]) torch.Size([17, 136])
torch.Size([171])
torch.Size([171, 1433]) torch.Size([171, 136])
torch.Size([113])
torch.Size([113, 1433]) torch.Size([113, 136])
torch.Size([74])
torch.Size([74, 1433]) torch.Size([74, 136])
torch.Size([22])
torch.Size([22, 1433]) torch.Size([22, 136])
torch.Size([26])
torch.Size([26, 1433]) torch.Size([26, 136])
torch.Size([223])
torch.Size([223, 1433]) torch.Size([223, 136])
torch.Size([60])
torch.Size([60, 1433]) torch.Size([60, 136])
torch.Size([28])
torch.Size([28, 1433]) torch.Size([28, 136])
torch.Size([67])
torch.Size([67, 1433]) torch.Size([67, 136])
torch.Size([10])
torch.Size([10, 1433]) torch.Size([10, 136])
torch.Size([26])
torch.Size([26, 1433]) torch.Size([26, 136])
torch.Size([56])
torch.Size([56, 1433]) torch.Size([56, 136])
torch.Size([17])
torch.Size([17, 1433]) torch.Size([17, 136])
torch.Size([16])
torch.Size([16, 1433]) torch.Size([16, 136])
torch.Size([32])
torch.Size([32, 1433]) torch.Size([32, 136])

torch.Size([40])
torch.Size([40, 1433]) torch.Size([40, 136])
torch.Size([45])
torch.Size([45, 1433]) torch.Size([45, 136])
torch.Size([165])
torch.Size([165, 1433]) torch.Size([165, 136])
torch.Size([221])
torch.Size([221, 1433]) torch.Size([221, 136])
torch.Size([79])
torch.Size([79, 1433]) torch.Size([79, 136])
torch.Size([17])
torch.Size([17, 1433]) torch.Size([17, 136])
torch.Size([20])
torch.Size([20, 1433]) torch.Size([20, 136])
torch.Size([30])
torch.Size([30, 1433]) torch.Size([30, 136])
torch.Size([24])
torch.Size([24, 1433]) torch.Size([24, 136])
torch.Size([63])
torch.Size([63, 1433]) torch.Size([63, 136])
torch.Size([183])
torch.Size([183, 1433]) torch.Size([183, 136])
torch.Size([25])
torch.Size([25, 1433]) torch.Size([25, 136])
torch.Size([20])
torch.Size([20, 1433]) torch.Size([20, 136])
torch.Size([33])
torch.Size([33, 1433]) torch.Size([33, 136])
torch.Size([53])
torch.Size([53, 1433]) torch.Size([53, 136])
torch.Size([20])
torch.Size([20, 1433]) torch.Size([20, 136])

torch.Size([179])
torch.Size([179, 1433]) torch.Size([179, 136])
torch.Size([177])
torch.Size([177, 1433]) torch.Size([177, 136])
torch.Size([77])
torch.Size([77, 1433]) torch.Size([77, 136])
torch.Size([17])
torch.Size([17, 1433]) torch.Size([17, 136])
torch.Size([10])
torch.Size([10, 1433]) torch.Size([10, 136])
torch.Size([217])
torch.Size([217, 1433]) torch.Size([217, 136])
torch.Size([28])
torch.Size([28, 1433]) torch.Size([28, 136])
torch.Size([67])
torch.Size([67, 1433]) torch.Size([67, 136])
torch.Size([140])
torch.Size([140, 1433]) torch.Size([140, 136])
torch.Size([172])
torch.Size([172, 1433]) torch.Size([172, 136])
torch.Size([12])
torch.Size([12, 1433]) torch.Size([12, 136])
torch.Size([217])
torch.Size([217, 1433]) torch.Size([217, 136])
torch.Size([81])
torch.Size([81, 1433]) torch.Size([81, 136])
torch.Size([34])
torch.Size([34, 1433]) torch.Size([34, 136])
torch.Size([56])
torch.Size([56, 1433]) torch.Size([56, 136])
torch.Size([75])
torch.Size([75, 1433]) torch.Size([

torch.Size([65, 1433]) torch.Size([65, 136])
torch.Size([9])
torch.Size([9, 1433]) torch.Size([9, 136])
torch.Size([33])
torch.Size([33, 1433]) torch.Size([33, 136])
torch.Size([55])
torch.Size([55, 1433]) torch.Size([55, 136])
torch.Size([159])
torch.Size([159, 1433]) torch.Size([159, 136])
torch.Size([40])
torch.Size([40, 1433]) torch.Size([40, 136])
torch.Size([29])
torch.Size([29, 1433]) torch.Size([29, 136])
torch.Size([74])
torch.Size([74, 1433]) torch.Size([74, 136])
torch.Size([4])
torch.Size([4, 1433]) torch.Size([4, 136])
torch.Size([21])
torch.Size([21, 1433]) torch.Size([21, 136])
torch.Size([36])
torch.Size([36, 1433]) torch.Size([36, 136])
torch.Size([17])
torch.Size([17, 1433]) torch.Size([17, 136])
torch.Size([66])
torch.Size([66, 1433]) torch.Size([66, 136])
torch.Size([151])
torch.Size([151, 1433]) torch.Size([151, 136])
torch.Size([19])
torch.Size([19, 1433]) torch.Size([19, 136])
torch.Size([22])
torch.Size([22, 1433]) torch.Size([22, 136])
torch.Size([35])
torch.Si

Done!


In [12]:
cora_train_dataset

SEALDataset(17952)

In [94]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = DGCNN(hidden_channels=32, num_layers=3).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)

In [None]:
best_val_auc = test_auc = 0
for epoch in range(1, 51):
    loss = train(cora_train_loader)
    val_auc = test(cora_val_loader)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        test_auc = test(cora_test_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
          f'Test: {test_auc:.4f}')

# My Usage

In [115]:
def load_checkpoint(filepath, device):
    
    model = DGCNN(hidden_channels=32, num_layers=3).to(device)

    if os.path.exists(filepath):
        print("pretrained finded")
        checkpoint = torch.load(filepath)
        model.load_state_dict(checkpoint['model_stat'])
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        optimizer.load_state_dict(checkpoint['optimizer_stat'])

    else:
        print("use a new optimizer")
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    return model, optimizer

In [116]:
root = os.getcwd()

In [117]:
dataset = "dataset1"

In [118]:
date_time = datetime.strftime(datetime.now(), "%Y-%m-%d_%H-%M")

save = os.path.join(root, "results", dataset, date_time)

if os.path.exists(save):
    pass
else:
    os.makedirs(save)

In [119]:
train_dataset =  MyDataset(num_hops=2, root=os.path.join(os.getcwd(), "hw2_data", dataset), split="train")
val_dataset =  MyDataset(num_hops=2, root=os.path.join(os.getcwd(), "hw2_data", dataset), split="valid")
test_dataset = MyDataset(num_hops=2, root=os.path.join(os.getcwd(), "hw2_data", dataset), split="test")

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

Processing...


RuntimeError: Sizes of tensors must match except in dimension 0. Got 1433 and 1547 in dimension 1 (The offending index is 1)

In [None]:
train_dataset

In [None]:
train_dataset[0]

In [None]:
test_dataset

In [None]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

model, optimizer = load_checkpoint(os.path.join(save, "{}.pth".format(date_time)), device)

In [None]:
best_val_auc = test_auc = 0
for epoch in range(1, 101):
    loss = train(train_loader)
    val_auc = test(val_loader)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        

        checkpoint = {
            'model_stat': model.state_dict(),
            'optimizer_stat': optimizer.state_dict(),
        }
        
        torch.save(checkpoint, os.path.join(save, 
                                            "{}.pth".format(date_time)))
        print("\nSave Model")
        
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_auc:.4f}')

In [None]:
model, optimizer = load_checkpoint(os.path.join(save, "{}.pth".format(date_time)), device)

In [None]:
pred = predict(test_loader)
pred = np.concatenate(pred)
pred = np.round(pred,3)
pred

In [None]:
test_pred = pd.read_csv(os.path.join(os.getcwd(), "hw2_data", dataset, "raw", "test.csv"))

test_pred["prob"] = pred

# test.loc[(test["from"] == test["to"])]

In [None]:
upload = test_pred[["id", "prob"]]
upload

In [None]:
test_pred.to_csv(os.path.join(save, "{}.csv".format(date_time)), index=False)

In [None]:
upload.to_csv(os.path.join(save, "upload.csv"), index=False)