In [2]:
import pickle as pkl
import spacy
import csv
import json
from copy import deepcopy
from tqdm import tqdm

In [3]:
# Load Parsed Corpus
sm_parser = spacy.load('en_core_web_sm')

with open('tbbt_en_zh.pkl', 'rb') as f_zh:
    with open('tbbt_en_fa.pkl', 'rb') as f_fa:
        zh = pkl.load(f_zh)
        fa = pkl.load(f_fa)
        inter_keys = set(zh.keys()) & set(fa.keys())

data = {}
with open('parsed_corpus.pkl', 'rb') as f:
    parsed = pkl.load(f)
    for item in inter_keys:
        data[item] = parsed[item]

In [4]:
# Regular Candidate Spans

output = []
for epi_key in data:
    if epi_key != (1,1):
        continue
    episode = data[epi_key]
    # Each scene contain on episode
    for scene in tqdm(episode):
        # Collect data to annotate
        all_sentences = []
        all_query_spans = []
        all_candidate_spans = []

        for i, utt in enumerate(scene):
            if "en_subtitles" in utt:
                utterance = " ".join([x.strip().lstrip('-').lstrip().lstrip('.').lstrip() for x in utt['en_subtitles']])
                utterance_tokens = [item.text for item in sm_parser(utterance)]
                speaker = utt['speaker']
                speaker_tokens = [item.text for item in sm_parser(speaker)]
                sentence_tokens = speaker_tokens + [":"] + utterance_tokens

                all_sentences.append(sentence_tokens)

                spans = list(set(utt['sm_noun_chunk']) | set(utt['berkeley_noun_chunk']) | set(utt['trf_noun_chunk']))
                spans.sort(key=lambda x: x[1])

                for span in spans:
                    all_candidate_spans.append({
                        "sentenceIndex": i,
                        "startToken": span[1] + len(speaker_tokens) + 1,
                        "endToken": span[2] + len(speaker_tokens) + 1
                    })
                    all_query_spans.append({
                        "sentenceIndex": i,
                        "startToken": span[1] + len(speaker_tokens) + 1,
                        "endToken": span[2] + len(speaker_tokens) + 1
                    })
            else:
                utterance = utt['utterance']
                utterance_tokens = [item.text for item in sm_parser(utterance)]
                speaker = utt['speaker']
                speaker_tokens = [item.text for item in sm_parser(speaker)]
                sentence_tokens = speaker_tokens + [":"] + utterance_tokens

                all_sentences.append(sentence_tokens)
                all_candidate_spans.append({
                        "sentenceIndex": i,
                        "startToken": 0,
                        "endToken": len(speaker) + 1
                })
        output.append({
            "sentences": all_sentences,
            "querySpans": all_query_spans,
            "candidateSpans": all_candidate_spans,
            "clickSpans": all_query_spans,
        })

100%|██████████| 11/11 [00:01<00:00,  5.96it/s]


In [5]:
def get_all_possible_spans(sentIdx, sentLen, window_size):
    all_possible_spans = []
    for i in range(sentLen-window_size):
        all_possible_spans.append({
            "sentenceIndex": sentIdx,
            "startToken": i,
            "endToken": i+window_size
        })
    return all_possible_spans

In [78]:
# All Spans
# Use Sliding Window to gather all potential spans

output = []
for epi_key in data:
    if epi_key != (1,1):
        continue
    episode = data[epi_key]
    # Each scene contain on episode
    for scene in episode:
        # Collect data to annotate
        all_sentences = []
        all_query_spans = []
        all_candidate_spans = []

        for i, utt in enumerate(scene):
            if "en_subtitles" in utt:
                # Fetch parse Noun Phrases from former parsing result
                utterance = " ".join([x.strip().lstrip('-').lstrip().lstrip('.').lstrip() for x in utt['en_subtitles']])
                utterance_tokens = [item.text for item in sm_parser(utterance)]
                speaker = utt['speaker']
                speaker_tokens = [item.text for item in sm_parser(speaker)]
                sentence_tokens = speaker_tokens + [":"] + utterance_tokens
                all_sentences.append(sentence_tokens)
                spans = list(set(utt['sm_noun_chunk']) | set(utt['berkeley_noun_chunk']) | set(utt['trf_noun_chunk']))
                spans.sort(key=lambda x: x[1])

                # Split NPs with Poesstive Pronoun into two parts
                all_new_spans = []
                for j, token in enumerate(sm_parser(utterance)):
                    if token.tag_=="PRP$":
                        for k, (word, start_idx, end_idx) in enumerate(spans):
                            if start_idx <= j < end_idx:
                                spans.pop(k)
                                new_span_1 = (token.text, j, j+1)
                                new_span_2 = (" ".join(utterance_tokens[j+1: end_idx]), j+1, end_idx)
                                all_new_spans.extend([new_span_1, new_span_2])
                    if token.tag_=="NNPS":
                        print(token)

                for item in all_new_spans:
                    spans.append(item)
                spans.sort(key=lambda x: x[1])

                # Merge overlapping spans into one maximum logical span
                to_pop = []
                for j, (word_0, start_idx_0, end_idx_0) in enumerate(spans):
                    for k, (word_1, start_idx_1, end_idx_1) in enumerate(spans):
                        if k==j:
                            continue
                        if (start_idx_1 >= start_idx_0) and (end_idx_1 <= end_idx_0):
                            to_pop.append(spans[k])
                for item in to_pop:
                    spans.remove(item)

                for span in spans:
                    all_query_spans.append({
                        "sentenceIndex": i,
                        "startToken": span[1] + len(speaker_tokens) + 1,
                        "endToken": span[2] + len(speaker_tokens) + 1
                    })
                # Gather all possible candidate spans
                temp = []
                for window_size in range(10):
                    temp += get_all_possible_spans(i, len(sentence_tokens), window_size)
                all_candidate_spans.extend(temp)
            else:
                utterance = utt['utterance']
                utterance_tokens = [item.text for item in sm_parser(utterance)]
                speaker = utt['speaker']
                speaker_tokens = [item.text for item in sm_parser(speaker)]
                sentence_tokens = speaker_tokens + [":"] + utterance_tokens

                all_sentences.append(sentence_tokens)

                # Gather all possible candidate spans
                temp = []
                for window_size in range(10):
                    temp += get_all_possible_spans(i, len(sentence_tokens), window_size)
                all_candidate_spans.extend(temp)

        print(len(all_candidate_spans))
        print(len(all_query_spans))
        # print(all_candidate_spans)
        print("=="*50)
        output.append({
            "sentences": all_sentences,
            "querySpans": all_query_spans,
            "candidateSpans": all_candidate_spans,
            "clickSpans": all_query_spans,
        })

3485
106
4484
123
12527
379
821
23
2506
86
388
14
1378
47
417
17
168
6
1270
41
Buttons
1934
58


In [74]:
with open('win_15_no_overlap_mar_20.csv', "w", encoding="utf-8") as csv_fh:
        fieldnames = ['json_data']
        writer = csv.DictWriter(csv_fh, fieldnames, lineterminator='\n')
        writer.writeheader()
        for line in output:
            writer.writerow({'json_data': json.dumps(line)})