In [1]:
!nvidia-smi

Tue Dec 21 16:00:36 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.80       Driver Version: 460.80       CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:02:00.0 Off |                    0 |
| N/A   30C    P0    34W / 250W |   1071MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  Off  | 00000000:03:00.0 Off |                    0 |
| N/A   30C    P0    31W / 250W |  10905MiB / 16280MiB |      0%      Defaul

In [2]:
import json
import pickle
import re
import nltk
from collections import Counter
import pandas as pd
import random
import heapq
import csv
from tqdm import tqdm
import os
import numpy as np
import time
import math
import lmdb
import gensim

import dgl
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchfile
import dgl.function as fn
from dgl.utils import expand_as_pair
from dgl.nn.functional import edge_softmax

from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, accuracy_score

# from dgl.nn.pytorch.conv import GINConv
from dgl.sampling import RandomWalkNeighborSampler
from torch.utils.data import DataLoader

Using backend: pytorch


In [3]:
device = ("cuda:3" if torch.cuda.is_available() else "cpu")
print('device: ', device)

dataset_folder = '../data/'
ckpt_folder = '../data/'
tuned_emb_file = 'pretrained_image_emb.pt'

device:  cuda:3


## build graph

In [4]:
def get_graph():
    print('generating graph ...')
    all_r2i_src_dst, train_r2i_src_dst, val_r2i_src_dst, test_r2i_src_dst = torch.load(dataset_folder+'r2i_all_train_val_test_src_dst.pt')
    r2i_edge_src, r2i_edge_dst = all_r2i_src_dst
    r2r_edge_src, r2r_edge_dst, r2r_edge_weight = torch.load(dataset_folder+'r2r_edge_src_dst_weight.pt')
    i2i_edge_src, i2i_edge_dst, i2i_edge_weight = torch.load(dataset_folder+'/i2i_edge_src_dst_weight.pt')
    all_u2r_src_dst_weight, train_u2r_src_dst_weight, val_u2r_src_dst_weight, test_u2r_src_dst_weight = torch.load(dataset_folder+'u2r_all_train_val_test_src_dst_weight.pt')
    u2r_edge_src, u2r_edge_dst, u2r_edge_weight = all_u2r_src_dst_weight

    # nodes and edges
    graph = dgl.heterograph({
        ('recipe', 'r-i', 'ingredient'): (r2i_edge_src, r2i_edge_dst),
        ('ingredient', 'i-r', 'recipe'): (r2i_edge_dst, r2i_edge_src),
        ('recipe', 'r-r', 'recipe'): (r2r_edge_src, r2r_edge_dst),
        ('ingredient', 'i-i', 'ingredient'): (i2i_edge_src, i2i_edge_dst),
        ('user', 'u-r', 'recipe'): (u2r_edge_src, u2r_edge_dst),
        ('recipe', 'r-u', 'user'): (u2r_edge_dst, u2r_edge_src)
    })

    # edge weight
    graph.edges['r-r'].data['weight'] = torch.FloatTensor(r2r_edge_weight)
    graph.edges['i-i'].data['weight'] = torch.FloatTensor(i2i_edge_weight)
    graph.edges['u-r'].data['weight'] = torch.FloatTensor(u2r_edge_weight)
    graph.edges['r-u'].data['weight'] = torch.FloatTensor(u2r_edge_weight)
    graph.edges['r-i'].data['weight'] = torch.ones(4440820)
    graph.edges['i-r'].data['weight'] = torch.ones(4440820)
    
    # node features
    recipe_nodes_avg_instruction_features = torch.load(dataset_folder+'/recipe_nodes_instruction_features.pt')
    ingredient_nodes_nutrient_features = torch.load(dataset_folder+'/ingredient_nodes_nutrient_features.pt')
    recipe_nodes_pretraind_image_features = torch.load(ckpt_folder + tuned_emb_file)
    graph.nodes['recipe'].data['avg_instr_feature'] = recipe_nodes_avg_instruction_features
    graph.nodes['recipe'].data['random_instr'] = torch.nn.init.xavier_normal_(torch.ones(472515, 512))
    graph.nodes['ingredient'].data['nutrient_feature'] = ingredient_nodes_nutrient_features
    graph.nodes['recipe'].data['resnet_image'] = recipe_nodes_pretraind_image_features
    graph.nodes['user'].data['random_feature'] = torch.nn.init.xavier_normal_(torch.ones(38624, 300))
    graph.nodes['recipe'].data['nodeID'] = graph.nodes('recipe')
    
    # labels and masks
    train_mask = torch.load(dataset_folder+'/train_cuisine_mask.pt')
    val_mask = torch.load(dataset_folder+'/val_cuisine_mask.pt')
    test_mask = torch.load(dataset_folder+'/test_cuisine_mask.pt')
    recipe_nodes_labels = torch.load(dataset_folder+'/recipe_nodes_cuisine_labels.pt')
    graph.nodes['recipe'].data['train_mask'] = train_mask
    graph.nodes['recipe'].data['val_mask'] = val_mask
    graph.nodes['recipe'].data['test_mask'] = test_mask
    graph.nodes['recipe'].data['label'] = recipe_nodes_labels.long()

    return graph

graph = get_graph()
print('graph: ', graph)

generating graph ...
graph:  Graph(num_nodes={'ingredient': 22186, 'recipe': 472515, 'user': 38624},
      num_edges={('ingredient', 'i-i', 'ingredient'): 170642, ('ingredient', 'i-r', 'recipe'): 4440820, ('recipe', 'r-i', 'ingredient'): 4440820, ('recipe', 'r-r', 'recipe'): 644256, ('recipe', 'r-u', 'user'): 1193179, ('user', 'u-r', 'recipe'): 1193179},
      metagraph=[('ingredient', 'ingredient', 'i-i'), ('ingredient', 'recipe', 'i-r'), ('recipe', 'ingredient', 'r-i'), ('recipe', 'recipe', 'r-r'), ('recipe', 'user', 'r-u'), ('user', 'recipe', 'u-r')])


In [5]:
def get_train_val_test_idx():
    train_mask = graph.nodes['recipe'].data['train_mask'].to(device)
    val_mask = graph.nodes['recipe'].data['val_mask'].to(device)
    test_mask = graph.nodes['recipe'].data['test_mask'].to(device)
    labels = graph.nodes['recipe'].data['label'].to(device)

    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
    val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze()
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
    
    return train_idx, val_idx, test_idx

train_idx, val_idx, test_idx = get_train_val_test_idx()

print('length of train_idx: ', len(train_idx))
print('length of val_idx: ', len(val_idx))
print('length of test_idx: ', len(test_idx))


length of train_idx:  330911
length of val_idx:  70448
length of test_idx:  71156


## Adversarial Learning

In [6]:
def get_PGD_inputs(model, blocks, inputs, labels, seeds, eps=0.02, alpha=0.005, iters=5): # 8. / 255.
    # init
    user, instr, ingredient, image = inputs
    
    delta_user = torch.rand([len(user), 128]) * eps * 2 - eps
    delta_instr = torch.rand([len(instr), 128]) * eps * 2 - eps
    delta_ingredient = torch.rand([len(ingredient), 128]) * eps * 2 - eps
    delta_image = torch.rand([len(image), 128]) * eps * 2 - eps
    
    delta_user = delta_user.to(device)
    delta_instr = delta_instr.to(device)
    delta_ingredient = delta_ingredient.to(device)
    delta_image = delta_image.to(device)
    
    delta_user = torch.nn.Parameter(delta_user)
    delta_instr = torch.nn.Parameter(delta_instr)
    delta_ingredient = torch.nn.Parameter(delta_ingredient)
    delta_image = torch.nn.Parameter(delta_image)
    
    for i in range(iters):
        h1_instr = model.encoder_instr(global_h_list_instr[blocks[0].srcdata['_ID']]) # hs[0]
        h1_instr = norm(h1_instr)
        h1_instr = model.mp_gin(blocks[0], h1_instr)
        h1_instr = h1_instr.unsqueeze(1)
        
        h1_image = model.encoder_image(global_h_list_image[blocks[0].srcdata['_ID']]) # hs[0]
        h1_image = norm(h1_image)
        h1_image = model.mp_gin(blocks[0], h1_image)
        h1_image = h1_image.unsqueeze(1)
    
        p_user = model.user_embedding(user)
        p_mid_instr = model.encoder_instr(instr)
        p_ingredient = model.ingredient_embedding(ingredient)
        p_mid_image = model.encoder_image(image)
        
        p_user = p_user + delta_user
        p_mid_instr = p_mid_instr + delta_instr
        p_ingredient = p_ingredient + delta_ingredient
        p_mid_image = p_mid_image + delta_image
        
        p_user = norm(p_user)
        p_mid_instr = norm(p_mid_instr)
        p_ingredient = norm(p_ingredient)
        p_mid_image = norm(p_mid_image)
        
        x1 = model.rgcn(blocks[-1:], {'user': p_user, 'recipe': p_mid_instr, 'ingredient': p_ingredient})
        x2 = model.rgcn(blocks[-1:], {'user': p_user, 'recipe': p_mid_image, 'ingredient': p_ingredient})
        
        x1 = torch.cat([x1, h1_instr], dim=1)
        x2 = torch.cat([x2, h1_image], dim=1)
        x = model.cross_view_out(torch.cat([x1, x2], dim=2))
        x = model.relation_attention(x)
        logits = model.out(x)

        # update
        model.zero_grad()
        loss = criterion(logits, labels)
        loss.backward()
        
        # --- delta update ---
        # user
        delta_user.data = delta_user.data + alpha * delta_user.grad.sign()
        delta_user.grad = None
        delta_user.data = torch.clamp(delta_user.data, min=-eps, max=eps)
        # instr
        delta_instr.data = delta_instr.data + alpha * delta_instr.grad.sign()
        delta_instr.grad = None
        delta_instr.data = torch.clamp(delta_instr.data, min=-eps, max=eps)
        # ingredient
        delta_ingredient.data = delta_ingredient.data + alpha * delta_ingredient.grad.sign()
        delta_ingredient.grad = None
        delta_ingredient.data = torch.clamp(delta_ingredient.data, min=-eps, max=eps)
        # image
        delta_image.data = delta_image.data + alpha * delta_image.grad.sign()
        delta_image.grad = None
        delta_image.data = torch.clamp(delta_image.data, min=-eps, max=eps)

    output = [(delta_user).detach(), (delta_instr).detach(), (delta_ingredient).detach(), (delta_image).detach()]
    return output


## model

In [7]:
class custom_GINConv(nn.Module):
    def __init__(self,
                 apply_func,
                 aggregator_type,
                 init_eps=0,
                 learn_eps=False):
        super(custom_GINConv, self).__init__()
        self.apply_func = apply_func
        self._aggregator_type = aggregator_type
        if aggregator_type == 'sum':
            self._reducer = fn.sum
        elif aggregator_type == 'max':
            self._reducer = fn.max
        elif aggregator_type == 'mean':
            self._reducer = fn.mean
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type))
        # to specify whether eps is trainable or not.
        if learn_eps:
            self.eps = torch.nn.Parameter(torch.FloatTensor([init_eps]))
        else:
            self.register_buffer('eps', torch.FloatTensor([init_eps]))
        
        self.fc_src = nn.Linear(128, 128, bias=False)
        self.fc_dst = nn.Linear(128, 128, bias=False)
        self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, 128)))
        self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, 128)))
        self.negative_slope = 0.2
        self.leaky_relu = nn.LeakyReLU(self.negative_slope)
        self.attn_drop = nn.Dropout(0)
        self.fc_src2 = nn.Linear(128, 128, bias=False)
        self.fc_dst2 = nn.Linear(128, 128, bias=False)
        
        # init
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
        nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_l, gain=gain)
        nn.init.xavier_normal_(self.attn_r, gain=gain)
        
    def forward(self, graph, feat, edge_weight=None):
        with graph.local_scope():
            aggregate_fn = fn.copy_src('h', 'm')
            if edge_weight is not None:
                assert edge_weight.shape[0] == graph.number_of_edges()
                graph.edata['_edge_weight'] = edge_weight
                aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')

            feat_src_original, feat_dst_original = expand_as_pair(feat, graph)
            
            # add W
            feat_src1 = self.fc_src(feat_src_original)
            feat_dst1 = self.fc_dst(feat_dst_original)
            
            # add a*W*hi*hj
            feat_src2 = self.fc_src2(feat_src_original)
            feat_dst2 = self.fc_dst2(feat_dst_original)
            el = (feat_src2 * self.attn_l).sum(dim=-1)
            er = (feat_dst2 * self.attn_r).sum(dim=-1)
            graph.srcdata.update({'el': el, 'feat_src2': feat_src2})
            graph.dstdata.update({'er': er, 'feat_dst2': feat_dst2})
            graph.apply_edges(fn.u_add_v('el', 'er', 'e')) # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
            e = self.leaky_relu(graph.edata.pop('e'))
            graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
            
            graph.apply_edges(fn.u_mul_e('feat_src2', 'a', 'a_and_src'))
            graph.update_all(fn.v_mul_e('feat_dst2', 'a_and_src', 'm'),
                             fn.sum('m', 'add_ft'))
            
            # out
            graph.srcdata['h'] = feat_src1
            graph.update_all(aggregate_fn, self._reducer('m', 'neigh'))
            rst = (1 + self.eps) * feat_dst1 + graph.dstdata['neigh'] + graph.dstdata['add_ft']
            if self.apply_func is not None:
                rst = self.apply_func(rst)
            return rst

In [8]:
class StochasticTwoLayerRGCN(nn.Module):
    def __init__(self, in_feat, hidden_feat, out_feat, rel_names):
        super().__init__()
        self.gnn_dict = {
                rel : custom_GINConv(torch.nn.Linear(128, 128), 'max')
                for rel in rel_names
            }
        self.conv1 = dglnn.HeteroGraphConv(self.gnn_dict, aggregate='stack')

    def forward(self, blocks, x):
#         edge_weight = blocks[0].edata['weight']
        x = self.conv1(blocks[0], x) # edge_weight
        return x['recipe']

In [9]:
# def node_drop(feats, drop_rate, training):
#     n = feats.shape[0]
#     drop_rates = torch.FloatTensor(np.ones(n) * drop_rate)
    
#     if training:
#         masks = torch.bernoulli(1. - drop_rates).unsqueeze(1)
#         feats = masks.to(feats.device) * feats / (1. - drop_rate)
#     else:
#         feats = feats
#     return feats

In [10]:
# 不同type neighbor的attention
class RelationAttention(nn.Module):
    def __init__(self, in_size, hidden_size=128):
        super(RelationAttention, self).__init__()

        self.project = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1, bias=False)
        )

    def forward(self, z):
        w = self.project(z).mean(0)                    # (M, 1)
        beta = torch.softmax(w, dim=0)                 # (M, 1)
        beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)
        out = (beta * z).sum(1)                        # (N, D * K)
        return out
    

In [11]:
global_h_list_instr = graph.ndata['avg_instr_feature']['recipe'].to(device)
global_h_list_image = graph.ndata['resnet_image']['recipe'].to(device)

In [12]:
class ScorePredictor(nn.Module):
    def forward(self, edge_subgraph, x):
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['x'] = x
            edge_subgraph.apply_edges(dgl.function.u_dot_v('x', 'x', 'score'), etype='u-r')
            return edge_subgraph.edata['score'][('user', 'u-r', 'recipe')].squeeze()

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        
        # original
        self.user_embedding = nn.Sequential(
            nn.Linear(300, 128),
            nn.ReLU(),
        )
        self.ingredient_embedding = nn.Sequential(
            nn.Linear(46, 128),
            nn.ReLU()
        )

        self.mp_gin = custom_GINConv(torch.nn.Linear(128, 128), 'max')
        self.rgcn = StochasticTwoLayerRGCN(128, 128, 128, graph.etypes)

        self.encoder_instr = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
        )
        self.encoder_image = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
        )
        
        self.cross_view_out = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
        )
        self.out = nn.Sequential(
            nn.Linear(128, 9)
        )

        self.relation_attention = RelationAttention(128)

    def forward(self, blocks, input_features, seeds, adv_deltas):
        user, avg_instr, ingredient, image = input_features

        if adv_deltas:
            # print('adv ...')
            delta_user, delta_instr, delta_ingredient, delta_image = adv_deltas
            
        # metapath
        h1_instr = self.encoder_instr(global_h_list_instr[blocks[0].srcdata['_ID']])
        h1_image = self.encoder_image(global_h_list_image[blocks[0].srcdata['_ID']])
        h1_instr = norm(h1_instr)
        h1_image = norm(h1_image)
        h1_instr = self.mp_gin(blocks[0], h1_instr)
        h1_image = self.mp_gin(blocks[0], h1_image)
        h1_instr = h1_instr.unsqueeze(1)
        h1_image = h1_image.unsqueeze(1)
        
        # schema
        user = self.user_embedding(user)
        ingredient = self.ingredient_embedding(ingredient)
        mid_instr = self.encoder_instr(avg_instr)
        mid_image = self.encoder_image(image)
        
        if adv_deltas:
            user += delta_user
            ingredient += delta_ingredient
            mid_instr += delta_instr
            mid_image += delta_image
        
        user = norm(user)
        ingredient = norm(ingredient)
        mid_instr = norm(mid_instr)
        mid_image = norm(mid_image)
    
        x1 = self.rgcn(blocks[-1:], {'user': user, 'recipe': mid_instr, 'ingredient': ingredient})
        x2 = self.rgcn(blocks[-1:], {'user': user, 'recipe': mid_image, 'ingredient': ingredient})

        x1 = torch.cat([x1, h1_instr], dim=1)
        x2 = torch.cat([x2, h1_image], dim=1)
        x = self.cross_view_out(torch.cat([x1, x2], dim=2))
        x = self.relation_attention(x)
  
        return self.out(x), x


## Data Loader

In [13]:
metapath_list = [['r-u', 'u-r']]

class HANSampler(object):
    def __init__(self, g, metapath_list, num_neighbors):
        self.sampler_list = []
        for metapath in metapath_list:
            # note: random walk may get same route(same edge), which will be removed in the sampled graph.
            # So the sampled graph's edges may be less than num_random_walks(num_neighbors).
            self.sampler_list.append(RandomWalkNeighborSampler(G=g,
                                                               num_traversals=2,
                                                               termination_prob=0,
                                                               num_random_walks=num_neighbors,
                                                               num_neighbors=num_neighbors,
                                                               metapath=metapath))
            
            self.schema_sampler = dgl.dataloading.MultiLayerNeighborSampler([20])

    def sample_blocks(self, seeds):
        block_list = []
        seeds = torch.stack(seeds)
        
        sampled_mp_src_nodes = []
        sampled_mp_dst_nodes = []
        for sampler in self.sampler_list:
            frontier = sampler(seeds)
            frontier = dgl.remove_self_loop(frontier)
            block = dgl.to_block(frontier, seeds)
            block_list.append(block)
        
        # schema
        schema_NodeDataLoader = dgl.dataloading.NodeDataLoader(graph, {'recipe': seeds}, self.schema_sampler,
                                        batch_size=4096, shuffle=False, drop_last=False, num_workers=0)
        for input_nodes, output_nodes, blocks in schema_NodeDataLoader:
            block_list.append(blocks[0])
            break

        return seeds, block_list
    
han_sampler = HANSampler(graph, metapath_list, num_neighbors=10)
train_dataloader = DataLoader(
                    dataset=train_idx.cpu(),
                    batch_size=4096,
                    collate_fn=han_sampler.sample_blocks,
                    shuffle=True,
                    drop_last=False,
                    num_workers=4)

for step, (seeds, blocks) in enumerate(train_dataloader):
    print(blocks)
    break
print()
    
han_val_sampler = HANSampler(graph, metapath_list, num_neighbors=10)
val_dataloader = DataLoader(
                dataset=val_idx.cpu(),
                batch_size=4096,
                collate_fn=han_val_sampler.sample_blocks,
                shuffle=False,
                drop_last=False,
                num_workers=0)
    
han_test_sampler = HANSampler(graph, metapath_list, num_neighbors=10)
test_dataloader = DataLoader(
                    dataset=test_idx.cpu(),
                    batch_size=4096,
                    collate_fn=han_test_sampler.sample_blocks,
                    shuffle=False,
                    drop_last=False,
                    num_workers=0)


[Block(num_src_nodes=39876, num_dst_nodes=4096, num_edges=39629), Block(num_src_nodes={'ingredient': 3947, 'recipe': 7207, 'user': 4561},
      num_dst_nodes={'ingredient': 0, 'recipe': 4096, 'user': 0},
      num_edges={('ingredient', 'i-i', 'ingredient'): 0, ('ingredient', 'i-r', 'recipe'): 38179, ('recipe', 'r-i', 'ingredient'): 0, ('recipe', 'r-r', 'recipe'): 3652, ('recipe', 'r-u', 'user'): 0, ('user', 'u-r', 'recipe'): 10092},
      metagraph=[('ingredient', 'ingredient', 'i-i'), ('ingredient', 'recipe', 'i-r'), ('recipe', 'ingredient', 'r-i'), ('recipe', 'recipe', 'r-r'), ('recipe', 'user', 'r-u'), ('user', 'recipe', 'u-r')])]



## helper functions

In [14]:
def norm(input, p=1, dim=1, eps=1e-12):
    return input / input.norm(p, dim, keepdim=True).clamp(min=eps).expand_as(input)


def get_score(y_pred, y_true):
    total_acc = accuracy_score(y_true, y_pred)
    score = {
        "f1": f1_score(y_true, y_pred, labels=[1, 2, 3, 4, 5, 6, 7, 8], average='micro'),
        "acc": total_acc
    }
    
    # detailed score
    matrix = confusion_matrix(y_true, y_pred)
    detailed_acc = matrix.diagonal()/matrix.sum(axis=1)
    detailed_score = {
        "f1": f1_score(y_true, y_pred, labels=[0, 1, 2, 3, 4, 5, 6, 7, 8], average=None, zero_division=0),
        "acc": detailed_acc
    }
    return score, detailed_score


def evaluate(model, dataloader, device):
    # print('evaluating ... ')
    evaluate_start = time.time()
    model.eval()
    total_loss = 0
    cosine_total_loss = 0
    link_prediction_total_loss = 0
    total_precision = 0
    total_recall = 0
    total_f1 = 0
    
    detailed_precision = 0
    detailed_recall = 0
    detailed_f1 = 0
    count = 0
    
    all_y_preds = None
    all_labels = None
    
    with torch.no_grad():
        for step, (seeds, blocks) in enumerate(dataloader):
            blocks = [blk.to(device) for blk in blocks]
        
            # input
            input_user = blocks[-1].srcdata['random_feature']['user']
            input_instr = blocks[-1].srcdata['avg_instr_feature']['recipe'] # avg_instr_feature
            input_ingredient = blocks[-1].srcdata['nutrient_feature']['ingredient']
            input_image = blocks[-1].srcdata['resnet_image']['recipe']
            labels = blocks[-1].dstdata['label']['recipe']

            inputs = [input_user, input_instr, input_ingredient, input_image]
            logits, _ = model(blocks, inputs, seeds, None)
            y_pred = np.argmax(logits.cpu(), axis=1)
            
            if all_y_preds is None:
                all_y_preds = y_pred
                all_labels = labels.cpu().numpy()
            else:
                all_y_preds = np.append(all_y_preds, y_pred, axis=0)
                all_labels = np.append(all_labels, labels.cpu().numpy(), axis=0)
            
            # Loss
            loss = criterion(logits, labels)
            total_loss += loss.item()

            count += len(labels)
#             break
        
        total_score, detailed_score = get_score(all_y_preds, all_labels)
        total_f1 = total_score['f1']
        total_acc = total_score['acc']
        detailed_f1 = detailed_score['f1']
        detailed_acc = detailed_score['acc']
        
        total_loss /= count
        link_prediction_total_loss /= count
        evalutate_time = time.strftime("%M:%S min", time.gmtime(time.time()-evaluate_start))
        
    return total_loss, total_f1, total_acc, evalutate_time, detailed_f1, detailed_acc, link_prediction_total_loss



## Training

In [15]:
model = Model().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.005)
scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.95)

weights_class = torch.Tensor(9).fill_(1)
criterion = nn.CrossEntropyLoss(weight=weights_class).to(device)

print('start ... ')
for epoch in range(100):
    train_start = time.time()
    epoch_loss = 0
    epoch_adversarial_loss = 0
    iteration_cnt = 0

    for batch, (seeds, blocks) in enumerate(train_dataloader):
        model.train()
        blocks = [b.to(device) for b in blocks]

        # input
        input_user = blocks[-1].srcdata['random_feature']['user']
        input_instr = blocks[-1].srcdata['avg_instr_feature']['recipe']
        input_ingredient = blocks[-1].srcdata['nutrient_feature']['ingredient']
        input_image = blocks[-1].srcdata['resnet_image']['recipe']
        labels = blocks[-1].dstdata['label']['recipe'] 
        inputs = [input_user, input_instr, input_ingredient, input_image]
        
        logits, x = model(blocks, inputs, seeds, None)
        
        # adversarial learning
        adv_deltas = get_PGD_inputs(model, blocks, inputs, labels, seeds)
        adv_logits, adv_x = model(blocks, inputs, seeds, adv_deltas)
        
        # compute loss
        adversarial_loss = criterion(adv_logits, labels)
        loss = criterion(logits, labels) + 0.1*adversarial_loss
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        epoch_loss += loss.item()
        epoch_adversarial_loss += adversarial_loss.item()
        iteration_cnt += 1 
#         break
        
    epoch_loss /= iteration_cnt
    epoch_adversarial_loss /= iteration_cnt
    train_end = time.strftime("%M:%S min", time.gmtime(time.time()-train_start))
    
    print('Epoch: {0}, L: {l:.4f}, adv_l: {adv_l:.4f}, T: {t}, LR: {lr:.6f}'
          .format(epoch, l=epoch_loss, adv_l=epoch_adversarial_loss, t=train_end, lr=opt.param_groups[0]['lr']))

    scheduler.step()

    # Evaluation
    # For demonstration purpose, only test set result is reported here. Please use val_dataloader for comprehensiveness.
    test_loss, test_f1, test_acc, test_time, test_detailed_f1, test_detailed_acc, link_prediction_test_loss \
    = evaluate(model, test_dataloader, device)
    
    print('Testing: ')
    print('Total Loss: {l:.4f},  F1: {f1:.6f}, Acc: {acc:.6f},  Time: {t}, LR: {lr:.6f}'
          .format(l=test_loss, f1=test_f1, acc=test_acc, t=test_time, lr=opt.param_groups[0]['lr']))
    
    # show in overleaf table structure
    test_detailed_f1 = [str('{:.1f}'.format(i*100)) for i in list(test_detailed_f1)]
    test_detailed_acc = [str('{:.1f}'.format(i*100)) for i in list(test_detailed_acc)]
    print('detailed_f1: ', ' & '.join(test_detailed_f1[1:]) + ' & ' + test_detailed_f1[0] + ' & ' + '{:.1f}'.format(test_f1*100))
    print('detailed_acc: ', ' & '.join(test_detailed_acc[1:]) + ' & ' + test_detailed_acc[0] + ' & ' + '{:.1f}'.format(test_acc*100))
    print()
    

start ... 
Epoch: 0, L: 1.2557, adv_l: 1.2512, T: 01:23 min, LR: 0.005000
Testing: 
Total Loss: 0.0002,  F1: 0.819867, Acc: 0.804542,  Time: 00:06 min, LR: 0.004750
detailed_f1:  70.2 & 87.9 & 75.6 & 77.6 & 78.3 & 89.9 & 64.9 & 85.8 & 72.4 & 82.0
detailed_acc:  64.2 & 88.8 & 69.4 & 70.6 & 78.4 & 92.3 & 67.7 & 88.8 & 71.1 & 80.5

Epoch: 1, L: 0.6031, adv_l: 0.6262, T: 01:22 min, LR: 0.004750
Testing: 
Total Loss: 0.0001,  F1: 0.861190, Acc: 0.843555,  Time: 00:05 min, LR: 0.004513
detailed_f1:  77.4 & 91.7 & 83.9 & 82.3 & 82.7 & 92.3 & 77.7 & 87.1 & 74.2 & 86.1
detailed_acc:  73.1 & 94.5 & 79.0 & 82.8 & 86.9 & 95.4 & 74.0 & 91.9 & 67.4 & 84.4

Epoch: 2, L: 0.5238, adv_l: 0.5459, T: 01:20 min, LR: 0.004513
Testing: 
Total Loss: 0.0001,  F1: 0.874189, Acc: 0.856274,  Time: 00:05 min, LR: 0.004287
detailed_f1:  78.7 & 93.8 & 85.6 & 85.5 & 85.1 & 93.1 & 79.0 & 88.1 & 76.0 & 87.4
detailed_acc:  75.3 & 93.4 & 83.4 & 82.5 & 85.1 & 94.8 & 74.8 & 92.2 & 73.4 & 85.6

Epoch: 3, L: 0.4898, adv_l: 0