In [1]:
import pandas as pd
import wikipedia
from bs4 import BeautifulSoup as soup
import re
import pickle
import time

In [2]:
# Cast List code

def get_cast_list(page):
    cast_text = page.section("Cast")
    if cast_text is not None:
        initial_cast_list = []
        bullet_points = cast_text.split("\n")
        if len(bullet_points) == 1:
            bullet_points_new = bullet_points[0].split(":")
            for i in range(0,len(bullet_points_new)):
                part = bullet_points_new[i]
                if " as " in part:
                    initial_cast_list.append(part.split(" as ")[-1])
        else:
            if bullet_points[0].split(":")[-1] == "":
                s = 1
            else:
                s = 0
            for i in range(s,len(bullet_points)):
                parts = bullet_points[i].split(":")
                if len(parts) == 1:
                    parts = bullet_points[i].split(",")
                name = parts[0]
                if " as " in name:
                    initial_cast_list.append(name.split(" as ")[1])
                else:
                    if name == "":
                        continue
                    else:
                        initial_cast_list.append(name)
        cast_list = cast_list_all_combinations(initial_cast_list)
        final_cast_list = clean_up_cast_list(cast_list)
        return final_cast_list
    else:
        return []

def cast_list_all_combinations(cast_list):
    all_combos = []
    for name in cast_list:
        names = name.split(" ")
        if len(names) == 2:
            all_combos.append(names[0])
            all_combos.append(names[1])
        elif len(names) >= 3:
            if '"' in name:
                quote_name = name.split('"')[1] # takes name inside quotation marks
                all_combos.append(quote_name)
            all_combos.append(names[0])
            all_combos.append(names[-1])
            all_combos.append(names[0] + " " + names[-1])
        all_combos.append(name)
    return all_combos

def clean_up_cast_list(cast_list):
    final_cast_list = []
    for cast in cast_list:
        if len(cast) > 1:
            if cast[0].isupper() and len(cast) < 50:
                if cast != "The" and "also" not in cast:
                    if "/" not in cast and "(" not in cast and ")" not in cast:
                        final_cast_list.append(cast)
    reduced_cast_list = set(final_cast_list)
    return reduced_cast_list

In [3]:
# Code for tagging sentences

def tag_all_sentences(sentences,cast_list,debug=False):
    tagged_sentences = []
    for sentence in sentences:
        tag_list = tag_cast_in_sentence(sentence,cast_list,debug)
        sentence = sentence.replace("\'s","'s")
        if len(tag_list) != 0:
            tagged_sentences.append( (sentence,tag_list) )
    return tagged_sentences

def tag_cast_in_sentence(sentence,cast_list,debug=False):
    tag_list = []
    sentence = sentence + " "
    for name in cast_list:
        name_alternatives = [name + " ", name + "'", name + ".", name + ","]
        for name in name_alternatives:
            locations = [m.start() for m in re.finditer(name,sentence)]
            for location in locations:
                tag_list.append((location,location+len(name)-1,"PERSON"))
        
    final_tag_list = clean_up_tag_list(tag_list,debug)
    final_tag_list = combine_connected_tags(final_tag_list,debug)
    return final_tag_list

def clean_up_tag_list(tag_list,debug=False):
    if debug:
        print("Input: {}".format(tag_list))
    tag_list = list(set(tag_list))
    tag_to_remove = find_tag_to_remove(tag_list,debug)
    while tag_to_remove != "":
        if debug:
            print("Removing {}".format(tag_to_remove))
        tag_list.remove(tag_to_remove)
        tag_to_remove = find_tag_to_remove(tag_list,debug)
    new_tag_list = combine_connected_tags(tag_list,debug)
    final_tag_list = remove_overlapping_tags(new_tag_list,debug)
    return final_tag_list
        
def find_tag_to_remove(tag_list,debug=False):
    for tag in tag_list:
        start = tag[0]
        end = tag[1]
        length = end-start
        for tag2 in [x for x in tag_list if x is not tag]:
            start2 = tag2[0]
            end2 = tag2[1]
            length2 = end2-start2
            if start == start2 or end == end2:
                if debug:
                    print("A: {} - {}, B: {} - {}".format(start,end,start2,end2))
                if length >= length2:
                    return tag2
                elif length2 > length:
                    return tag
    return ""

def combine_connected_tags(tag_list,debug=False):
    for tag in tag_list:
        start = tag[0]
        end = tag[1]
        for tag2 in [x for x in tag_list if x is not tag]:
            start2 = tag2[0]
            end2 = tag2[1]
            if end+1 == start2:
                tag_list.remove(tag2)
                tag_list.remove(tag)
                new_tag = (start,end2,"PERSON")
                tag_list.append(new_tag)
                if debug:
                    print("Combined {} and {} into {}".format(tag,tag2,new_tag))
                tag_list = combine_connected_tags(tag_list,debug)
            elif end2+1 == start:
                tag_list.remove(tag2)
                tag_list.remove(tag)
                new_tag = (start2,end,"PERSON")
                tag_list.append(new_tag)
                if debug:
                    print("Combined {} and {} into {}".format(tag2,tag,new_tag))
                tag_list = combine_connected_tags(tag_list,debug)
    # Only get here if no connected tags
    return tag_list

def find_overlapping_tag(tag_list,debug=False):
    for tag in tag_list:
        start = tag[0]
        end = tag[1]
        length = end-start
        for tag2 in [x for x in tag_list if x is not tag]:
            start2 = tag2[0]
            end2 = tag2[1]
            length2 = end2-start2
            if (end2 >= start and end2 <= end) or (end >= start2 and end <= end2):
                if length >= length2:
                    return tag2
                else:
                    return tag
    return ""

def remove_overlapping_tags(tag_list,debug=False):
    tag_list = list(set(tag_list))
    overlapping_tag = find_overlapping_tag(tag_list,debug)
    while overlapping_tag != "":
        if debug:
            print("Removing {}".format(overlapping_tag))
        tag_list.remove(overlapping_tag)
        overlapping_tag = find_overlapping_tag(tag_list,debug)
    return tag_list            

In [4]:
# Code for getting plot sentences

# TODO: Add catch for "NoneType has no attribute replace" error when plot/ plot sentence is empty

def get_plot_sentences(page,debug=False):
    plot_text = page.section("Plot")
    plot_text_split = plot_text.replace("\n"," ")
    plot_text_final = plot_text_split.replace("\'s","'s")
    plot_sentences = plot_text_final.split(". ")
    return plot_sentences

In [5]:
# Code for collecting training data from pages

def get_training_data_from_all_pages(list_of_titles,debug=False):
    training_data = []
    for title in list_of_titles:
        try:
            page = wikipedia.WikipediaPage(title=title)
        except wikipedia.exceptions.DisambiguationError:
            print("Disambiguous Title Name, Skipping {}".format(title))
            continue
        except wikipedia.exceptions.PageError:
            print("No page found with title {}".format(title))
            continue
        
        try:
            current_data = get_training_data_from_page(page,debug)
        except Exception as e:
            print("Error for Title: {}: {}".format(title,e))
            current_data = []
            continue
        
        if current_data != []:
            training_data += current_data
            print("Added {} new training examples successfully from {}".format(len(current_data),title))
        else:
            print("No cast found from {}".format(title))
    return training_data

def get_training_data_from_page(page,debug=False):
    cast_list = get_cast_list(page)
    if cast_list != []:
        if debug:
            print("Cast List:{}".format(sorted(cast_list)))
        plot_sentences = get_plot_sentences(page,debug)
        tagged_sentences = tag_all_sentences(plot_sentences,cast_list,debug)
        return tagged_sentences
    else:
        return []

In [6]:
# Top function

def create_training_test_sets(csv_file,number_of_titles,training_file,test_file,debug=False):
    # Uses a 80:20 train:test split
    t1 = time.time()
    titles_df = pd.read_csv(csv_file)
    
    train_limit = int(number_of_titles*0.8)
    list_of_titles_train = titles_df["primaryTitle"].tolist()[:train_limit]
    training_data = get_training_data_from_all_pages(list_of_titles_train,debug)
    print("Training Data Collected")
    pickle.dump(training_data, open(training_file, 'wb'))
    print("Training Data Stored")
    
    list_of_titles_test = titles_df["primaryTitle"].tolist()[train_limit:number_of_titles]
    test_data = get_training_data_from_all_pages(list_of_titles_test,debug)
    print("Test Data Collected")
    pickle.dump(test_data, open(test_file, 'wb'))
    print("Test Data Stored")
    t2 = time.time()
    print("Took {} seconds = {} seconds per page".format(round(t2-t1,2),round((t2-t1)/number_of_titles,2)))

In [7]:
create_training_test_sets("data/imdb_popular_content.csv",100,"data/train_100_3","data/test_100_3")

Added 32 new training examples successfully from The Shawshank Redemption




  lis = BeautifulSoup(html).find_all('li')


Disambiguous Title Name, Skipping The Dark Knight
Added 30 new training examples successfully from Inception
Added 29 new training examples successfully from Fight Club
No cast found from Pulp Fiction
Added 29 new training examples successfully from Forrest Gump
No cast found from Game of Thrones
Added 28 new training examples successfully from The Matrix
Added 31 new training examples successfully from The Lord of the Rings: The Fellowship of the Ring
Added 42 new training examples successfully from The Lord of the Rings: The Return of the King
Added 1 new training examples successfully from The Godfather
Added 33 new training examples successfully from The Dark Knight Rises
Disambiguous Title Name, Skipping Interstellar
Added 34 new training examples successfully from The Lord of the Rings: The Two Towers
Disambiguous Title Name, Skipping Se7en
No cast found from Breaking Bad
Added 4 new training examples successfully from Django Unchained
No cast found from Gladiator
Added 28 new tr