# Import

In [20]:
import pandas as pd
import numpy as np
import nltk
import torch
from transformers import pipeline
from nltk import sent_tokenize
from glob import glob

# Load the model

In [21]:
model_name = "facebook/bart-large-mnli"
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

In [None]:
def load_model(device="cpu", model_name=model_name):
    theme_classifier = pipeline(
        task="zero-shot-classification",
        model=model_name,
        device=device
    )
    return theme_classifier

In [9]:
theme_classifier = load_model(device=device)

In [12]:
theme_list = ["friendship","hope","sacrifice","battle","self development","betrayal","love","dialogue", "pain", "hatred", "dream", "hard work", "war"]

In [11]:
theme_classifier(
    "I gave him a right hook then a left jab",
    theme_list,
    multi_label=True
)

{'sequence': 'I gave him a right hook then a left jab',
 'labels': ['battle',
  'hard word',
  'self development',
  'pain',
  'war',
  'hope',
  'hatred',
  'dream',
  'sacrifice',
  'dialogue',
  'betrayal',
  'love',
  'friendship'],
 'scores': [0.9121251702308655,
  0.5533947348594666,
  0.4749987721443176,
  0.4389338493347168,
  0.35820794105529785,
  0.08781738579273224,
  0.0643683597445488,
  0.04519854485988617,
  0.04500005766749382,
  0.020132755860686302,
  0.012040241621434689,
  0.0042922901920974255,
  0.0028172011952847242]}

# Working with our dataset

In [23]:
files = glob("../data/subtitles/*.ass")
files[:5]

['../data/subtitles/Naruto Season 4 - 94.ass',
 '../data/subtitles/Naruto Season 4 - 80.ass',
 '../data/subtitles/Naruto Season 2 - 32.ass',
 '../data/subtitles/Naruto Season 8 - 185.ass',
 '../data/subtitles/Naruto Season 8 - 191.ass']

In [43]:
with open(files[0], "r") as file:
    lines = file.readlines()
    lines = lines[27:] # read only the lines after the metadata
    lines = [line.split(",")[9:] for line in lines] # remove all data with that is not part of the text column
    # some sentences may have , and we don't want a 2d list 
    lines = ["".join(line) for line in lines] # we will combine this in the main file

In [44]:
lines[:10]

['We are Fighting Dreamers aiming high\n',
 "Fighting Dreamers\\Ndon't care what people think about them\n",
 'Fighting Dreamers\\Nfollow what they believe\n',
 'Oli Oli Oli Oh! Just go my way\n',
 'Right here right now (Bang)\\NHit it straight like a line drive!\n',
 'Right here right now (Burn)\n',
 'Down a difficult road\\Nfilled with endless struggles\n',
 "Where do you think you are going\\Nfollowing someone else's map?\n",
 'An insightful crow comes along\\Nto tear up the map\n',
 'Now open your eyes and\\Ntake a look at the truth (Yeah!)\n']

*Note:* looks good however we have \\N character showing up

In [46]:
lines = [line.replace('\\N', '') for line in lines]

In [49]:
print(" ".join(lines[:10]))

We are Fighting Dreamers aiming high
 Fighting Dreamersdon't care what people think about them
 Fighting Dreamersfollow what they believe
 Oli Oli Oli Oh! Just go my way
 Right here right now (Bang)Hit it straight like a line drive!
 Right here right now (Burn)
 Down a difficult roadfilled with endless struggles
 Where do you think you are goingfollowing someone else's map?
 An insightful crow comes alongto tear up the map
 Now open your eyes andtake a look at the truth (Yeah!)



*Good*: now let us work with anotating the episode in question

In [64]:
files[2]

'../data/subtitles/Naruto Season 2 - 32.ass'

In [65]:
int(files[2].split('-')[-1].split('.')[0].strip())

32

In [67]:
def load_subtitles_dataset(dataset_path="../data/subtitles/"):
    subtitles_data = glob(dataset_path+"/*.ass")
    
    scripts = []
    episode_nums = []
    
    for path in subtitles_data:
        
        # read file
        with open(path, "r") as file:
            lines = file.readlines()
            
        lines = lines[27:]
        lines = [line.split(",")[9:] for line in lines]
        lines = ["".join(line) for line in lines]
        lines = [line.replace('\\N', '') for line in lines]
        script = " ".join(lines)
        scripts.append(script)
        
        # getting the episode in question
        episode_num = int(path.split('-')[-1].split('.')[0].strip())
        episode_nums.append(episode_num)
    
    df = pd.DataFrame.from_dict({"episode": episode_nums, "script": scripts})
    return df

In [68]:
df = load_subtitles_dataset()

In [70]:
df.head()

Unnamed: 0,episode,script
0,94,We are Fighting Dreamers aiming high\n Fightin...
1,80,We are Fighting Dreamers aiming high\n Fightin...
2,32,Press down hard on the gas\n That’s right ther...
3,185,Rock away your existence\n Shouting that you a...
4,191,Rock away your existence\n Shouting that you a...


## Test the model

In [71]:
script = df.iloc[0]["script"]

In [None]:
script_sentences = sent_tokenize(script)
script_sentences[:10]


["We are Fighting Dreamers aiming high\n Fighting Dreamersdon't care what people think about them\n Fighting Dreamersfollow what they believe\n Oli Oli Oli Oh!",
 'Just go my way\n Right here right now (Bang)Hit it straight like a line drive!',
 "Right here right now (Burn)\n Down a difficult roadfilled with endless struggles\n Where do you think you are goingfollowing someone else's map?",
 'An insightful crow comes alongto tear up the map\n Now open your eyes andtake a look at the truth (Yeah!)',
 "There's nothing to loseso let's GO!!!",
 "We are Fighting Dreamers aiming high\n Fighting Dreamersdon't care what people think about them\n Fighting Dreamersfollow what they believe\n Oli Oli Oli Oh!Just go my way\n Right here right now (Bang)Hit it straight like a line drive!",
 "Right here right now (Burn)We're gonna do it and do our best!",
 'Right here right now (Bang)Hit it straight like a line drive!',
 "Right here right now (Burn)We're gonna do it and do our best!",
 'BANG!']

In [76]:
print(len(script_sentences))

199


In [77]:
# create batces of sentences
sentence_batch = 32
script_batches = []
for index in range(0, len(script_sentences), sentence_batch):
    sent = ""
    if index+sentence_batch <= len(script_sentences): 
        sent = " ".join(script_sentences[index:index+sentence_batch])
    else:
        sent = " ".join(script_sentences[index:-1])
    script_batches.append(sent)


In [81]:
def load_model(device="cpu", model_name=model_name):
    theme_classifier = pipeline(
        task="zero-shot-classification",
        model=model_name,
        device=device
    )
    return theme_classifier

In [82]:
model_name = "facebook/bart-large-mnli"
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
theme_list = ["friendship","hope","sacrifice","battle","self development","betrayal","love","dialogue", "pain", "hatred", "dream", "hard work", "war"]

In [83]:
theme_classifier = load_model(device=device)
theme_output = theme_classifier(
    script_batches[:2],
    theme_list,
    multi_label=True
)



In [None]:
theme_output