In [4]:
!pip install flair



In [5]:
# Imports
import csv, os
import pandas as pd
import numpy as np
from flair.data import Corpus
from flair.datasets import CSVClassificationCorpus
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from pathlib import Path

from sklearn.model_selection import train_test_split

In [8]:
# # Get datasets
train_file_path = "https://raw.githubusercontent.com/noharm-ai/substance-prediction/refs/heads/main/datasets/med_all_bio_train.csv" # ARTHUR / OLIMAR
val_file_path = 'https://raw.githubusercontent.com/noharm-ai/substance-prediction/refs/heads/main/datasets/med_all_bio_val.csv' # ARTHUR / OLIMAR
test_file_path = 'https://raw.githubusercontent.com/noharm-ai/substance-prediction/refs/heads/main/datasets/med_all_bio_test.csv' # DÉBORA / OLIMAR / FILIPE

train_file = pd.read_csv(train_file_path)
val_file = pd.read_csv(val_file_path)
test_file = pd.read_csv(test_file_path)

print(train_file.head())

                                         medicamento  \
0  NPT - ZINCO+CUPRICO+MANGANES+CROMICO SOL INJ A...   
1                      AMOXICILINA SODICA 500MG CAPS   
2                              ROSUVASTATINA 10MG CP   
3                          Indometacina Cápsula 25mg   
4                             Iopamiron 370mg - 50ml   

                                          substancia  
0  SULFATO DE ZINCO HEPTAIDRATADO + SULFATO CUPRI...  
1                                        AMOXICILINA  
2                              ROSUVASTATINA CALCICA  
3                                       INDOMETACINA  
4                                          IOPAMIDOL  


In [10]:
train_data_2, val_file_2 = train_test_split(val_file, test_size=0.067, random_state=42, shuffle=True)
train_data_2.describe()

Unnamed: 0,medicamento,substancia
count,14176,14176
unique,13743,1511
top,"CLONAZEPAM 2,5MG/ML","CLORETO DE SODIO 0,9% (SORO FISIOLOGICO)"
freq,8,206


In [11]:
val_file_2.describe()

Unnamed: 0,medicamento,substancia
count,1019,1019
unique,1017,531
top,MICOFENOLATO DE MOFETILA 50MG/ML SOL ORAL 50ML,"CLORETO DE SODIO 0,9% (SORO FISIOLOGICO)"
freq,2,14


In [20]:
train_file_path_2 = '/bio_train_2.csv'
train_file_path_all = '/bio_train_all.csv'
val_file_path_2 = '/bio_val_2.csv'
test_file_path_2 = '/bio_test_2.csv'

train_file_full = pd.concat([train_file, train_data_2], ignore_index=True)

train_file_full.to_csv(train_file_path_all, index=False)
val_file_2.to_csv(val_file_path_2, index=False)
test_file.to_csv(test_file_path_2, index=False)

In [18]:
train_file_full.describe()

Unnamed: 0,medicamento,substancia
count,78972,78972
unique,69314,2173
top,"CLONAZEPAM 2,5MG/ML","CLORETO DE SODIO 0,9% (SORO FISIOLOGICO)"
freq,39,1110


In [None]:
# 1. Preparar o Corpus do Flair
# Colunas: texto (medicamentos) e rótulo (substancias)
column_name_map = {0: "text", 1: "label"}  # Mapear as colunas com base no índice

corpus = CSVClassificationCorpus(
    data_folder="",
    column_name_map=column_name_map,
    train_file=train_file_path_all,
    test_file=test_file_path_2,
    dev_file=val_file_path_2,
    label_type='substancia',  # Tipo de label definido
    delimiter=','
)

# 2. Definir o tipo de rótulo
label_type = 'substancia'  # Ajuste o nome do tipo de rótulo se necessário

# 3. Criar o dicionário de rótulos
label_dict = corpus.make_label_dictionary(label_type=label_type)

# 4. Inicializar as embeddings do transformador usando o modelo BERT em português
document_embeddings = TransformerDocumentEmbeddings('neuralmind/bert-base-portuguese-cased', fine_tune=True)

# 5. Criar o classificador de texto
classifier = TextClassifier(document_embeddings, label_dictionary=label_dict, label_type=label_type)

# 6. Inicializar o treinador
trainer = ModelTrainer(classifier, corpus)

# 7. Executar o treinamento com fine-tuning
trainer.train('/substances-classification-with-flair',
                  learning_rate=0.1,
                  mini_batch_size=64,
                  patience=5,
                  max_epochs=200,
                  main_evaluation_metric=("macro avg", 'f1-score'))