In [1]:
import os
import numpy as np
import pandas as pd
import altair as alt
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

# Carregando dados

In [2]:
db = pd.DataFrame()
lista = os.listdir('../data/models')
lista.remove('.ipynb_checkpoints')
for i in lista:
    _ = np.loadtxt(f'../data/models/{i}/learning_process.txt')
    db = pd.concat([db,pd.DataFrame(_, columns=['loss']).reset_index().rename(columns = {'index':'epoch'}).assign(model=i)])

# Curva de aprendizado

In [4]:
import altair as alt

error_band = alt.Chart(db).mark_errorband(extent='ci', opacity = 0.8).encode(
    x='epoch:Q',
    y=alt.Y('loss:Q', title='Loss'),
).properties(
    title='Error Band with Mean Line Plot',
    width=600,
    height=200
)

mean_line = alt.Chart(db).mark_line(color='red').encode(
    x='epoch:Q',
    y=alt.Y('mean(loss):Q', title='Mean Loss')
)

(error_band + mean_line).configure_mark(
    opacity=0.8
).configure_axis(
    labelFontSize=12,
    titleFontSize=14
).configure_title(
    fontSize=16
)

In [5]:
alt.Chart(db.query('epoch > 380')).mark_area(point = True).encode(
    x=alt.X('loss:Q', bin=alt.Bin(maxbins=10), title='Loss'),
    y='count()'
).properties(
    title='Histogram of Loss',
    width=300,
    height=200
).configure_axis(
    labelFontSize=12,
    titleFontSize=14
).configure_title(
    fontSize=16
)

# Melhores modelos

In [10]:
db[db['epoch'] > 300].groupby('model')['loss'].agg(['mean','std']).sort_values('mean')

Unnamed: 0_level_0,mean,std
model,Unnamed: 1_level_1,Unnamed: 2_level_1
2025-03-24_02-00-01.091682,0.399471,0.389711
2025-03-24_01-51-10.011417,0.405073,0.396757
2025-03-24_02-00-03.258666,0.434256,0.456004
2025-03-24_01-59-11.012045,0.451154,0.439005
2025-03-24_01-53-53.982726,0.457657,0.436402
2025-03-24_02-11-58.510897,0.46713,0.449582
2025-03-24_02-20-45.611328,0.473669,0.470798
2025-03-24_01-59-51.746760,0.478098,0.447331
2025-03-24_02-14-36.188079,0.480067,0.496847
2025-03-24_02-03-48.977940,0.496464,0.522504


In [11]:
import altair as alt

filtered_db = db[db['epoch'] > 300]

error_band = alt.Chart(filtered_db).mark_errorbar(extent='ci', opacity=0.8).encode(
    x=alt.X('model:O', sort=alt.EncodingSortField(field='mean(loss)', order='ascending')),
    y=alt.Y('loss:Q', title='Loss'),
).properties(
    title='Error Band with Mean Line Plot',
    width=600,
    height=200
)

mean_line = alt.Chart(filtered_db).mark_circle(color='red', size = 100).encode(
    x=alt.X('model:O', sort=alt.EncodingSortField(field='mean(loss)', order='ascending')),
    y=alt.Y('mean(loss):Q', title='Mean Loss'),
)

(error_band + mean_line).configure_mark(
    opacity=0.8
).configure_axis(
    labelFontSize=12,
    titleFontSize=14
).configure_title(
    fontSize=16
)
