In [62]:
import os
import csv
import json
import numpy as np
import pandas as pd

from tqdm import tqdm_notebook
from collections import Counter

In [63]:
root_dir = "/media/iiit/EXTERNAL/hashtag_data"
post_set = set()
hashtags_list = []
for file in os.listdir(root_dir):
    if ".json" in file:
        file_name, ext = os.path.splitext(file)
        json_path = os.path.join(root_dir, file)

        with open(json_path, 'r') as f2:
            json_data = json.load(f2)
            for post in json_data:
                post_id = json_data[post]["id"]
                post_set.add(post_id)
                text = json_data[post]['text_des'].lower()
                words = text.split(" ")
                for cur_word in words:
                    if "#" in cur_word:
                        hashtags_list.append(cur_word)

In [64]:
hashtags_counter = Counter(hashtags_list)
top_100_hashtags = [hashtag for (hashtag, val) in hashtags_counter.most_common(100)]
top_100_hashtags[:5]

['#travel',
 '#travelphotography',
 '#travelgram',
 '#travelblogger',
 '#instatravel']

In [65]:
root_dir = "/media/iiit/EXTERNAL/hashtag_data"
num_labels_list = []
with open("data.csv", "w") as f1:
    writer = csv.writer(f1, delimiter=" ")
    writer.writerow(["image_path", "hashtags"])
    
    for file in os.listdir(root_dir):
        if ".json" in file:
            file_name, ext = os.path.splitext(file)
            json_path = os.path.join(root_dir, file)
            
            with open(json_path, 'r') as f2:
                json_data = json.load(f2)
                for post in json_data:
                    post_id = json_data[post]["id"]
                    text = json_data[post]['text_des'].lower()
                    list_tags = []
                    for tag in top_100_hashtags:
                        if tag in text:
                            list_tags.append(tag.replace('#',''))
                    
                    img_path = os.path.join(file_name, post_id)
                    img_complete_path = os.path.join(root_dir, img_path)
                    if os.path.exists(img_complete_path):
                        if len(list_tags) <= 7 and len(list_tags) > 0:
                            num_labels_list.append(len(list_tags))
                            writer.writerow([img_path, ','.join(list_tags)])

In [66]:
print("Mean number of labels: {}".format(np.mean(num_labels_list)))

Mean number of labels: 4.060648089774303


In [67]:
from numpy.random import RandomState

df = pd.read_csv("data.csv", sep=" ")
df.head()

Unnamed: 0,image_path,hashtags
0,Citytravel/B3WUl6BoAHt,"citytravel,city,architecture,art"
1,Citytravel/B3R4txlIkdR,"citytravel,city,architecture,art"
2,Citytravel/B3ME05mIpMS,"citytravel,city,architecture,art"
3,Citytravel/B3XTYbvlpZa,"travel,love,explore,citytravel,travelgirl,city"
4,Citytravel/B3UEHqIBWyr,"travel,traveladdict,travels,citytravel,europe,..."


In [68]:
rng = RandomState()
train_df = df.sample(frac=0.8, random_state=rng)
val_df = df.loc[~df.index.isin(train_df.index)]

In [69]:
train_df.to_csv("train.csv", index=False, sep=" ")
val_df.to_csv("val.csv", index=False, sep=" ")

In [70]:
train_df.iloc[0]["image_path"]

'Travellove/B3ZQp7mhUHs'