<a href="https://colab.research.google.com/github/danielpatrickhug/Sentence_Kernels/blob/main/LAION_OIG_topic_modeling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
%%bash
pip install sentence-transformers
pip install bertopic
pip install openai

In [None]:
from google.colab import drive
import json
from sentence_transformers import SentenceTransformer

import math
import pandas as pd
import numpy as np
import scipy.sparse as sp
from scipy.sparse import csr_matrix
from tqdm import tqdm
import torch
from torch import Tensor
from collections import defaultdict
from typing import List
from getpass import getpass
import openai

In [None]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
openai_secret = getpass('Enter the secret value: ')
# Set up OpenAI API credentials
openai.api_key = openai_secret

Enter the secret value: ··········


In [None]:
def load_jsonl(filepaths):
    data = []
    for filepath in filepaths:
        with open(filepath, "r") as f:
            for line in f:
                data.append(json.loads(line))
    return data

In [None]:
res = load_jsonl(['/content/drive/MyDrive/OIG_Data/OIG_test_shuffled.jsonl'])

In [None]:
res[:3]

[{'text': 'User: Can you write short story and dialog based on the theme:\n  being scared\nAssistant: Sure, a short story and dialog based on the theme "being scared" could be:\n  Blas feels scared when he sees blood in the bathtub. The blood is running down the drain and it\'s a bright red color. Blas doesn\'t know what to do and he feels like he\'s going to faint.\n  Blas: Mom, there\'s blood in the bathtub. It\'s running down the drain and it\'s bright red. I don\'t know what to do.\n  Mother: Oh, my goodness! What happened?\n  Blas: I don\'t know. I was just taking a bath and I noticed the blood.\n  Mother: Let me take a look. Maybe you just got a cut somewhere and didn\'t realize it. But we need to clean it up so it doesn\'t get infected. Do you feel faint? You\'re looking very pale. Maybe you should sit down.\n  Mother: Well, there doesn\'t seem to be any cuts on you that could explain the bleeding. Have you been feeling ill lately? Any pain in your stomach or chest?\n  Blas: No,

In [None]:
def embed_data(data, key='query', model_name='all-MiniLM-L6-v2', cores=1, gpu=False, batch_size=128):
    """
    Embed the sentences/text using the MiniLM language model (which uses mean pooling)
    """
    print('Embedding data')
    model = SentenceTransformer(model_name)
    print('Model loaded')

    sentences = data[key].tolist()
    unique_sentences = data[key].unique()
    print('Unique sentences', len(unique_sentences))

    if cores == 1:
        embeddings = model.encode(unique_sentences, show_progress_bar=True, batch_size=batch_size)
    else:
        devices = ['cpu'] * cores
        if gpu:
            devices = None  # use all CUDA devices

        # Start the multi-process pool on multiple devices
        print('Multi-process pool starting')
        pool = model.start_multi_process_pool(devices)
        print('Multi-process pool started')

        chunk_size = math.ceil(len(unique_sentences) / cores)

        # Compute the embeddings using the multi-process pool
        embeddings = model.encode_multi_process(unique_sentences, pool, batch_size=batch_size, chunk_size=chunk_size)
        model.stop_multi_process_pool(pool)

    print("Embeddings computed")

    mapping = {sentence: embedding for sentence, embedding in zip(unique_sentences, embeddings)}
    embeddings = np.array([mapping[sentence] for sentence in sentences])
  
    return embeddings

def gaussian_kernel_torch(embs_a, embs_b, sigma=1.0):
    """
    Computes the Gaussian kernel matrix between two sets of embeddings using PyTorch.
    :param embs_a: Tensor of shape (batch_size_a, embedding_dim) containing the first set of embeddings.
    :param embs_b: Tensor of shape (batch_size_b, embedding_dim) containing the second set of embeddings.
    :param sigma: Width of the Gaussian kernel.
    :return: Tensor of shape (batch_size_a, batch_size_b) containing the Gaussian kernel matrix.
    """
    if not isinstance(embs_a, torch.Tensor):
        embs_a = torch.tensor(embs_a)

    if not isinstance(embs_b, torch.Tensor):
        embs_b = torch.tensor(embs_b)

    # Compute the pairwise distances between the embeddings
    dist_matrix = torch.cdist(embs_a, embs_b)

    # Compute the Gaussian kernel matrix
    kernel_matrix = torch.exp(-dist_matrix ** 2 / (2 * sigma ** 2))

    return kernel_matrix

def cos_sim(a: Tensor, b: Tensor):
    """
    Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
    :return: Matrix with res[i][j]  = cos_sim(a[i], b[j])
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(np.array(a))

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(np.array(b))

    if len(a.shape) == 1:
        a = a.unsqueeze(0)

    if len(b.shape) == 1:
        b = b.unsqueeze(0)

    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
    return torch.mm(a_norm, b_norm.transpose(0, 1))

def prune_ref_docs(qa_embs, ref_embs, ref_docs, threshold=0.1):
    """
    Drops unnecessary documents from the reference embeddings and updates the list of reference documents,
    and then recomputes the adjacency matrix.

    Parameters:
    qa_embs (numpy array): The embedding matrix of QA pairs.
    ref_embs (numpy array): The embedding matrix of reference sentences.
    ref_docs (list): The list of reference documents.
    threshold (float): The threshold below which documents are considered unnecessary.

    Returns:
    pruned_ref_embs (numpy array): The pruned embedding matrix of reference sentences.
    pruned_ref_docs (list): The pruned list of reference documents.
    pruned_A (numpy array): The pruned adjacency matrix.
    """
    
    # Compute the initial adjacency matrix with full reference embeddings
    A = gaussian_kernel_torch(qa_embs, ref_embs, sigma=0.5)
    print(f'Before: {A.shape}')
    # Compute the row-wise sum of the adjacency matrix
    row_sum = torch.sum(A, dim=0)
    
    # Identify the indexes of the relevant documents
    relevant_idx = torch.where(row_sum > threshold * row_sum.max())[0]
    
    # Drop unnecessary rows from the reference embeddings
    pruned_ref_embs = ref_embs[relevant_idx]
    
    # Update the list of reference documents
    pruned_ref_docs = [ref_docs[i] for i in relevant_idx]
    
    # Recompute the adjacency matrix with pruned reference embeddings
    pruned_A = gaussian_kernel_torch(qa_embs, pruned_ref_embs, sigma=0.5)
    print(f'After: {pruned_A.shape}')
    return pruned_ref_embs, pruned_ref_docs, pruned_A


def compute_cos_sim_kernel(embs, threshold=0.65, kernel_type="cosine", sigma=1.0):
    # match case to kernel type
    if kernel_type == "gaussian":
        A = gaussian_kernel_torch(embs, embs, sigma)
    if kernel_type == "cosine":
        A = cos_sim(embs, embs)
    adj_matrix = torch.zeros_like(A)
    adj_matrix[A > threshold] = 1
    adj_matrix[A <= threshold] = 0
    adj_matrix = adj_matrix.numpy().astype(np.float32)
    return adj_matrix


def k_hop_message_passing_sparse(A, node_features, k):
    """
    Compute the k-hop adjacency matrix and aggregated features using message passing.

    Parameters:
    A (numpy array or scipy sparse matrix): The adjacency matrix of the graph.
    node_features (numpy array or scipy sparse matrix): The feature matrix of the nodes.
    k (int): The number of hops for message passing.

    Returns:
    A_k (numpy array): The k-hop adjacency matrix.
    agg_features (numpy array): The aggregated feature matrix for each node in the k-hop neighborhood.
    """

    # Convert input matrices to sparse matrices if they are not already
    if not sp.issparse(A):
        A = sp.csr_matrix(A)
    if not sp.issparse(node_features):
        node_features = sp.csr_matrix(node_features)

    # Compute the k-hop adjacency matrix and the aggregated features
    A_k = A.copy()
    agg_features = node_features.copy()

    for i in tqdm(range(k)):
        # Compute the message passing for the k-hop neighborhood
        message = A_k.dot(node_features)
        # Apply a GCN layer to aggregate the messages
        agg_features = A_k.dot(agg_features) + message
        # Update the k-hop adjacency matrix by adding new edges
        A_k += A_k.dot(A)

    return A_k.toarray(), agg_features.toarray()

def compute_kernel(embs, threshold=0.65, sigma=1.0):
    # match case to kernel type
    A = gaussian_kernel_torch(embs, embs, sigma)
    adj_matrix = torch.zeros_like(A)
    adj_matrix[A > threshold] = 1
    adj_matrix[A <= threshold] = 0
    adj_matrix = adj_matrix.numpy().astype(np.float32)
    return adj_matrix

In [None]:
keys = set().union(*res)
keys

{'meta', 'metadata', 'source', 'text'}

In [None]:
sents = []
for i, item in enumerate(res):
    if 'meta' in item.keys():
        dataset_key = item['meta']['source']

    elif 'metadata' in item.keys():
        if isinstance((item['metadata']), str):
            dataset_key = item['metadata']
        else:
            dataset_key = item['metadata']['source']

    elif 'source' in item.keys():
        dataset_key = item['source']

    else:
        dataset_key = 'free_agent'
    sents.append((i, f"{item['text']}", dataset_key))

In [None]:
sents = [ ]

In [None]:
data_base = pd.DataFrame(sents, columns=['id_', 'query', 'dataset_key'])

In [None]:
groups = data_base.groupby('dataset_key')

# iterate over the groups and perform some function
for name, group in groups:
    print(f"Dataset: {name}")
    print(group.describe())
    

In [None]:
groups = data_base.groupby('dataset_key')

# iterate over the groups and perform some function
for name, group in groups:
    print(f"Dataset: {name}")
    print(group.describe())
    if 'flanv2_cot_ecqa_train' == name:
        break

In [140]:
data_partition = group

In [141]:
data_partition

Unnamed: 0,id_,query,dataset_key
237,237,User: Where could you see someone pretending t...,flanv2_cot_ecqa_train
1769,1769,User: Joe got his kid a plush bald Eagle. Wher...,flanv2_cot_ecqa_train
1878,1878,User: Where is a dining area likely to be extr...,flanv2_cot_ecqa_train
1949,1949,User: John thought that using television is di...,flanv2_cot_ecqa_train
1983,1983,User: Where is a monkey likely to enjoy being?...,flanv2_cot_ecqa_train
...,...,...,...
860234,860234,"User: People were talking on the corner, who w...",flanv2_cot_ecqa_train
860488,860488,User: If you must remember facts for a test yo...,flanv2_cot_ecqa_train
860745,860745,User: What kind of entertainment may feature a...,flanv2_cot_ecqa_train
860764,860764,User: Going into a trance is often associated ...,flanv2_cot_ecqa_train


In [142]:
max_len = len(data_partition)
closest_multiple = 64 * round(max_len // 64)
data = data_partition.sample(closest_multiple)

In [143]:
import argparse


from bertopic import BERTopic
from bertopic.representation import OpenAI
from bertopic.representation import MaximalMarginalRelevance
from bertopic.vectorizers import ClassTfidfTransformer
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import CountVectorizer


def load_topic_model(diversity=0.1, min_topic_size=10):
    vectorizer_model = CountVectorizer(stop_words="english")
    ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True)
    representation_model = OpenAI(model="gpt-3.5-turbo", delay_in_seconds=1, chat=True)
    model = SentenceTransformer(MODEL_NAME)
    #representation_model = MaximalMarginalRelevance(diversity=diversity)
    topic_model = BERTopic(
        nr_topics="auto",
        min_topic_size=min_topic_size,
        representation_model=representation_model,
        vectorizer_model=vectorizer_model,
        ctfidf_model=ctfidf_model,
        embedding_model=model,
    )
    return topic_model


def fit_topic_model(topic_model, data, embeddings, key="query"):
    topics, probs = topic_model.fit_transform(data[key].to_list(), embeddings)
    return topics, probs


def get_topic_info(topic_model):
    return topic_model.get_topic_info()


def reduce_topics(topic_model, data, nr_topics, key="query"):
    topic_model.reduce_topics(data[key].to_list(), nr_topics)
    return topic_model


def get_representative_docs(topic_model):
    return topic_model.get_representative_docs()


def reduce_outliers(topic_model, data, topics, probs, key="query", strategy="c-tf-idf"):
    vectorizer_model = CountVectorizer(stop_words="english")
    ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True)
    representation_model = OpenAI(model="gpt-3.5-turbo", delay_in_seconds=1, chat=True)
    #representation_model = MaximalMarginalRelevance(diversity=diversity)
    if strategy == "c-tf-idf":
        new_topics = topic_model.reduce_outliers(data[key].to_list(), topics, strategy, threshold=0.1)
    elif strategy == "embeddings":
        new_topics = topic_model.reduce_outliers(data[key].to_list(), topics, strategy)
    elif strategy == "distributions":
        new_topics = topic_model.reduce_outliers(data[key].to_list(), topics, probabilities=probs, strategy=strategy)
    else:
        raise ValueError("Invalid strategy")
    topic_model.update_topics(data[key].to_list(), topics=new_topics, representation_model=representation_model, vectorizer_model=vectorizer_model, ctfidf_model=ctfidf_model)
    return topic_model, new_topics


def compute_hierarchical_topic_tree(topic_model, data, key="query"):
    hierarchical_topics = topic_model.hierarchical_topics(data[key].to_list())
    tree = topic_model.get_topic_tree(hierarchical_topics)
    return hierarchical_topics, tree

In [144]:
MODEL_NAME = "all-MiniLM-L6-v2"
qa_embs = embed_data(data, key="query")
print(qa_embs.shape)
A = compute_kernel(qa_embs, threshold=0.70)

A_k, agg_features = k_hop_message_passing_sparse(A, qa_embs, 2)
print(A_k.shape)
topic_model = load_topic_model()
topics, probs = fit_topic_model(topic_model, data, agg_features)
freq = get_topic_info(topic_model)
rep_docs = get_representative_docs(topic_model)
print(freq)



Embedding data
Model loaded
Unique sentences 3264


Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Embeddings computed
(3264, 384)


100%|██████████| 2/2 [00:00<00:00, 31.14it/s]


(3264, 3264)
    Topic  Count                                               Name
0      -1    914             -1_Various scenarios and questions.___
1       0    261  0_the consequences and emotions of committing ...
2       1    217                            1_Storage and Sports___
3       2    111   2_Rural Living and Farming in Various Regions___
4       3    106  3_Shopping and Entertainment in Urban Downtown...
..    ...    ...                                                ...
61     60     12      60_Mental Illness Treatment and Management___
62     61     12              61_Beekeeping and Honey Production___
63     62     12  62_Bald Eagle Sightings in Various Nature Area...
64     63     11              63_Cats and Knights in Fairy Tales___
65     64     11    64_Fighting Inflation and Economic Stability___

[66 rows x 3 columns]


In [145]:
freq = get_topic_info(topic_model)
rep_docs = get_representative_docs(topic_model)
freq

Unnamed: 0,Topic,Count,Name
0,-1,914,-1_Various scenarios and questions.___
1,0,261,0_the consequences and emotions of committing ...
2,1,217,1_Storage and Sports___
3,2,111,2_Rural Living and Farming in Various Regions___
4,3,106,3_Shopping and Entertainment in Urban Downtown...
...,...,...,...
61,60,12,60_Mental Illness Treatment and Management___
62,61,12,61_Beekeeping and Honey Production___
63,62,12,62_Bald Eagle Sightings in Various Nature Area...
64,63,11,63_Cats and Knights in Fairy Tales___


In [146]:
hr, tree = compute_hierarchical_topic_tree(topic_model=topic_model, data=data)

100%|██████████| 64/64 [01:24<00:00,  1.32s/it]


In [147]:
print(tree)

.
├─Feelings and Experiences in Relation to Driving and Waiting in Lines.____
│    ├─Effects of Excessive Alcohol Consumption____
│    │    ├─Fighting inflation, civil unrest, failed bill, and mental illness management____
│    │    │    ├─Conquering opponents, fighting inflation, civil unrest, failed economic measures____
│    │    │    │    ├─■──Fighting Inflation and Economic Stability____ ── Topic: 64
│    │    │    │    └─Fear of War and Civil Unrest____
│    │    │    │         ├─Military leadership and civil unrest____
│    │    │    │         │    ├─■──Dealing with Anger and Civil Action____ ── Topic: 45
│    │    │    │         │    └─■──Military Leadership and Warfare Dynamics____ ── Topic: 41
│    │    │    │         └─■──Weapons and Explosives in War Zones____ ── Topic: 44
│    │    │    └─Mental Illness Treatment____
│    │    │         ├─■──Mental Illness Treatment and Management____ ── Topic: 60
│    │    │         └─■──Medical Training and Practices for Doctors____ ── T

In [148]:
outliers = []
for top, sent in zip(topics, data["query"].to_list()):
    if top == -1:
        outliers.append((top, sent))
out = pd.DataFrame(outliers, columns=['Topic', 'doc'])

docs = []
for k, v in rep_docs.items():
    docs.append((k, v))
rep = pd.DataFrame(docs, columns=["Topic", "Rep_docs"])
topic_info_freq = pd.merge(freq, rep, on="Topic")

out.to_csv(f"/content/drive/MyDrive/OIG_Data/data_summaries/topic_modeling/base_cluster_{name}_outlier_topics.csv")
topic_info_freq.to_csv(f"/content/drive/MyDrive/OIG_Data/data_summaries/topic_modeling/base_cluster_{name}_topics.csv")
with open(f'/content/drive/MyDrive/OIG_Data/data_summaries/topic_modeling/base_cluster_{name}_topic_tree.txt', 'w', encoding='utf-8') as f:
    f.write(tree)

Reduce outliers by c-tf-idf

In [149]:
cftidf_reduced_topic_model, cftidf_reduced_topics = reduce_outliers(topic_model, data, topics, probs, key="query", strategy="c-tf-idf")

In [150]:
cftidf_reduced_freq = get_topic_info(cftidf_reduced_topic_model)
cftidf_reduced_rep_docs = get_representative_docs(cftidf_reduced_topic_model)
cftidf_reduced_freq

Unnamed: 0,Topic,Count,Name
0,-1,650,-1_Common Locations and Origins___
1,0,263,0_Consequences of committing perjury___
2,1,224,1_Everyday Household Items and Sports___
3,2,113,2_Farming and Countryside Life in Various Regi...
4,3,108,3_Shopping and Parking in Urban Areas___
...,...,...,...
61,60,13,60_Mental Illness and Medication Adherence___
62,61,13,61_Beekeeping and Honey Production___
63,62,15,62_Bald Eagle Sightings in Various Nature Area...
64,63,14,63_Characters in fairy tales and historical no...


In [159]:
for topic, sent in zip(cftidf_reduced_topics, data['query'].to_list()):
    if topic == 64:
        print(sent)       

User: In order to fight inflation you should be doing this?
Options:
- dissatisfaction
- saving money
- not very much
- invest
- higher interest rates
Assistant: Saving money because savings of money is advisable if you want to fight the inflation as the prices of the goods increases in inflation.
User: James wanted to fight inflation because he didn't like losing value when he did what?
Options:
- more money
- save money
- can
- costs money
- pay less
Plus step-by-step reasons.
Assistant: Save money because increase in inflation reduces the value of saved money. Person doesn't like if the value of money that he saved before gets reduced.
User: The bill was passed in hopes of fighting inflation on the products, but it failed to force the companies into what?
Options:
- raising interest rates
- lowering prices
- higher interest rates
- save money
- sell more products
And why?
Assistant: Lowering prices because the prices of the goods has to be lowered in order to fight the inflation of 

In [151]:
cftidf_reduced_hr, cftidf_reduced_tree = compute_hierarchical_topic_tree(topic_model=cftidf_reduced_topic_model, data=data)

100%|██████████| 64/64 [01:24<00:00,  1.33s/it]


In [152]:
print(cftidf_reduced_tree)

.
├─Personal experiences and emotions____
│    ├─Health and well-being____
│    │    ├─Helping Children Break Boredom with Fun Activities____
│    │    │    ├─Learning and Perjury____
│    │    │    │    ├─Computer Programming and Efficiency____
│    │    │    │    │    ├─■──Memory and Forgetting____ ── Topic: 56
│    │    │    │    │    └─■──Using Computers for Increased Efficiency and Creating Programs Through Coding and Compilation____ ── Topic: 59
│    │    │    │    └─Teaching, Perjury and Frustration____
│    │    │    │         ├─Understanding Perjury and its Consequences____
│    │    │    │         │    ├─■──Learning Strategies____ ── Topic: 31
│    │    │    │         │    └─Perjury and its consequences____
│    │    │    │         │         ├─■──Learning and growth in knowledge acquisition, analysis, and application, and the role of hard scienc ── Topic: 12
│    │    │    │         │         └─■──Consequences of committing perjury____ ── Topic: 0
│    │    │    │         └─■

In [153]:
outliers = []
for top, sent in zip(cftidf_reduced_topics, data["query"].to_list()):
    if top == -1:
        outliers.append((top, sent))
out = pd.DataFrame(outliers, columns=['Topic', 'doc'])

docs = []
for k, v in cftidf_reduced_rep_docs.items():
    docs.append((k, v))
rep = pd.DataFrame(docs, columns=["Topic", "Rep_docs"])
topic_info_freq = pd.merge(cftidf_reduced_freq, rep, on="Topic")


out.to_csv(f"/content/drive/MyDrive/OIG_Data/data_summaries/topic_modeling/ctfidf_reduced_{name}_outlier_topics.csv")
topic_info_freq.to_csv(f"/content/drive/MyDrive/OIG_Data/data_summaries/topic_modeling/ctfidf_reduced_{name}_topics.csv")
with open(f'/content/drive/MyDrive/OIG_Data/data_summaries/topic_modeling/ctfidf_reduced_{name}_topic_tree.txt', 'w', encoding='utf-8') as f:
    f.write(cftidf_reduced_tree)

Reduce outliers by embedding distance

In [154]:
emb_reduced_topic_model, emb_reduced_new_topics = reduce_outliers(cftidf_reduced_topic_model, data, cftidf_reduced_topics, probs, key="query", strategy="embeddings")

In [155]:
emb_reduced_freq = get_topic_info(emb_reduced_topic_model)
emb_reduced_rep_docs = get_representative_docs(emb_reduced_topic_model)
emb_reduced_freq

Unnamed: 0,Topic,Count,Name
0,0,270,0_Perjury and its consequences___
1,1,252,1_Storage and Sports Equipment___
2,2,123,2_Countryside Farming and Animal Regions___
3,3,136,3_Shopping and Parking Options in Urban Areas___
4,4,104,4_Various Topics on Sexuality and Human Emotio...
...,...,...,...
60,60,22,60_Mental Illness and Treatment___
61,61,16,61_Beekeeping and Honey Production___
62,62,15,62_Bald Eagle Sightings in Various Natural Are...
63,63,15,63_Knights in Fairy Tales and Historical Novel...


In [156]:
emb_reduced_hr, emb_reduced_tree = compute_hierarchical_topic_tree(topic_model=emb_reduced_topic_model, data=data)

100%|██████████| 63/63 [01:24<00:00,  1.35s/it]


In [157]:
outliers = []
for top, sent in zip(emb_reduced_new_topics, data["query"].to_list()):
    if top == -1:
        outliers.append((top, sent))
out = pd.DataFrame(outliers, columns=['Topic', 'doc'])

docs = []
for k, v in emb_reduced_rep_docs.items():
    docs.append((k, v))
rep = pd.DataFrame(docs, columns=["Topic", "Rep_docs"])
topic_info_freq = pd.merge(emb_reduced_freq, rep, on="Topic")


out.to_csv(f"/content/drive/MyDrive/OIG_Data/data_summaries/topic_modeling/emb_reduced_{name}_outlier_topics.csv")
topic_info_freq.to_csv(f"/content/drive/MyDrive/OIG_Data/data_summaries/topic_modeling/emb_reduced_{name}_topics.csv")
with open(f'/content/drive/MyDrive/OIG_Data/data_summaries/topic_modeling/emb_reduced_{name}_topic_tree.txt', 'w', encoding='utf-8') as f:
    f.write(cftidf_reduced_tree)

In [None]:
def chat_gpt_inference(messages: list):
    response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=messages
    )
    return response

def create_prompt_message_template(text, role="user"):
    if role not in ["user", "assistant"]:
        raise ValueError("Not a valid role. Please use 'user' or 'assistant'.")
    return {"role": role, "content": text}

def compose_inference(text_block, messages):
    user_template = create_prompt_message_template(text_block, role="user")
    messages.append(user_template)
    chat_resp = chat_gpt_inference(messages)
    reply_text = chat_resp['choices'][0]['message']['content']
    assistant_template = create_prompt_message_template(reply_text, role="assistant")
    messages.append(assistant_template)
    return messages, assistant_template
    


def process_transcript(segments):
    buffer = []
    messages = [{"role": "system", "content": "You are a helpful assistant and you create educational content for college students and PHDs such as question and answer pairs. Your number one priority is to be factually correct"}]
    running_size = 0
    with open('conversation.txt', 'w') as f:
        for i, sent in enumerate(segments):
            if i % 10 == 0 and len(buffer) > 5:
                block = "\n".join(buffer)
                text_block = f"""Summarize the transcript up to this point with a summary up to 5 question answer pairs 
                                and representative python code. If theres nothing interesting, just reply 'skip'. ```{block}```"""
                running_size += len(text_block)
                if running_size < 4000:
                    print(running_size)
                    messages, assistant_template = compose_inference(text_block, messages)
                    f.write(f"{assistant_template['content']}\n")
                    buffer = []
                else:
                    oldest_user_reply = []
                    partition_sum = 0
                    for j, msg in enumerate(messages):
                        if msg["role"] == "user":
                            text_to_replace = messages[j]["content"]
                            partition_sum += len(text_to_replace)
                            if partition_sum > 2000:
                                break
                            else:
                                oldest_user_reply.append((j, text_to_replace))

                    if len(oldest_user_reply) >= 1:
                        for oldest in oldest_user_reply:
                            running_size -= len(oldest[1])
                            messages[oldest[0]]["content"] = ""
                    messages, assistant_template = compose_inference(text_block, messages)
                    f.write(f"{assistant_template['content']}\n")
                    buffer = []
            buffer.append(sent['text'])
    return messages


processed_messages = process_transcript(resp["segments"])