In [1]:
import os
import json
import sys
sys.path.append("../")

from tqdm import tqdm
import numpy as np
import utils
import pandas as pd
import ast

In [2]:
data_dir = "../../data/reuters"
data_ratio = {"train": 0.8, "dev": 0.1, "test": 0.1}
random_state = np.random.RandomState(1)

save_data_path = os.path.join(data_dir, "data.ndjson")
save_label_path = os.path.join(data_dir, "doc_label_encoder.json")

In [3]:
doc_fname = os.path.join(data_dir, "doc", "reuters.csv")
df = pd.read_csv(doc_fname)
df["topic"] = df["topic"].apply(lambda x : ast.literal_eval(x))
df.head()

Unnamed: 0,path,topic,subset,index,content,lead,tin,retail,fuel,propane,...,soy-meal,earn,sun-oil,instal-debt,cotton,heat,trade,dfl,palladium,iron-steel
0,test/14826,[trade],test,14826,ASIAN EXPORTERS FEAR DAMAGE FROM U.S.-JAPAN RI...,0,0,0,0,0,...,0,0,0,0,0,0,1,0,0,0
1,test/14828,[grain],test,14828,CHINA DAILY SAYS VERMIN EAT 7-12 PCT GRAIN STO...,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,test/14829,"[nat-gas, crude]",test,14829,JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWA...,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,test/14832,"[rubber, tin, sugar, corn, rice, grain, trade]",test,14832,THAI TRADE DEFICIT WIDENS IN FIRST QUARTER\n ...,0,1,0,0,0,...,0,0,0,0,0,0,1,0,0,0
4,test/14833,"[palm-oil, veg-oil]",test,14833,INDONESIA SEES CPO PRICE RISING SHARPLY\n Ind...,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [4]:
test_ids = df[df["subset"] == "test"].index.values
train_ids = df[df["subset"] == "training"].index.values
num_test = int(len(train_ids) * 0.2)
train_ids = random_state.permutation(train_ids)
val_ids = train_ids[:num_test]
train_ids = train_ids[num_test:]

In [5]:
def get_content(text):
    title, content = "", ""
    for x in text.split("\n"):
        x = x.strip()
        if title == "" and len(x) != 0:
            title = x
        else:
            content += x + " "
            
    return content, title

def get_sample(doc, doc_id):
    content, title = get_content(doc["content"])
    return {"id": doc_id, "content": content, "title": title, "labels": doc["topic"]}

In [6]:
samples = [get_sample(doc, did) for did, doc in df.iterrows()]

In [7]:
for sample in samples:
    id = sample["id"]
    sample["is_train"] = id in train_ids
    sample["is_dev"] = id in val_ids
    sample["is_test"] = id in test_ids

In [8]:
labels = set()
with open(save_data_path, "w") as f:
    for sample in tqdm(samples):
        if sample["labels"]:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
            labels.update(sample["labels"])

label_ids = {x : i for i, x in enumerate(sorted(labels))}
utils.dump_json(label_ids, save_label_path)

100%|██████████| 10788/10788 [00:00<00:00, 84650.35it/s]


In [9]:
import pandas as pd
df = pd.DataFrame(samples)
df

Unnamed: 0,id,content,title,labels,is_train,is_dev,is_test
0,0,Mounting trade friction between the U.S. And J...,ASIAN EXPORTERS FEAR DAMAGE FROM U.S.-JAPAN RIFT,[trade],False,False,True
1,1,A survey of 19 provinces and seven cities show...,CHINA DAILY SAYS VERMIN EAT 7-12 PCT GRAIN STOCKS,[grain],False,False,True
2,2,The Ministry of International Trade and Indust...,JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWARDS,"[nat-gas, crude]",False,False,True
3,3,Thailand's trade deficit widened to 4.5 billio...,THAI TRADE DEFICIT WIDENS IN FIRST QUARTER,"[rubber, tin, sugar, corn, rice, grain, trade]",False,False,True
4,4,Indonesia expects crude palm oil (CPO) prices ...,INDONESIA SEES CPO PRICE RISING SHARPLY,"[palm-oil, veg-oil]",False,False,True
...,...,...,...,...,...,...,...
10783,10783,The Bank of Japan bought a small amount of dol...,BANK OF JAPAN INTERVENES SOON AFTER TOKYO OPENING,"[money-fx, dlr]",True,False,False
10784,10784,"Japan's rubber stocks fell to 44,980 tonnes in...",JAPAN RUBBER STOCKS FALL IN MARCH,[rubber],True,False,False
10785,10785,THE BANK OF KOREA SAID IT FIXED THE MIDRATE OF...,SOUTH KOREAN WON FIXED AT 25-MONTH HIGH,[money-fx],True,False,False
10786,10786,Nippon Mining Co Ltd said it lowered its selli...,NIPPON MINING LOWERS COPPER PRICE,[copper],True,False,False


In [10]:
sum(df["is_train"]), sum(df["is_dev"]), sum(df["is_test"])

(6216, 1553, 3019)

In [11]:
ax = df["labels"].apply(lambda x : len(x)).value_counts().plot(kind="bar");
ax.set_title("Number of labels per document");

In [12]:
df["content"].apply(lambda x : len(x.split())).mean()

120.45059325176122

In [13]:
len(label_ids)

90