In [2]:
import os
import sys
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold

In [2]:
def load_data_for_split(file_path):
    texts, labels = [], []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            line = line.strip()
            text = line.replace(' ', '')
            label = []
            i = 0
            while (i < len(text)):
                if text[i] in ['，', '。', '？', '！']:
                    label.append('1')
                else:
                    label.append('0')
                i += 1
            text_list = list(text)
            # label[-1] = 2
            # print(text, label)
            # break
            texts.append(text_list)
            labels.append(''.join(label))

            if (len(text_list) != len(label)):
                print('Error:', text, label)
    return texts, labels

In [3]:
from sklearn.model_selection import train_test_split

def recursive_split(data, n_splits, idx=0):
    print(f'Splitting data_{idx} with {len(data)} samples')
    if idx >= n_splits:
        with open(f'data/split6/data_{idx - n_splits}.txt', 'w', encoding='utf-8') as file:
            for text in data['sentence'].to_list():
                file.write(''.join(text) + '\n')
        return
    train_data, test_data = train_test_split(data, test_size=0.5, random_state=42, stratify=data['label'])
    # print(f'Lenght of train_data: {len(train_data)}')
    # print(f'Lenght of test_data: {len(test_data)}')
    recursive_split(train_data, n_splits, idx * 2)
    recursive_split(test_data, n_splits, idx * 2 + 1)

In [4]:
sentences, labels = load_data_for_split('data/train_large_2.txt')

data = pd.DataFrame({'sentence': sentences, 'label': labels})

In [5]:
label_counts = data['label'].value_counts()
valid_labels = label_counts[label_counts >= 4096].index
filtered_data = data[data['label'].isin(valid_labels)]

In [6]:
print(f'Original data with {len(data)} samples')
print(f'Filtered data with {len(filtered_data)} samples')

Original data with 592088 samples
Filtered data with 413525 samples


In [None]:
recursive_split(filtered_data, 4096, 1)

In [3]:
# Validate the class distribution in the splits
for i in range(4096):
    data = pd.read_csv(f'data/split6/data_{i}.txt', header=None, names=['sentence'])
    data['label'] = data['sentence'].apply(lambda x: ''.join(['1' if c in ['，', '。', '？', '！'] else '0' for c in x]))
    print(f'Class distribution in data_{i}:')
    print(data['label'].value_counts())
    print('')

Class distribution in data_0:
0000100000001               39
0000100001                  10
00000100000001              10
000010000100000001           8
00001000001                  5
000010000001                 4
0000010000100000001          3
00000100001                  2
000000010000100000001        2
0000010000001                2
0001000100000001             2
000010000100000100000001     2
0000100000100000001          2
0000000100001                2
000001000001                 2
000000100000001              2
0000000100000001             2
00000100000100000001         1
Name: label, dtype: int64

Class distribution in data_1:
0000100000001               39
00000100000001              11
0000100001                  11
000010000100000001           9
00001000001                  5
000010000001                 5
0000010000100000001          3
0000000100001                3
00000100001                  3
0000000100000001             2
00000100000100000001         2
00000010000000

KeyboardInterrupt: 