In [None]:
%load_ext autoreload
%autoreload 2
from scipy.io import arff
import xml.etree.ElementTree as ET
import xmltodict
import pandas as pd
from os.path import join
from tqdm import tqdm
import traceback
import imodels
import imodels.util.data_util
import traceback

In [None]:
# html = pd.read_html('https://www.uco.es/kdis/mllresources/#EnronDesc')
# # drop last column
# df = html[0]
# df = df.iloc[:, :-1]
# # convert multiindex to single index
# df.columns = [col[0] for col in df.columns.values]
# df.to_csv('multitask.csv')
ovw = pd.read_csv('multitask.csv')
vals = ovw.Dataset.str.lower().values

In [None]:
import os
# manually download each dataset in mulan format to folder 'dsets'
dsets = [d for d in sorted(os.listdir('dsets'))
         if os.path.isdir(join('dsets', d))]
os.makedirs('processed', exist_ok=True)

In [None]:
def arff_to_df(arff_file, xml_file):
    # convert to csv
    data, meta = arff.loadarff(arff_file)

    with open(xml_file, 'r') as file:
        # Parse the XML file into a dictionary
        targets = xmltodict.parse(file.read())
    targets = [d['@name'] for d in [targets['labels']['label']][0]]

    df = pd.DataFrame(data)
    for target in targets:
        assert target in df.columns

    # append __target to each target column
    df.columns = [
        f'{col}__target' if col in targets else col for col in df.columns]
    return df


for dset in tqdm(dsets):
    files = os.listdir(join('dsets', dset))
    arff_file = [f for f in files if f.endswith('.arff')][0]
    xml_file = [f for f in files if f.endswith('.xml')][0]

    try:
        dset_name = dset.replace("_Mulan", '')
        dset_name = dset_name.lower()
        if dset_name.replace('_', '-') in vals:
            dset_name = dset_name.replace('_', '-')
        elif dset_name.replace('-', '_') in vals:
            dset_name = dset_name.replace('-', '_')
        out_file = join('processed', f'{dset_name}.csv')
        if not os.path.exists(out_file):
            df = arff_to_df(join('dsets', dset, arff_file),
                            join('dsets', dset, xml_file))
            df.to_csv(join('processed', f'{dset_name}.csv'), index=False)
    except:
        print(f'Failed to process {dset}')
        # print error trace
        # traceback.print_exc()
        continue

Process byte strings

In [None]:
processed_files = sorted(
    [d for d in os.listdir('processed') if d.endswith('.csv')])
os.makedirs('processed_clean', exist_ok=True)


def convert_byte_strings(arr):
    def decode_if_bytes(s): return s.strip("b'") if isinstance(
        s, str) and s.startswith("b'") else s
    vectorized_func = np.vectorize(decode_if_bytes)
    return vectorized_func(arr)


for file in tqdm(processed_files):
    df = pd.read_csv(join('processed', file))
    df = df.apply(convert_byte_strings)
    df.to_csv(join('processed_clean', file), index=False)

### Manually rename csvs then check that they match main csv

In [None]:
processed_files = [d for d in os.listdir('processed') if d.endswith('.csv')]
print(f'Processed {len(processed_files)} datasets')
dset_names_processed = [d.replace('.csv', '') for d in processed_files]
for dset_name in dset_names_processed:
    assert dset_name in vals, dset_name

In [None]:
ovw_filt = ovw[ovw.Dataset.str.lower().isin(dset_names_processed)].drop(
    columns=['Unnamed: 0']).reset_index().drop(columns=['index'])
ovw_filt['Dataset'] = ovw_filt['Dataset'].str.lower()

In [None]:
print(ovw_filt.to_markdown())

# See if the new data can be accessed

In [None]:
X, y, feature_names, target_col_names = imodels.get_clean_dataset(
    'water-quality_multitask', return_target_col_names=True)
print('shapes', X.shape, y.shape)

In [None]:
names = imodels.util.data_util.DSET_MULTITASK_NAMES
# names = ['corel16k001']
for name in tqdm(names):
    try:
        X, y, feature_names, target_col_names = imodels.get_clean_dataset(
            name + '_multitask', return_target_col_names=True, convertna=False)
        print('unique labels in each target of np array y', [
              len(set(y[:, i])) for i in range(y.shape[1])])
    except:
        print('failed', name)
        traceback.print_exc()