In [14]:
import pandas as pd
import os

def load_and_prepare_data(train_path, test_path):
    train_df = pd.read_csv(train_path)
    test_df = pd.read_csv(test_path)
    df = pd.concat([train_df, test_df], ignore_index=True)
    drop_cols = ['tconst', 'originalTitle', 'directors', 'writers', 'top_actors']
    df = df.drop(columns=drop_cols)
    df = df[df['startYear'].notna()]
    df['startYear'] = df['startYear'].astype(int)
    return df

def get_balanced_year_partitions(df, target_size=44226):
    year_counts = df['startYear'].value_counts().sort_index()
    partitions = []
    current_partition = []
    current_total = 0
    start_year = None

    for year, count in year_counts.items():
        if start_year is None:
            start_year = year
        current_partition.append(year)
        current_total += count

        if current_total >= target_size:
            end_year = year
            partitions.append((start_year, end_year))
            start_year = None
            current_partition = []
            current_total = 0

    if start_year is not None:
        end_year = year_counts.index[-1]
        partitions.append((start_year, end_year))

    return partitions

def split_and_save(df, partitions, output_dir=""):
    for i, (start, end) in enumerate(partitions, 1):
        partition_df = df[df['startYear'].between(start, end)]
        partition_df = partition_df.sample(frac=1, random_state=42).reset_index(drop=True)  # Shuffle
        out_path = os.path.join(output_dir, f"movies_v{i}.csv")
        partition_df.to_csv(out_path, index=False)
        print(f"Saved Version {i}: {start}–{end} → {out_path} ({len(partition_df)} movies)")

# Run the process
df = load_and_prepare_data("imdb_train_data.csv", "imdb_test_data.csv")
partitions = get_balanced_year_partitions(df)
split_and_save(df, partitions)


Saved Version 1: 1894–1995 → movies_v1.csv (44893 movies)
Saved Version 2: 1996–2014 → movies_v2.csv (44806 movies)
Saved Version 3: 2015–2025 → movies_v3.csv (42981 movies)


In [15]:
df.columns

Index(['primaryTitle', 'startYear', 'runtimeMinutes', 'genres',
       'averageRating', 'numVotes'],
      dtype='object')