In [None]:
import json
import praw
import os
import pandas as pd

import re
import nltk
from nltk.tokenize import TreebankWordTokenizer

import jupyterannotate

In [None]:
with open('../reddit_api.json') as json_file:
    reddit_api_credentials = json.load(json_file)
    reddit_read_only = praw.Reddit(client_id=reddit_api_credentials['client_id'],
                                   client_secret=reddit_api_credentials['secret'],
                                   user_agent=reddit_api_credentials['user_agent']) 

subreddit = reddit_read_only.subreddit("AskDocs")

In [None]:
random_posts = [
    {
        "id": post.id,
        'post_text': f"{post.title}\n{post.selftext}",
        "score": post.score,
        'total_comments': post.num_comments,
        'post_url': post.url
    }
    for post in [
        subreddit.random() 
        for i in range(1)
    ]
]

In [None]:
labels = ['age', 'age_unit']
reddit_posts = [post['post_text'] for post in random_posts]

annotation_widget = jupyterannotate.AnnotateWidget(
    docs = reddit_posts,
    labels = labels
)
annotation_widget

In [None]:
def convert_char_to_token_labels(text, spans):
    tokens = TreebankWordTokenizer().tokenize(text)
    token_labels = ['O'] * len(tokens)  # Initialize all labels as 'O' (outside) by default
    for span in spans:
        for i, token in enumerate(tokens):
            start_index = None
            end_index = None
            if token in TreebankWordTokenizer().tokenize(span['text']):
                token_spans = [{'start': span[0], 'end': span[1]} for span in TreebankWordTokenizer().span_tokenize(text)]
                string_matches = pd.DataFrame([{'index': index, 'start': token_spans[index]['start'], 'end': token_spans[index]['end']} 
                                               for index, token in enumerate(tokens) if token == tokens[i]]).set_index('index')
                token_start = string_matches.loc[i, 'start']
                token_end = string_matches.loc[i, 'end']
                if span['start'] >= token_start and span['start'] <= token_end:
                    start_index = i
                if span['end'] >= token_start and span['end'] <= token_end:
                    end_index = i
                    
            # Assign the label to the tokens within the span
            if start_index is not None and end_index is not None:
                token_labels[start_index] = 'B-' + span['label']  # Beginning of the span
                for i in range(start_index + 1, end_index + 1):
                    token_labels[i] = 'I-' + label  # Inside the span

    return tokens, token_labels

In [None]:
new_labels = []
for i, labels in enumerate(annotation_widget.spans):
    for label in labels:
        tokens, token_labels = convert_char_to_token_labels(reddit_posts[i], annotation_widget.spans[i])
        new_labels.append(
            {
                'post_id': random_posts[i]['id'],
                'context': random_posts[i]['post_text'],
                'tokens': tokens,
                'token_labels': token_labels,
                'score': random_posts[i]['score'],
                'total_comments': random_posts[i]['total_comments'],
                'post_url': random_posts[i]['post_url'],
            }
        )

In [None]:
if 'ner_labels.json' in os.listdir('../data/reddit_scraping/'):
    with open('../data/reddit_scraping/ner_labels.json') as json_file:
        new_labels.extend(json.load(json_file))
        ner_labels = []
        for label in new_labels:
            if label not in ner_labels:
                ner_labels.append(label)
else:
    ner_labels = new_labels

ner_labels = [label for label in ner_labels if label is not None]
with open('../data/reddit_scraping/ner_labels.json', 'w') as reddit_file:
    json.dump(ner_labels, reddit_file, indent=4)

In [None]:
len(ner_labels)