In [1]:
import numpy as np
import pandas as pd
import altair as alt
import os
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

# Carregando dados

In [2]:
db = pd.DataFrame()
for i in os.listdir('../models'):
    _ = np.loadtxt(f'../models/{i}/learning_process.txt')
    db = pd.concat([db,pd.DataFrame(_, columns=['loss']).reset_index().rename(columns = {'index':'epoch'}).assign(model=i)])

# Curva de aprendizado

In [3]:
import altair as alt

error_band = alt.Chart(db).mark_errorband(extent='ci', opacity = 0.8).encode(
    x='epoch:Q',
    y=alt.Y('loss:Q', title='Loss'),
).properties(
    title='Error Band with Mean Line Plot',
    width=600,
    height=200
)

mean_line = alt.Chart(db).mark_line(color='red').encode(
    x='epoch:Q',
    y=alt.Y('mean(loss):Q', title='Mean Loss')
)

(error_band + mean_line).configure_mark(
    opacity=0.8
).configure_axis(
    labelFontSize=12,
    titleFontSize=14
).configure_title(
    fontSize=16
)

In [4]:
alt.Chart(db.query('epoch > 380')).mark_area(point = True).encode(
    x=alt.X('loss:Q', bin=alt.Bin(maxbins=10), title='Loss'),
    y='count()'
).properties(
    title='Histogram of Loss',
    width=300,
    height=200
).configure_axis(
    labelFontSize=12,
    titleFontSize=14
).configure_title(
    fontSize=16
)



# Melhores modelos

In [5]:
db[db['epoch'] > 300].groupby('model')['loss'].agg(['mean','std']).sort_values('mean')

Unnamed: 0_level_0,mean,std
model,Unnamed: 1_level_1,Unnamed: 2_level_1
2025-03-23_04-02-15.866905,0.366110,0.392264
2025-03-23_09-52-17.692189,0.374574,0.413589
2025-03-22_21-16-15.989736,0.384999,0.345205
2025-03-23_12-04-38.190381,0.389730,0.385753
2025-03-23_09-01-54.680769,0.404880,0.459231
...,...,...
2025-03-23_07-37-24.155796,0.855297,0.851543
2025-03-23_04-58-21.580986,0.858699,0.974858
2025-03-23_07-09-54.352628,0.872724,0.915112
2025-03-22_23-21-55.790968,0.904050,0.859076


In [6]:
import altair as alt

filtered_db = db[db['epoch'] > 300]

error_band = alt.Chart(filtered_db).mark_errorbar(extent='ci', opacity=0.8).encode(
    x=alt.X('model:O', sort=alt.EncodingSortField(field='mean(loss)', order='ascending')),
    y=alt.Y('loss:Q', title='Loss'),
).properties(
    title='Error Band with Mean Line Plot',
    width=600,
    height=200
)

mean_line = alt.Chart(filtered_db).mark_circle(color='red', size = 100).encode(
    x=alt.X('model:O', sort=alt.EncodingSortField(field='mean(loss)', order='ascending')),
    y=alt.Y('mean(loss):Q', title='Mean Loss'),
)

(error_band + mean_line).configure_mark(
    opacity=0.8
).configure_axis(
    labelFontSize=12,
    titleFontSize=14
).configure_title(
    fontSize=16
)
