In [None]:
import pandas as pd
import pickle
from typing import Set, List

In [None]:
class TranscriptObject():
    keywords: Set[str]
    text: str
    speaker: str

    def __init__(self, keywords, text, speaker) -> None:
        self.keywords = keywords
        self.text = text
        self.speaker = speaker

    def __repr__(self) -> str:
        return f"TranscriptObject()\nkeywords: {self.keywords}\nspeaker: {self.speaker}\ntext: {self.text}\n"
    
    def __str__(self) -> str:
        return f"keywords: {self.keywords}\nspeaker: {self.speaker}\ntext: {self.text}\n"

class TranscriptPage():
    title: str
    keywords: Set[str]
    speakers: Set[str]
    transcript_objects: List[TranscriptObject]

    def __init__(self, title, keywords, speakers, transcript_objects) -> None:
        self.title = title
        self.keywords = keywords
        self.speakers = speakers
        self.transcript_objects = transcript_objects
    
    def __repr__(self) -> str:
        return f"TranscriptPage()\ntitle: {self.title}\nkeywords: {self.keywords}\nspeakers: {self.speakers}\ntranscript_objects: {self.transcript_objects}\n"
    
    def __str__(self) -> str:
        return f"title: {self.title}\nkeywords: {self.keywords}\nspeakers: {self.speakers}\ntranscript_objects: {self.transcript_objects}\n"     


In [None]:
with open('../data/trump_transcripts.pkl', 'rb') as f:
    trump_data: List[TranscriptPage] = pickle.load(f)

In [None]:
with open("../data/biden_transcripts.pkl", "rb") as f:
    biden_data: List[TranscriptPage] = pickle.load(f)

In [None]:
for transcript_page in biden_data:
    for transcript_object in transcript_page.transcript_objects:
        for word in transcript_object.text.split(" "):
            if word in keywords:
                transcript_page.keywords.add(word)
                transcript_object.keywords.add(word)

In [None]:
biden_data

In [None]:
keywords = set()

for data in trump_data:
    for kw in data.keywords:
        keywords.add(kw)

In [None]:
len(keywords)

In [None]:
ways_to_ask_questions = [
    "how do you feel about",
    "what is your opinion on",
    "what is your viewpoint for"
]

In [None]:
speakers = set()
for data in biden_data:
    for objs in data.transcript_objects:
        speakers.add(objs.speaker)

print("Joe Biden" in speakers)

In [None]:
import random

conversations = []

convo_texts = []
convo_speeches = []

#   {
#     "Context": "Please remind me of calling to Jessie at 2PM.",
#     "Knowledge": "reminder_contact_name is Jessie, reminder_time is 2PM",
#     "Response": "Sure, set the reminder: call to Jesse at 2PM"
#   },

for data in biden_data:
    for objs in data.transcript_objects:
        if len(objs.keywords) > 0 and objs.speaker == "Joe Biden":
            for kw in objs.keywords:
                question_start = random.choice(ways_to_ask_questions)
                question = f"{question_start} {kw}"
                knowledge = ""
                response = objs.text
                conversations.append({"Context": question, "Knowledge": knowledge, "Response": response})
                convo_texts.append(response)
                convo_speeches.append(data.title)

In [None]:
len(conversations)

In [None]:
conversations

In [None]:
import random
random.shuffle(conversations)

In [None]:
train_convos = conversations[0:50000]
val_convos = conversations[368892:378892]
test_convos = conversations[388892:398892]

In [None]:
import json
import jsonlines


with jsonlines.open("../data/biden_convos_train.json", mode="w") as writer:
    for i in train_convos:
        writer.write(i)
with jsonlines.open("../data/biden_convos_val.json", mode="w") as writer:
    for i in val_convos:
        writer.write(i)
with jsonlines.open("../data/biden_convos_test.json", mode="w") as writer:
    for i in test_convos:
        writer.write(i)

In [None]:
from sentence_transformers import SentenceTransformer, util
import torch

model = SentenceTransformer('all-MiniLM-L6-v2')

embeddings = model.encode(convo_texts[0:100000], convert_to_tensor=True)

In [None]:
query = model.encode(["what is your view on abortion?"], convert_to_tensor=True)

cosine_scores = util.cos_sim(embeddings, query)

In [None]:
indexes = torch.topk(cosine_scores.flatten(), 5).indices

for index in indexes:
    print("SPEECH:", convo_speeches[index])
    print("STATEMENT:", convo_texts[index])
    print()