In [None]:
import os
from glob import glob

import pandas as pd
import rootutils
import torch
from torch_geometric.data import Data
from tqdm import tqdm

rootutils.setup_root("../", indicator=".project-root", pythonpath=True)
import data

In [None]:
groups = {
    "Partisan A": {
        "a": [
            "democrats",
            "OpenChristian",
            "GamerGhazi",
            "excatholic",
            "EnoughLibertarianSpam",
            "AskAnAmerican",
            "lastweektonight",
        ],
        "b": [
            "Conservative",
            "progun",
            "TrueChristian",
            "Catholicism",
            "AskTrumpSupporters",
            "CGPGrey",
        ],
    },
    "Partisan B": {
        "a": [
            "hillaryclinton",
            "SandersForPresident",
            "askhillarysupporters",
            "BlueMidterm2018",
            "badwomensanatomy",
            "PoliticalVideo",
            "liberalgunowners",
            "GrassrootsSelect",
            "GunsAreCool",
        ],
        "b": [
            "The_Donald",
            "KotakuInAction",
            "HillaryForPrison",
            "AskThe_Donald",
            "PoliticalHumor",
            "ChoosingBeggars",
            "uncensorednews",
            "Firearms",
            "DNCleaks",
            "dgu",
        ],
    },
    "Affluence": {
        "a": [
            "vagabond",
            "hitchhiking",
            "DumpsterDiving",
            "almosthomeless",
            "AskACountry",
            "KitchenConfidential",
            "Nightshift",
            "alaska",
            "fuckolly",
            "FolkPunk",
        ],
        "b": [
            "backpacking",
            "hiking",
            "Frugal",
            "personalfinance",
            "travel",
            "Cooking",
            "fitbit",
            "CampingandHiking",
            "gameofthrones",
            "IndieFolk",
        ],
    },
    "Gender": {
        "a": [
            "AskMen",
            "TrollYChromosome",
            "AskMenOver30",
            "OneY",
            "TallMeetTall",
            "daddit",
            "ROTC",
            "FierceFlow",
            "malelivingspace",
            "predaddit",
        ],
        "b": [
            "AskWomen",
            "CraftyTrolls",
            "AskWomenOver30",
            "women",
            "bigboobproblems",
            "Mommit",
            "USMilitarySO",
            "HaircareScience",
            "InteriorDesign",
            "BabyBumps",
        ],
    },
    "Age": {
        "a": [
            "teenagers",
            "youngatheists",
            "teenrelationships",
            "AskMen",
            "saplings",
            "hsxc",
            "trackandfield",
            "bapccanada",
            "RedHotChiliPeppers",
        ],
        "b": [
            "RedditForGrownups",
            "TrueAtheism",
            "relationship_advice",
            "AskMenOver30",
            "eldertrees",
            "running",
            "trailrunning",
            "MaleFashionMarket",
            "canadacordcutters",
            "pearljam",
        ],
    },
    # "Edgy": {
    #     "a": [
    #         "memes",
    #         "watchpeoplesurvive",
    #         "MissingPersons",
    #         "twinpeaks",
    #         "pickuplines",
    #         "texts",
    #         "startrekgifs",
    #         "subredditoftheday",
    #         "peeling",
    #         "rapbattles",
    #     ],
    #     "b": [
    #         "ImGoingToHellForThis",
    #         "watchpeopledie",
    #         "MorbidReality",
    #         "TrueDetective",
    #         "MeanJokes",
    #         "FiftyFifty",
    #         "DaystromInstitute",
    #         "SRSsucks",
    #         "bestofworldstar",
    #     ],
    # },
}

groups["Partisan A"]["a"].extend(groups["Partisan B"]["a"])
groups["Partisan A"]["b"].extend(groups["Partisan B"]["b"])
del groups["Partisan B"]

In [None]:
from pprint import pprint

sizes_subreddit = {}
sizes_global = {}
size_numbers = {}


def get_true_size(subreddit):
    sizes = []
    # print(os.getcwd())
    # print(f"../../data/processed_files/processed/{subreddit}/*")
    for file in tqdm(
        glob(
            f"/scratch/rcohen/l2hebert/data/processed_files/processed/{subreddit}/*"
        )
    ):
        graph = torch.load(file, weights_only=False)
        size = len(graph["in_degree"])
        sizes.append(size)
    return sizes


for file in glob("raw_files/*/*-data.json"):
    _, topic, subreddit = file.split("/")
    subreddit = subreddit.split("-")[0]
    count = len(open(file).readlines())
    size_numbers[subreddit] = get_true_size(subreddit)
    if "Partisan" in topic:
        topic = "Partisan A"

    if subreddit in groups[topic]["a"]:
        topic = f"{topic} A"
    elif subreddit in groups[topic]["b"]:
        topic = f"{topic} B"
    else:
        print(f"Subreddit {subreddit} in Topic {topic} not found in any group")
        continue

    if "Partisan A" in topic:
        topic = topic.replace("Partisan A", "Partisan")
    if "Parisan B" in topic:
        topic = topic.replace("Partisan B", "Partisan")
    if topic not in sizes_subreddit:
        sizes_subreddit[topic] = {}
        sizes_global[topic] = 0

    sizes_subreddit[topic][subreddit] = count
    sizes_global[topic] += count


pprint(sizes_global)
pprint(sizes_subreddit)

In [None]:
for subreddit, sizes in size_numbers.items():
    num_to_remove = sum(1 if size < 5 else 0 for size in sizes)
    total_size = len(sizes)
    print(
        f"Subreddit {subreddit} ({total_size}) has {num_to_remove} graphs with less than 5 nodes"
    )
    # Find what topic the subreddit belongs to
    for topic, subreddits in sizes_subreddit.items():
        if subreddit in subreddits:
            sizes_subreddit[topic][subreddit] -= num_to_remove
            sizes_global[topic] -= num_to_remove

In [None]:
indices_to_remove = {}
for subreddit, sizes in size_numbers.items():
    indices_to_remove[subreddit] = []
    for idx, size in enumerate(sizes):
        if size < 5:
            indices_to_remove[subreddit].append(idx)

In [None]:
sizes_global

In [None]:
import math
from copy import deepcopy


def calculate_balanced_ratios(
    top_level_sizes, sub_dataset_sizes, holdout_ratio=0.2
):
    """
    Calculates balancing ratios (upsampling or downsampling) for sub-datasets to balance a hierarchical dataset,
    accounting for in-balance between sub-datasets as well.

    This function aims to achieve three levels of balancing:
    1) Balance classes 'A' and 'B' within each top-level dataset towards a middle ground.
    2) Balance the total size of each top-level dataset to each other towards a middle ground.
    3) Balance the size of sub-datasets within each top-level dataset towards a middle ground.
    4) Calculate balancing ratios for sub-datasets that contribute to 1, 2 & 3,
       allowing for both upsampling and downsampling.

    Args:
        top_level_sizes (dict): Dictionary of top-level dataset sizes.
                                 e.g., {'Affluence A': 9254, 'Affluence B': 36844, ...}
        sub_dataset_sizes (dict): Nested dictionary of sub-dataset sizes.
                                   e.g., {'Partisan A': {'liberalgunowners': 2389, 'excatholic': 260, ...},
                                         'Partisan B': {'Conservative': 5000, 'Republican': 6000, ...}, ...}

    Returns:
        dict: A nested dictionary with the same structure as sub_dataset_sizes,
              but containing balancing ratios for each sub-dataset instead of sizes.
              Ratios > 1 indicate upsampling, ratios < 1 indicate downsampling, ratio = 1 means no change.
              e.g., {'Partisan A': {'liberalgunowners': 0.8, 'excatholic': 1.5, ...}, ...}
    """

    balancing_ratios = {}
    balanced_top_level_sizes = {}

    top_level_sizes = {
        key: value * (1 - holdout_ratio)
        for key, value in top_level_sizes.items()
        if value > 0
    }
    sub_dataset_sizes = {
        key: {k: v * (1 - holdout_ratio) for k, v in value.items() if v > 0}
        for key, value in sub_dataset_sizes.items()
    }

    # 1. Balance A and B classes within each top-level dataset towards the average size
    for topic_class, total_size in top_level_sizes.items():
        topic, class_label = topic_class.split()
        if topic not in balanced_top_level_sizes:
            balanced_top_level_sizes[topic] = {"A": 0, "B": 0}
        balanced_top_level_sizes[topic][class_label] = total_size

    for topic_data in balanced_top_level_sizes.values():
        avg_class_size = (
            (topic_data["A"] + topic_data["B"]) / 2
            if (topic_data["A"] + topic_data["B"]) > 0
            else 0
        )
        topic_data["A_target_size"] = avg_class_size
        topic_data["B_target_size"] = avg_class_size

    # 2. Balance top-level datasets to each other towards the average total size
    total_topic_sizes = []
    for topic, topic_data in balanced_top_level_sizes.items():
        total_topic_sizes.append(
            topic_data["A_target_size"] + topic_data["B_target_size"]
        )

    avg_total_topic_size = (
        sum(total_topic_sizes) / len(total_topic_sizes)
        if total_topic_sizes
        else 0
    )

    for topic_data in balanced_top_level_sizes.values():
        topic_data["topic_target_size"] = avg_total_topic_size

    # 3. Calculate balancing ratios for sub-datasets, considering sub-dataset balance
    for topic_class, sub_datasets in sub_dataset_sizes.items():
        topic, class_label = topic_class.split()
        balancing_ratios[topic_class] = {}
        current_class_size = top_level_sizes[topic_class]
        target_class_size = balanced_top_level_sizes[topic][
            f"{class_label}_target_size"
        ]
        topic_target_size = balanced_top_level_sizes[topic]["topic_target_size"]
        current_topic_size = (
            balanced_top_level_sizes[topic]["A_target_size"]
            + balanced_top_level_sizes[topic]["B_target_size"]
        )  # sizes after A/B balancing

        # Class level balancing ratio (for A/B balancing within topic)
        class_balancing_factor = (
            target_class_size / current_class_size
            if current_class_size > 0
            else 1.0
        )

        # Topic level balancing ratio (for balancing topics to each other)
        topic_balancing_factor = (
            topic_target_size / current_topic_size
            if current_topic_size > 0
            else 1.0
        )

        # Sub-dataset level balancing ratio (balance sub-datasets within the class)
        values = [size for size in sub_datasets.values() if size > 0]
        avg_sub_dataset_size = sum(values) / len(values) if sub_datasets else 0

        for sub_dataset_name, sub_dataset_size in sub_datasets.items():
            sub_dataset_size_adjusted = sub_dataset_size
            if avg_sub_dataset_size == 0:
                sub_dataset_balancing_factor = 1.0  # Avoid division by zero if no sub-datasets or avg size is zero
            elif sub_dataset_size == 0:
                sub_dataset_balancing_factor = (
                    1.0  # if current sub dataset size is zero, no change needed
                )
            else:
                # We clip the balancing factor to be between 0.5 and 5 to avoid
                # extreme upsampling/downsampling
                sub_dataset_balancing_factor = max(
                    0.75,
                    min(avg_sub_dataset_size / sub_dataset_size_adjusted, 5),
                )

            # Base ratio based on class balancing
            base_ratio = class_balancing_factor

            # Refine ratio to also contribute to topic balancing and sub-dataset balancing.

            final_ratio = (
                base_ratio
                * topic_balancing_factor
                * sub_dataset_balancing_factor
            )
            sub_dataset_ratio = final_ratio

            balancing_ratios[topic_class][sub_dataset_name] = sub_dataset_ratio

    return balancing_ratios


# Example Usage:


balancing_ratios_result = calculate_balanced_ratios(
    deepcopy(sizes_global), deepcopy(sizes_subreddit)
)
pprint(balancing_ratios_result)

In [None]:
balancing_ratios_result["Partisan A"]

In [None]:
sizes_subreddit

In [None]:
sizes_subreddit["Partisan A"]

In [None]:
import json
import random

from sklearn.model_selection import train_test_split


def apply_balancing(sizes, ratios):
    balanced_data = {}
    for topic_class, sub_ratios in ratios.items():
        print(f"Balancing {topic_class}")
        print(sub_ratios.keys())
        balanced_data[topic_class] = {}
        for sub_dataset_name, ratio in sub_ratios.items():
            print(
                f"Balancing {topic_class} {sub_dataset_name} with ratio {ratio}"
            )
            num_to_remove = indices_to_remove[sub_dataset_name]
            original_data = list(
                range(sizes[topic_class][sub_dataset_name] + len(num_to_remove))
            )  # Assume 'data' is your input data structure
            original_data = [
                idx for idx in original_data if idx not in num_to_remove
            ]

            train, test_validation = train_test_split(
                original_data, test_size=0.2, random_state=42
            )
            if len(test_validation) < 20:
                print(
                    f"Skipping validation split for {topic_class} {sub_dataset_name} due to insufficient data ({len(test_validation)} samples)"
                )
                train += test_validation
                test = []
                validation = []
            else:
                test, validation = train_test_split(
                    test_validation, test_size=0.5, random_state=42
                )

            original_size = len(train)
            ratio = max(
                0.5, min(ratio, 5)
            )  # Clip ratio to be between 0.5 and 5
            target_size = int(original_size * ratio)

            random.seed(42)

            if ratio > 1:  # Upsampling
                # Implement upsampling logic (e.g., random oversampling)
                new_train = random.choices(
                    train, k=target_size
                )  # Simple example
            elif ratio < 1:  # Downsampling
                # Implement downsampling logic (e.g., random undersampling)
                new_train = random.sample(train, target_size)  # Simple example
            else:  # ratio == 1, no change
                new_train = train

            splits = {
                "old_train_idx": train,
                "train_idx": new_train,
                "test_idx": test,
                "valid_idx": validation,
            }
            topic_folder = topic_class[:-2]
            if topic_folder == "Partisan":
                topic_folder = "Partisan A"
            json.dump(
                splits,
                open(
                    f"raw_files/{topic_folder}/{sub_dataset_name}-split.json",
                    "w",
                ),
            )
            balanced_data[topic_class][sub_dataset_name] = splits
    return balanced_data


# Assuming 'your_data' is your data in the same structure as sub_dataset_sizes
balanced_dataset = apply_balancing(sizes_subreddit, balancing_ratios_result)

In [None]:
pprint(
    {
        key_1: {
            key_2: {k: len(v) for k, v in value_2.items()}
            for key_2, value_2 in value_1.items()
        }
        for key_1, value_1 in balanced_dataset.items()
    }
)

In [None]:
balanced_dataset