In [1]:
# default_exp series.train

# series.train
> Methods for training a `RandomForestClassifier` from `scikit-learn` to classify MRI series types.

In [3]:
#export
from dicomtools.basics import *
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold, GroupKFold, GroupShuffleSplit, cross_val_score, RandomizedSearchCV, GridSearchCV
from pprint import pprint

np.random.seed(42)

In [3]:
#export
def train_setup(df, preproc=True):
    "Extract labels for training data and return 'unknown' as test set"
    if preproc:
        df1 = preprocess(df)
        labels = extract_labels(df1)
        df1 = df1.join(labels[['plane', 'contrast', 'seq_label']])
    else:
        df1 = df.copy()
    filt = df1['seq_label'] == 'unknown'
    train = df1[~filt].copy().reset_index(drop=True)
    test = df1[filt].copy().reset_index(drop=True)
    y, y_names = pd.factorize(train['seq_label'])
    return train, test, y, y_names


In [10]:
#export
def train_fit(train, y, features, fname='model-run.skl'):
    "Train a Random Forest classifier on `train[features]` and `y`, then save to `fname` and return."
    clf = RandomForestClassifier(n_jobs=2, random_state=0)
    clf.fit(train[features], y)

    dump(clf, fname)
    return clf


In [1]:
#export
def train_setup_abdomen(df, cols=['patientID','exam','series'], preproc=False, need_labels=False):

    if preproc:
        df1=preprocess(df)
        
    else:
        df1=df.copy()
    
    if need_labels:

        labels = extract_labels(df1)
        df1 = df1.merge(labels, on=cols)
 
    length = df1.shape[0]

    #gkf = GroupKFold(n_splits=5)
    #for train_set, val_set in gkf.split(df1, groups=df1['patientID']):
    #    train, val = df1.loc[train_set], df1.loc[val_set]
   
    train_set, val_set = next(GroupShuffleSplit(test_size=.20, n_splits=1, random_state = 42).split(df1, groups=df1['patientID']))

    train = df1.iloc[train_set]
    val = df1.iloc[val_set]
    y, y_names = train['label_code'],train['GT label']
 
    return train, val, y, y_names


