In [1]:
import os
import pickle
import pycrfsuite
from tqdm import tqdm

In [2]:
class MyClass():

    def __init__(self, param):
        self.param = param


def load_object(filename):
    try:
        with open(filename, "rb") as f:
            return pickle.load(f)
    except Exception as ex:
        print("Error during unpickling object (Possibly unsupported):", ex)

In [3]:
def train_crf_model(path_to_batches, output_file_name, batch_names):
    trainer = pycrfsuite.Trainer(verbose=False)

    nb_blocks = 0
    nb_blocks_per_class = [0, 0, 0, 0]

    for batch_name in batch_names:
        path_batch = path_to_batches + "/" + batch_name + "/"
        idx2 = 0

        print("Loading batch : " + batch_name + " ...")
        for file_name in tqdm(sorted(os.listdir(path_batch + "sequence_multimodal"))):

            multimodal_obj = load_object(path_batch + "sequence_multimodal/" +  str(file_name))
            multimodal_seq = multimodal_obj.param

            sequential_information_obj = load_object(path_batch + "sequence_sequential_information/sequence_sequential_information_" + file_name.split('multimodal_')[1])
            sequential_information_seq = sequential_information_obj.param

            feature_seq = []

            for idx in range(len(multimodal_seq)):
                feature_seq.append({**multimodal_seq[idx], **sequential_information_seq[idx]})

            file_name_label = path_batch + "sequence_labels/sequence_labels_" + file_name.split('multimodal_')[1]
            label_obj = load_object(file_name_label)
            label_seq = label_obj.param

            nb_blocks += len(label_seq)

            for label in label_seq:
                nb_blocks_per_class[int(label)] += 1

            for idx in range(len(label_seq)):
                label_seq[idx] = str(label_seq[idx])

            trainer.append(feature_seq, label_seq)

    print("Nb blocks : " + str(nb_blocks))

    for idx in range(len(nb_blocks_per_class)):
        print("\tNumber of blocks for class " + str(idx) + " : " + str(nb_blocks_per_class[idx]))

    print("Setting CRF parameters ...")

    trainer.set_params({
        'c1': 1.0,
        'c2': 1e-3,
        'max_iterations': 100,
        'feature.possible_transitions': True
    })

    print("Training and saving model ...")

    trainer.train(output_file_name)

    print("Done.")


In [None]:
train_crf_model('/path_to_batches_folder', '/path_to_models/model.crfsuite', ['batch_1', ...])