In [6]:
import json
from collections import Counter
from tenacity import retry, stop_after_attempt, wait_random_exponential
import time
import openai
from openai import OpenAI
import os

os.environ['OPENAI_API_KEY'] = 'sk-'

client = OpenAI()
log = []

@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(10))
def gpt_request(**kwargs):
    return client.chat.completions.create(**kwargs)

def get_filtered_transcriptions(sentence, min_length=5):
    transcriptions = [
        sentence['hubertlarge'],
        sentence['w2v2100'],
        sentence['w2v2960'],
        sentence['w2v2960large'],
        sentence['w2v2960largeself'],
        sentence['wavlmplus'],
        sentence['whisperbase'],
        sentence['whisperlarge'],
        sentence['whispermedium'],
        sentence['whispersmall'],
        sentence['whispertiny']
    ]
    
    # Filter out very short responses
    filtered_transcriptions = [trans for trans in transcriptions if len(trans) > min_length]
    
    # If all transcriptions are short, return all available transcriptions
    if not filtered_transcriptions:
        return transcriptions
    
    return filtered_transcriptions

def refine_transcription(transcriptions):
    prompt = "Choose the most comprehensive and coherent sentence from the following options. If impossible to decide, choose the longest option available. Output only the selected sentence without any additional explanation or phrases:\n"
    prompt += "\n".join([f"{i+1}. {trans}" for i, trans in enumerate(transcriptions)])

    try:
        response = gpt_request(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "You are a text refinement assistant."},
                {"role": "user", "content": prompt}
            ]
        ).choices[0].message.content.strip()
    except Exception as e:
        print(f"Error with prompt: {prompt}\nError: {e}")
        response = transcriptions[0]  # Default to the first option if there's an error
    
    # Ensure the response does not include numbering or additional phrases
    response = response.replace("The most comprehensive and coherent sentence is:", "").strip()
    
    # Remove any backslashes
    response = response.replace("\\", "")
    
    return response

def gpt_ensemble_transcription(sentence, min_length=5):
    filtered_transcriptions = get_filtered_transcriptions(sentence, min_length)
    refined_transcription = refine_transcription(filtered_transcriptions)
    return refined_transcription

def process_session(data, session_name):
    # Filter data for the specified session
    session_data = [item for item in data if item['id'].startswith(session_name)]
    
    # Process each script within the session
    processed_data = []
    for i, item in enumerate(session_data):
        try:
            item['ensemble'] = gpt_ensemble_transcription(item)
        except Exception as e:
            print(f"Error processing record {i+1} in session {session_name}: {e}")
            item['ensemble'] = None  # Indicate failure to process this item
        processed_data.append(item)
        
        if (i + 1) % 10 == 0 or i == len(session_data) - 1:
            print(f"Processed {i + 1}/{len(session_data)} records in session {session_name}")
    
    # Save the updated JSON file for the session
    output_file = f'F:\\SLT_2024\\train_ensemble\\{session_name}_processed.json'
    with open(output_file, 'w') as file:
        json.dump(processed_data, file, indent=4)

def apply_ensemble_method_for_session(file_path, session_name):
    # Load the JSON file
    with open(file_path, 'r') as file:
        data = json.load(file)
        
    # Process the specified session
    process_session(data, session_name)

# Apply the ensemble method to a specific session
file_path = 'F:\\SLT_2024\\train_ensemble\\iemocap_script_allemo.json'
session_name = 'Ses05'  # Change this to process a different session
apply_ensemble_method_for_session(file_path, session_name)


Processed 10/1034 records in session Ses05
Processed 20/1034 records in session Ses05
Processed 30/1034 records in session Ses05
Processed 40/1034 records in session Ses05
Processed 50/1034 records in session Ses05
Processed 60/1034 records in session Ses05
Processed 70/1034 records in session Ses05
Processed 80/1034 records in session Ses05
Processed 90/1034 records in session Ses05
Processed 100/1034 records in session Ses05
Processed 110/1034 records in session Ses05
Processed 120/1034 records in session Ses05
Processed 130/1034 records in session Ses05
Processed 140/1034 records in session Ses05
Processed 150/1034 records in session Ses05
Processed 160/1034 records in session Ses05
Processed 170/1034 records in session Ses05
Processed 180/1034 records in session Ses05
Processed 190/1034 records in session Ses05
Processed 200/1034 records in session Ses05
Processed 210/1034 records in session Ses05
Processed 220/1034 records in session Ses05
Processed 230/1034 records in session Ses

In [None]:
### Combine all processed sessions

In [None]:
import json
import os

# Define the order of the sessions
session_order = [
    "Ses01", "Ses02", "Ses03", "Ses04", "Ses05"
]

# Define the directory where the processed JSON files are stored
input_dir = "F:\\SLT_2024\\train_ensemble\\"

# Initialize a list to hold all the combined data
combined_data = []


for session_name in session_order:
    
    file_path = os.path.join(input_dir, f"{session_name}_processed.json")
    
    # Load the JSON data for the current session
    with open(file_path, 'r') as file:
        session_data = json.load(file)
    
    # Add the session data to the combined data list
    combined_data.extend(session_data)

# Define the output file path for the combined JSON file
output_file_path = os.path.join(input_dir, "F:\\SLT_2024\\train_ensemble\\combined\\combined.json")


with open(output_file_path, 'w') as output_file:
    json.dump(combined_data, output_file, indent=4)

print(f"Combined JSON file created at: {output_file_path}")
