# Cluster-based Summarization

## Import packages

In [1]:
import json
from itertools import product
import numpy as np
import os
from time import time
from tqdm.notebook import tqdm
from argsum import load_test_df, get_summetix_cluster_sums, get_t5_cluster_sums, get_llm_cluster_sums
import pandas as pd

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers
pip install xformers.
loading configuration file config.json from cache at /Users/moritz/.cache/huggingface/hub/models--distilbert-base-uncased/snapshots/12040accade4e8a0f71eabdb258fecc2e7e948be/config.json
Model config DistilBertConfig {
  "activation": "gelu",
  "architectures": [
    "DistilBertForMaskedLM"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "output_attentions": true,
  "output_hidden_states": true,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.33.3",
  "voc

## Load data

In [2]:
ArgKP21 = load_test_df('ArgKP21')
Debate_test = load_test_df('Debate_test')

## Define functions

In [3]:
def get_cluster_sums(df, cluster_dict, get_cluster_sums_callable, parameter_dict, output_dir = 'investigations/2_cluster_summaries', file_name = None):

    # Get cluster parameter names and values
    clu_parameter_names = cluster_dict['parameter_names']
    clu_parameter_values = cluster_dict['parameter_values']
    clu_parameter_combinations = list(product(*clu_parameter_values))

    # Get unique topics and stances
    topics = df['topic'].unique().tolist()
    stances = [str(int(sta)) for sta in sorted(df['stance'].unique())]

    # Get parameter for iteration
    iterate_parameter_names = [item[0] for item in parameter_dict.items() if type(item[1]) == list]
    iterate_parameter_values = [parameter_dict[parameter_name] for parameter_name in iterate_parameter_names]
    iter_parameter_value_combinations = list(product(*iterate_parameter_values))

    # Create empty dict to store the clusters
    results = dict(zip(['summaries', 'clu_parameter_names', 'clu_parameter_values', 'sum_parameter_names', 'sum_parameter_values'], [dict(zip([str(comb) for comb in clu_parameter_combinations], [dict(zip(topics, [dict(zip(stances, [dict(zip([str(comb) for comb in iter_parameter_value_combinations], [dict(zip(['sums', 'runtime'], [None, None])) for i in range(len(iter_parameter_value_combinations))])) for i in range(len(stances))])) for i in range(len(topics))])) for i in range(len(clu_parameter_combinations))]))] + [dict(zip([str(comb) for comb in clu_parameter_combinations], [dict(zip(topics, [dict(zip(stances, [dict(zip([str(comb) for comb in iter_parameter_value_combinations], [dict(zip(['sums', 'runtime'], [None, None])) for i in range(len(iter_parameter_value_combinations))])) for i in range(len(stances))])) for i in range(len(topics))])) for i in range(len(clu_parameter_combinations))]))] + [clu_parameter_names, clu_parameter_values ,iterate_parameter_names, iterate_parameter_values]))
        
    ################################
    ### Iterate: topic & stance ####
    ################################

    for topic_stance in tqdm([(topic, stance) for topic in topics for stance in stances], leave = True, desc = 'topic + stance'):
        
        topic = topic_stance[0]
        stance = topic_stance[1]
        mask_topic_stance = (df['topic'] == topic) & (df['stance'] == int(stance))
        arguments = df[mask_topic_stance]['argument'].to_list()

        ##########################################
        ### Iterate: cluster parameter values ####
        ##########################################

        for clu_parameter in tqdm(clu_parameter_combinations, leave = True, desc = 'clustering parameter'):
            
            if 'iterative_clustering' in cluster_dict.keys():
                cluster_ids = cluster_dict['iterative_clustering'][str(clu_parameter)][topic][stance]['cluster_ids']
                clustering_runtime = cluster_dict['iterative_clustering'][str(clu_parameter)][topic][stance]['runtime']   
            else:    
                cluster_ids = cluster_dict['clustering'][str(clu_parameter)][topic][stance]['cluster_ids']
                clustering_runtime = cluster_dict['clustering'][str(clu_parameter)][topic][stance]['runtime']
   
            cluster_ids_no_noise = [cluster_ids[i] for i in range(len(cluster_ids)) if cluster_ids[i] != -1]
            arguments_no_noise = [arguments[i] for i in range(len(cluster_ids)) if cluster_ids[i] != -1]

            cond_1 = (len(set(cluster_ids_no_noise)) > 1) # Number of clusters > 1
            cond_2 = ((len(cluster_ids_no_noise) / len(cluster_ids)) > 0.5) # Proportion of clustered arguments > 50%

            ############################
            ### Iterate: parameter #####
            ############################

            # Only if the conditions are true
            if cond_1 & cond_2:

                for comb in tqdm(iter_parameter_value_combinations, leave = False, disable = True, desc = 'summarization parameter'):
                    iterate_parameter_dict = {**parameter_dict, **dict(zip(iterate_parameter_names, list(comb)))}

                    ########################
                    ### Get summaries ######
                    ########################

                    try:
                        start_time = time()
                        cluster_sums = get_cluster_sums_callable(arguments_no_noise, cluster_ids_no_noise, topic = topic, stance = int(stance), **iterate_parameter_dict)
                        runtime = time() - start_time
                    except:
                        cluster_sums = None
                        runtime= None

                    results['summaries'][str(clu_parameter)][topic][stance][str(comb)]['sums'] = cluster_sums
                    if runtime != None:
                        results['summaries'][str(clu_parameter)][topic][stance][str(comb)]['runtime'] = np.round(clustering_runtime + runtime, 3)

    ########################
    ### Save results #######
    ########################

    if file_name != None:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(output_dir + '/' + file_name, 'w') as file:
            json.dump(results, file)
    
    return results

## Summetix

In [None]:
with open('investigations/1_argument_clusters/ArgKP21_Summetix.json') as f:
    summetix_cluster_dict = json.load(f)

summetix_parameter_dict = {}

summetix_results = get_cluster_sums(df = ArgKP21,
                                    cluster_dict = summetix_cluster_dict,
                                    get_cluster_sums_callable = get_summetix_cluster_sums,
                                    parameter_dict = summetix_parameter_dict,
                                    file_name = 'ArgKP21_Summetix.json'
                                    )

In [None]:
with open('investigations/1_argument_clusters/Debate_test_Summetix.json') as f:
    summetix_cluster_dict = json.load(f)

summetix_parameter_dict = {}

summetix_results = get_cluster_sums(df = Debate_test,
                                    cluster_dict = summetix_cluster_dict,
                                    get_cluster_sums_callable = get_summetix_cluster_sums,
                                    parameter_dict = summetix_parameter_dict,
                                    file_name = 'Debate_test_Summetix.json'
                                    )

## USKPM

In [None]:
with open('investigations/1_argument_clusters/ArgKP21_USKPM.json') as f:
    uskpm_cluster_dict = json.load(f)

uskpm_parameter_dict = {'max_length_text':512, 
                        'max_length_label':128,
                        'n': 5, 
                        'num_beams':6,
                        'temperature':None,
                        'do_sample':False, 
                        'p':None
                        }

uskpm_results = get_cluster_sums(df = ArgKP21,
                                 cluster_dict = uskpm_cluster_dict,
                                 get_cluster_sums_callable = get_t5_cluster_sums,
                                 parameter_dict = uskpm_parameter_dict,
                                 file_name = 'ArgKP21_USKPM.json'
                                 )

In [None]:
with open('investigations/1_argument_clusters/Debate_test_USKPM.json') as f:
    uskpm_cluster_dict = json.load(f)

uskpm_parameter_dict = {'max_length_text':512, 
                        'max_length_label':128,
                        'n': 5, 
                        'num_beams':6,
                        'temperature':None,
                        'do_sample':False, 
                        'p':None
                        }

uskpm_results = get_cluster_sums(df = Debate_test,
                                 cluster_dict = uskpm_cluster_dict,
                                 get_cluster_sums_callable = get_t5_cluster_sums,
                                 parameter_dict = uskpm_parameter_dict,
                                 file_name = 'Debate_test_USKPM.json'
                                 )

## MCArgSum

### Local

In [None]:
with open('investigations/1_argument_clusters/ArgKP21_MCArgSum_SBERT_all_mpnet_base.json') as f:
    mc_argsum_cluster_dict = json.load(f)

mc_argsum_local_parameter_dict = {'llm':'gpt-3.5-turbo',
                                  'optimization':'local', 
                                  'sum_token_length':8, 
                                  'sum_min_num':1, 
                                  'sum_max_num':1,
                                  'few_shot':True, 
                                  'exclude_topic':False, 
                                  'generate_less':False,
                                  'temperature':0.5,
                                  'frequency_penalty':None, 
                                  'n':5,
                                  'p':None
                                  }

mc_argsum_local_results = get_cluster_sums(df = ArgKP21,
                                           cluster_dict = mc_argsum_cluster_dict,
                                           get_cluster_sums_callable = get_llm_cluster_sums,
                                           parameter_dict = mc_argsum_local_parameter_dict,
                                           file_name = 'ArgKP21_MCArgSum_SBERT_all_mpnet_base_local.json'
                                           )

In [None]:
with open('investigations/1_argument_clusters/Debate_test_MCArgSum_SBERT_all_mpnet_base.json') as f:
    mc_argsum_cluster_dict = json.load(f)

mc_argsum_local_parameter_dict = {'llm':'gpt-3.5-turbo',
                                  'optimization':'local', 
                                  'sum_token_length':8, 
                                  'sum_min_num':1, 
                                  'sum_max_num':1,
                                  'few_shot':True, 
                                  'exclude_topic':False, 
                                  'generate_less':False,
                                  'temperature':0.5,
                                  'frequency_penalty':None, 
                                  'n':5,
                                  'p':None
                                  }

mc_argsum_local_results = get_cluster_sums(df = Debate_test,
                                           cluster_dict = mc_argsum_cluster_dict,
                                           get_cluster_sums_callable = get_llm_cluster_sums,
                                           parameter_dict = mc_argsum_local_parameter_dict,
                                           file_name = 'Debate_test_MCArgSum_SBERT_all_mpnet_base_local.json'
                                           )

### Global

In [None]:
with open('investigations/1_argument_clusters/ArgKP21_MCArgSum_SBERT_all_mpnet_base.json') as f:
    mc_argsum_cluster_dict = json.load(f)

mc_argsum_local_parameter_dict = {'llm':'gpt-3.5-turbo',
                                  'optimization':'global', 
                                  'sum_token_length':8, 
                                  'sum_min_num':1, 
                                  'sum_max_num':1,
                                  'few_shot':True, 
                                  'exclude_topic':False, 
                                  'generate_less':False,
                                  'temperature':0.5,
                                  'frequency_penalty':None, 
                                  'n':5,
                                  'p':None
                                  }

mc_argsum_local_results = get_cluster_sums(df = ArgKP21,
                                           cluster_dict = mc_argsum_cluster_dict,
                                           get_cluster_sums_callable = get_llm_cluster_sums,
                                           parameter_dict = mc_argsum_local_parameter_dict,
                                           file_name = 'ArgKP21_MCArgSum_SBERT_all_mpnet_base_global.json'
                                           )

In [None]:
with open('investigations/1_argument_clusters/Debate_test_MCArgSum_SBERT_all_mpnet_base.json') as f:
    mc_argsum_cluster_dict = json.load(f)

mc_argsum_local_parameter_dict = {'llm':'gpt-3.5-turbo',
                                  'optimization':'global', 
                                  'sum_token_length':8, 
                                  'sum_min_num':1, 
                                  'sum_max_num':1,
                                  'few_shot':True, 
                                  'exclude_topic':False, 
                                  'generate_less':False,
                                  'temperature':0.5,
                                  'frequency_penalty':None, 
                                  'n':5,
                                  'p':None
                                  }

mc_argsum_local_results = get_cluster_sums(df = Debate_test,
                                           cluster_dict = mc_argsum_cluster_dict,
                                           get_cluster_sums_callable = get_llm_cluster_sums,
                                           parameter_dict = mc_argsum_local_parameter_dict,
                                           file_name = 'Debate_test_MCArgSum_SBERT_all_mpnet_base_global.json'
                                           )