In [1]:
import re
import os
from nltk.corpus import stopwords

import glob
import copy
import random
import time
import json
import pickle
import nltk
import collections
from collections import Counter
from itertools import combinations
import numpy as np
from random import shuffle
import torch
import argparse
import time

from module.vocabulary import Vocab

In [2]:
data_dir = '../cnndm'
cache_dir = '../cache/CNNDM'
DATA_FILE = os.path.join(data_dir, "index_to_file_mapping_train.json")
VALID_FILE = os.path.join(data_dir, "index_to_file_mapping_val.json")
VOCAL_FILE = os.path.join(cache_dir, "vocab")
FILTER_WORD = os.path.join(cache_dir, "filter_word.txt")
train_w2s_path = os.path.join(cache_dir, "index_to_file_mapping_train.json")
val_w2s_path = os.path.join(cache_dir, "index_to_file_mapping_val.json")
    
# defaults
vocab_size = 50000
doc_max_timesteps = 50
sent_max_len = 100

# vocab
vocab = Vocab(VOCAL_FILE, vocab_size)

# filterwords
FILTERWORD = stopwords.words('english')
punctuations = [',', '.', ':', ';', '?', '(', ')', '[', ']', '&', '!', '*', '@', '#', '$', '%', '\'\'', '\'', '`', '``',
                '-', '--', '|', '\/']
FILTERWORD.extend(punctuations)

In [3]:
# utils
def readJson(fname):
    data = []
    with open(fname, encoding="utf-8") as f:
        for line in f:
            data.append(json.loads(line))
    return data


def readText(fname):
    data = []
    with open(fname, encoding="utf-8") as f:
        for line in f:
            data.append(line.strip())
    return data

In [18]:
class Example(object):
    def __init__(self, article_sents, abstract_sents, vocab, sent_max_len, label):
        self.sent_max_len = sent_max_len
        self.enc_sent_len = []
        self.enc_sent_input = []
        self.enc_sent_input_pad = []

        # Store the original strings
        self.original_article_sents = article_sents
        self.original_abstract = "\n".join(abstract_sents)

        # Process the article
        if isinstance(article_sents, list) and isinstance(article_sents[0], list):  # multi document
            self.original_article_sents = []
            for doc in article_sents:
                self.original_article_sents.extend(doc)
        for sent in self.original_article_sents:
            article_words = sent.split()
            self.enc_sent_len.append(len(article_words))  # store the length before padding
            self.enc_sent_input.append([vocab.word2id(w.lower()) for w in article_words])  # list of word ids; OOVs are represented by the id for UNK token
        self._pad_encoder_input(vocab.word2id('[PAD]'))

        # Store the label
        self.label = label
        label_shape = (len(self.original_article_sents), len(label))  # [N, len(label)]
        # label_shape = (len(self.original_article_sents), len(self.original_article_sents))
        self.label_matrix = np.zeros(label_shape, dtype=int)
        if label != []:
            self.label_matrix[np.array(label), np.arange(len(label))] = 1  # label_matrix[i][j]=1 indicate the i-th sent will be selected in j-th step

    def _pad_encoder_input(self, pad_id):
        """
        :param pad_id: int; token pad id
        :return: 
        """
        max_len = self.sent_max_len
        for i in range(len(self.enc_sent_input)):
            article_words = self.enc_sent_input[i].copy()
            if len(article_words) > max_len:
                article_words = article_words[:max_len]
            if len(article_words) < max_len:
                article_words.extend([pad_id] * (max_len - len(article_words)))
            self.enc_sent_input_pad.append(article_words)

In [26]:
class ExampleSet(torch.utils.data.Dataset):
    def __init__(self, data_path, vocab, doc_max_timesteps, sent_max_len, filter_word_path, w2s_path):

        self.vocab = vocab
        self.sent_max_len = sent_max_len
        self.doc_max_timesteps = doc_max_timesteps

        with open(data_path, "r", encoding="utf-8") as f:
            self.example_list = json.load(f)
        self.size = len(self.example_list)

        tfidf_w = readText(filter_word_path)
        self.filterwords = FILTERWORD
        self.filterids = [vocab.word2id(w.lower()) for w in FILTERWORD]
        self.filterids.append(vocab.word2id("[PAD]"))   # keep "[UNK]" but remove "[PAD]"
        lowtfidf_num = 0
        for w in tfidf_w:
            if vocab.word2id(w) != vocab.word2id('[UNK]'):
                self.filterwords.append(w)
                self.filterids.append(vocab.word2id(w))
                lowtfidf_num += 1
            if lowtfidf_num > 5000:
                break

        with open(w2s_path, "r") as f:
            self.w2s_tfidf = json.load(f)
            
    def get_w2s(self,index):
        index = str(index)
        file_name, new_index = self.w2s_tfidf[index]
        ws = readJson(file_name)
        return ws[new_index]
    
    def pad_label_m(self, label_matrix):
        label_m = label_matrix[:self.doc_max_timesteps, :self.doc_max_timesteps]
        N, m = label_m.shape
        if m < self.doc_max_timesteps:
            pad_m = np.zeros((N, self.doc_max_timesteps - m))
            return np.hstack([label_m, pad_m])
        return label_m
            
    def get_example(self, index):
        file_name, new_index  = self.example_list[str(index)]
        e = readJson(file_name)
        e = e[new_index]
        e["summary"] = e.setdefault("summary", [])
        example = Example(e["text"], e["summary"], self.vocab, self.sent_max_len, e["label"])
        return example
    
    # def checker(self, index):
    #     item = self.get_example(index)
    #     input_pad = item.enc_sent_input_pad[:self.doc_max_timesteps]
    #     label = self.pad_label_m(item.label_matrix)
    #     # w2s_w = self.w2s_tfidf[index]
    #     w2s_w = self.get_w2s(index)
    #     return input_pad, label, w2s_w, index
    
    def __getitem__(self, index):
        item = self.get_example(index)
        input_pad = item.enc_sent_input_pad[:self.doc_max_timesteps]
        label = self.pad_label_m(item.label_matrix)
        # w2s_w = self.w2s_tfidf[index]
        w2s_w = self.get_w2s(index)
        return input_pad, label, w2s_w, index
        

In [27]:
dataset = ExampleSet(data_path=DATA_FILE, vocab=vocab, doc_max_timesteps=doc_max_timesteps, sent_max_len=sent_max_len, filter_word_path=FILTER_WORD, w2s_path=train_w2s_path)

In [29]:
dataset.checker(0)

['/scratch/hitesh.goel/cnndm/train/0.jsonl', 0]
[[21, 13, 8, 694, 10, 11130, 15, 592, 440, 45, 238, 12, 144, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [16235, 7, 5, 2080, 6, 39, 73, 34, 210, 19404, 9, 16959, 17, 8, 1609, 990, 27, 12861, 3108, 12157, 60, 29, 41814, 4575, 17818, 5, 16790, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [33, 72, 309, 662, 8, 287, 61, 1264, 9, 1037, 8, 26514, 6213, 27, 5, 6820, 4579, 11, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [None]:
'''
notes
- dont make 2d matrix for labels
- only 1D with [:max_timesteps] length?
    - reduces space?
- remove tf-idf scores
- use nx to make the graphs
- then use torch_geometric to write the code
'''

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
import networkx as nx
import numpy as np

# Generate sample bipartite graph data
def generate_bipartite_data(num_nodes_src, num_nodes_dst, num_edges):
    G = nx.complete_bipartite_graph(num_nodes_src, num_nodes_dst)
    edge_index = np.array(G.edges()).T
    src_features = torch.randn(num_nodes_src, 16)
    dst_features = torch.randn(num_nodes_dst, 16)
    return Data(edge_index=torch.tensor(edge_index, dtype=torch.long), 
                x_src=src_features, x_dst=dst_features)


# Define the neural network model
class GATModel(nn.Module):
    def __init__(self):
        super(GATModel, self).__init__()
        self.conv1 = GATConv(in_channels=16, out_channels=8, heads=2, dropout=0.6)
        self.fc = nn.Linear(16, 1)

    def forward(self, data):
        src, dst = data.x_src, data.x_dst
        x_src = self.conv1(src, data.edge_index[:2])  # Pass source node features and edge indices
        x_dst = self.conv1(dst, data.edge_index[::-1][:2])  # Pass destination node features and edge indices
        x = torch.cat([x_src, x_dst], dim=0)  # Concatenate source and destination node features
        x = F.relu(x)
        x = self.fc(x)
        return x


# Generate sample data
data = generate_bipartite_data(10, 10, 50)

# Create model, optimizer and loss function
model = GATModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Training loop
for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, torch.randn_like(output))  # Random target for demonstration
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
