# GTZAN dataset split
This notebook builds a track-level train/val/test split to avoid leakage when you segment audio later.


In [10]:
from pathlib import Path
import pandas as pd
from sklearn.model_selection import train_test_split

DATA_ROOT = Path(r'C:\sem 07\sem07 frontend development\e20-co542-classical-music-classification\code\GTZN\Data\genres_original')
SPLIT_DIR = Path('..') / 'data' / 'splits'
SPLIT_DIR.mkdir(parents=True, exist_ok=True)

items = []
for genre_dir in sorted(p for p in DATA_ROOT.iterdir() if p.is_dir()):
    label = genre_dir.name
    for wav in sorted(genre_dir.glob('*.wav')):
        items.append({'path': str(wav.resolve()), 'label': label, 'track_id': wav.stem})

df = pd.DataFrame(items)
if df.empty:
    raise SystemExit(f'No files found under {DATA_ROOT}')

train_paths, test_paths, y_train, y_test = train_test_split(
    df['path'], df['label'], test_size=0.2, random_state=42, stratify=df['label']
)
train_paths, val_paths, y_train, y_val = train_test_split(
    train_paths, y_train, test_size=0.2, random_state=42, stratify=y_train
)

def build_split(paths, labels, name):
    split = df[df['path'].isin(paths)].copy()
    split['split'] = name
    return split

df_train = build_split(train_paths, y_train, 'train')
df_val = build_split(val_paths, y_val, 'val')
df_test = build_split(test_paths, y_test, 'test')

df_train.to_csv(SPLIT_DIR / 'gtzan_train.csv', index=False)
df_val.to_csv(SPLIT_DIR / 'gtzan_val.csv', index=False)
df_test.to_csv(SPLIT_DIR / 'gtzan_test.csv', index=False)

print('Saved splits to', SPLIT_DIR)
print(df_train['label'].value_counts())
print(df_val['label'].value_counts())
print(df_test['label'].value_counts())


Saved splits to ..\data\splits
label
blues        64
classical    64
country      64
disco        64
hiphop       64
jazz         64
metal        64
pop          64
reggae       64
rock         64
Name: count, dtype: int64
label
blues        16
classical    16
country      16
disco        16
hiphop       16
jazz         16
metal        16
pop          16
reggae       16
rock         16
Name: count, dtype: int64
label
blues        20
classical    20
country      20
disco        20
hiphop       20
jazz         20
metal        20
pop          20
reggae       20
rock         20
Name: count, dtype: int64
