In [None]:
import csv
import ast
from tqdm import tqdm
from collections import defaultdict

# File paths
loader_csv_path = "../csv/train_loader.csv"
orig_csv_path = "../csv/train_orig.csv"
output_csv_path = "../csv/train_loader_audio.csv"

# 1. Read train_orig.csv and group entries by (video_id, timestamp)
frame_dict = defaultdict(list)
with open(orig_csv_path, newline='') as f:
    reader = csv.reader(f)
    header = next(reader)  # skip header
    for row in reader:
        video_id = row[0]
        timestamp = round(float(row[1]), 2)
        entity_id = row[7]
        label_id = int(row[8])
        frame_dict[(video_id, timestamp)].append({
            "entity_id": entity_id,
            "label_id": label_id,
            "row": row
        })

# 2. Count total lines in train_loader.csv for progress bar
with open(loader_csv_path, 'r') as f:
    total_lines = sum(1 for _ in f)

# 3. Process train_loader.csv
rebuilt_rows = []
with open(loader_csv_path, newline='') as f:
    reader = csv.reader(f, delimiter='\t')
    for row in tqdm(reader, total=total_lines, desc="Processing entities"):
        entity_id = row[0]
        labels_str = row[3]
        loader_labels = ast.literal_eval(labels_str)
        num_frames_loader = len(loader_labels)

        # Find all entries in train_orig.csv belonging to this entity_id
        entity_entries = [e for entries in frame_dict.values() for e in entries if e["entity_id"] == entity_id]

        if not entity_entries:
            print(f"Warning: {entity_id} not found in train_orig.csv")
            continue

        # Sort by timestamp
        entity_entries.sort(key=lambda x: round(float(x["row"][1]), 2))

        new_labels = []
        for e in entity_entries:
            video_id = e["row"][0]
            timestamp = round(float(e["row"][1]), 2)

            # Find all entries belonging to the same frame
            same_frame_entries = frame_dict.get((video_id, timestamp), [])

            # If at least one label_id=1 → 1, otherwise 0
            new_label = 1 if any(entry["label_id"] == 1 for entry in same_frame_entries) else 0
            new_labels.append(new_label)

        # Check if frame counts match
        if len(new_labels) != num_frames_loader:
            print(f"Frame count mismatch for {entity_id}: loader={num_frames_loader}, rebuilt={len(new_labels)}")

        # Build new row
        new_row = [
            entity_id,
            row[1],          # keep original column
            row[2],          # keep another column
            new_labels,
            row[4] if len(row) > 4 else ""  # optional column
        ]
        rebuilt_rows.append(new_row)

# 4. Save to new CSV
with open(output_csv_path, 'w', newline='') as f:
    writer = csv.writer(f, delimiter='\t')
    for row in rebuilt_rows:
        writer.writerow(row)

print(f"\nRebuilt CSV saved to {output_csv_path}")


In [None]:
import pandas as pd
import ast

# File paths
vad_path = "../csv/val_loader_vad.csv"
audio_path = "../csv/val_loader_audio.csv"
output_path = "../csv/val_loader_merged.csv"

# Read CSV (no header, tab-separated)
vad_df = pd.read_csv(vad_path, sep="\t", header=None)
audio_df = pd.read_csv(audio_path, sep="\t", header=None)

# Assign column names
vad_df.columns = ["trackid", "col2", "col3", "labels", "col5"]
audio_df.columns = ["trackid", "col2", "col3", "labels", "col5"]

# Convert labels from string to list
vad_df["labels"] = vad_df["labels"].apply(lambda x: ast.literal_eval(x))
audio_df["labels"] = audio_df["labels"].apply(lambda x: ast.literal_eval(x))

# Build audio labels dictionary {trackid: labels}
audio_dict = dict(zip(audio_df["trackid"], audio_df["labels"]))

# Update VAD labels
def merge_labels(row):
    trackid = row["trackid"]
    vad_labels = row["labels"]
    if trackid in audio_dict:
        audio_labels = audio_dict[trackid]
        # Align two sequences and update
        merged = [1 if (v == 0 and a == 1) else v for v, a in zip(vad_labels, audio_labels)]
        return merged
    return vad_labels

vad_df["labels"] = vad_df.apply(merge_labels, axis=1)

# Save to new CSV
vad_df.to_csv(output_path, sep="\t", index=False, header=False)
print(f"Merging completed, result saved to {output_path}")


In [None]:
import pandas as pd
import ast
from tqdm import tqdm

# File paths
merged_path = "../csv/val_loader_merged.csv"
audio_path = "../csv/val_loader_audio.csv"
output_path = "../csv/val_loader_merged_refined.csv"

# Read CSV
merged_df = pd.read_csv(merged_path, sep="\t", header=None)
audio_df = pd.read_csv(audio_path, sep="\t", header=None)

# Add column names for easier manipulation
merged_df.columns = ["trackid", "col2", "col3", "labels", "col5"]
audio_df.columns = ["trackid", "col2", "col3", "labels", "col5"]

# Convert labels from string to list
merged_df["labels"] = merged_df["labels"].apply(lambda x: ast.literal_eval(x))
audio_df["labels"] = audio_df["labels"].apply(lambda x: ast.literal_eval(x))

# Build audio dictionary
audio_dict = dict(zip(audio_df["trackid"], audio_df["labels"]))

adjust_count = 0

def refine_labels(row):
    global adjust_count
    tid = row["trackid"]
    merged_labels = row["labels"]
    if tid not in audio_dict:
        return merged_labels
    
    audio_labels = audio_dict[tid]
    merged_labels = merged_labels[:]  # Copy
    
    i = 0
    while i < len(merged_labels):
        if merged_labels[i] == 1 and audio_labels[i] == 0:
            # Found the start of a mismatch segment
            start = i
            while i < len(merged_labels) and merged_labels[i] == 1 and audio_labels[i] == 0:
                i += 1
            end = i  # End of the segment (exclusive)
            seg_len = end - start
            if seg_len <= 6:
                for j in range(start, end):
                    merged_labels[j] = 0
                adjust_count += seg_len
        else:
            i += 1
    return merged_labels

# Apply tqdm progress bar
tqdm.pandas(desc="Refining labels")
merged_df["labels"] = merged_df.progress_apply(refine_labels, axis=1)

# Save to new file
merged_df.to_csv(output_path, sep="\t", index=False, header=False)

print(f"Refinement completed, new file saved to {output_path}")
print(f"A total of {adjust_count} labels were adjusted")
