In [13]:
import wandb
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt



In [14]:
# Initialize the API
api = wandb.Api()

# Specify your project details
entity = "haeri-hsn"  # Replace with your wandb entity
project = "stream_learning"  # Replace with your wandb project name

# Define filters
filters = {
    'state': 'finished',  # Only fetch finished runs
    'tags': {'$in': ['test_num_anchors']}  # Runs containing a specific tag msmsa_horizon_analysis_melbourne_housing \ msmsa_anchor_analysis_melbourne_housing
}

# Query runs with filters
runs = api.runs(f"{entity}/{project}", filters=filters)

# print number of runs
print(f"Number of runs: {len(runs)}")

Number of runs: 60


In [15]:
run_data = []
for run in runs:
    # if ('msmsa_horizon_analysis_melbourne_housing' in run.tags) and len(run.config['hor_candids']) == 7:
    #     continue
    # print(run.summary)
    run_dict = {

        
        'id': run.id,
        'name': run.name,
        'method': run.config['method'],
        'dataset': run.config['dataset'],
        'base_learner': run.config['base_learner'],
        'hor_candids': run.config['hor_candids'],
        'num_anchors': run.config['num_anchors'],

        'MAE': run.config['MAE'],
        'RMSE': run.config['RMSE'],
        'R2': run.config['R2'],
        'base_learner_params': run.config['base_learner_params'],

        # print(run.summary)
        'runtime': run.summary['_runtime'],  # Add runtime in seconds
        'num_timesteps': run.summary['_step'],  # Add runtime in seconds
        'num_train_samples': run.summary['num_train_samples'],
        'run_abs_error': run.summary['run_abs_error'],
        'run_y': run.summary['run_y'],
        'run_y_pred': run.summary['run_y_pred'],
        'tags': run.tags,

        # 'hor_candids': run.config.hor_candids,
        # 'num_anchors': run.config.num_anchors,
        # 'updated_at': run.config.updated_at,
        # 'MAE': run.tags,
        # 'RMSE': run.notes,
        # 'R2': run.config.R2,
        # 'base_learner_params': run.config.base_learner_params,

        # 'num_train_samples': run.config.num_train_samples,
        # 'run_abs_error': run.configrun_abs_error,
        # 'run_y': run.config.run_y,
        # 'run_y_pred': run.config.run_y_pred,

    }
    run_data.append(run_dict)

# Convert to DataFrame
# df_normal = pd.DataFrame(run_data)
df = pd.DataFrame(run_data)
df['num_hor_candids'] = df['hor_candids'].apply(len)

# print distinct num_hor_candids
print(f"Distinct num_hor_candids: {df['num_hor_candids'].unique()}")

Distinct num_hor_candids: [16]


In [12]:
# combine df_uniform df_normal and df_exact into one dataframe add include a column for the type of data
df_normal['type'] = 'normal'
df_uniform['type'] = 'uniform'
df_exact['type'] = 'exact'

df = pd.concat([df_normal, df_uniform, df_exact])

# pickle the dataframe
df.to_pickle('df_anchor_distributions.pkl')


## Horizon Analysis

In [29]:
%matplotlib qt

plt.close('all')
# set theme for seaborn
sns.set_theme(style='whitegrid')


# assuming hor_candids is a list of integers, make another column (num_hor_candids) with the length of the list
df['num_hor_candids'] = df['hor_candids'].apply(len)

# remove rows with 
df = df.sort_values(by='num_hor_candids')
# Now create another column where if num_hor_candids is 7, then the value is 'exponential(^2)', if  37, then 'exponential(^1.15)' and if 991 then 'full'
df['hor_candids_type'] = df['num_hor_candids'].apply(lambda x: 'b=1.50' if x == 15 else 'b=1.15' if x == 47 else 'b=1.10' if x == 68 else 'Full')
colors = ['#E57439', '#EDB732', '#A0C75C', '#5387DD']
# create a sns barplot of MAE, RMSE, R2 for each hor_candids configuration
def plot_metrics(df, metric, title):

    plt.figure(figsize=(4, 4))
    sns.barplot(x='hor_candids_type', y=metric, data=df, hue='hor_candids_type', palette=colors, width=0.5)
    # plt.title(f'{metric} for different horizon candidates setting')
    # remove x label
    plt.xlabel(title)
    # if metric is runtime add a [s] to the y label
    plt.ylabel(f'Runtime [sec]' if metric == 'runtime' else metric)
    plt.tight_layout()
    plt.show()


plot_metrics(df, 'MAE', '(a)')
plot_metrics(df, 'RMSE', '(b)')
plot_metrics(df, 'R2', '(c)')
plot_metrics(df, 'runtime', '(d)')



  sns.barplot(x='hor_candids_type', y=metric, data=df, hue='hor_candids_type', palette=colors, width=0.5)
  sns.barplot(x='hor_candids_type', y=metric, data=df, hue='hor_candids_type', palette=colors, width=0.5)
  sns.barplot(x='hor_candids_type', y=metric, data=df, hue='hor_candids_type', palette=colors, width=0.5)
  sns.barplot(x='hor_candids_type', y=metric, data=df, hue='hor_candids_type', palette=colors, width=0.5)


## Num Anchor points analysis

In [31]:
%matplotlib qt

plt.close('all')
# set theme for seaborn
sns.set_theme(style='whitegrid')

colors = ['#57A6A1', '#577B8D', '#344C64', '#2D3047']

# create a sns barplot of MAE, RMSE, R2 for each hor_candids configuration
def plot_metrics(df, metric):

    plt.figure(figsize=(3, 4))
    sns.barplot(x='num_anchors', y=metric, data=df, hue='num_anchors', palette='coolwarm', width=0.5)
    # plt.title(f'{metric} for different distribution of anchor points')
    # remove x label
    plt.xlabel('Number of Anchor Points')
    # if metric is runtime add a [s] to the y label
    plt.ylabel(f'Runtime [sec]' if metric == 'runtime' else metric)
    plt.tight_layout()
    # replace the legend title with 'Number of anchor points'
    plt.legend(title='Number of anchor points')
    plt.show()


plot_metrics(df, 'MAE')
plot_metrics(df, 'RMSE')
plot_metrics(df, 'R2')
plot_metrics(df, 'runtime')

## Anchor Distribution Analysis

In [10]:
%matplotlib qt

plt.close('all')
# set theme for seaborn
sns.set_theme(style='whitegrid')

colors = ['#FF5733', '#33FF57', '#3357FF']

# create a sns barplot of MAE, RMSE, R2 for each hor_candids configuration
def plot_metrics(df, metric):

    plt.figure(figsize=(3, 4))
    sns.barplot(x='type', y=metric, data=df, hue='type', palette=colors, width=0.5, order=['uniform', 'normal', 'exact'])
    # plt.title(f'{metric} for different distribution of anchor points')
    # remove x label
    plt.xlabel('Distribution of Anchor Points')
    # if metric is runtime add a [s] to the y label
    plt.ylabel(f'Runtime [sec]' if metric == 'runtime' else metric)
    plt.tight_layout()
    plt.show()


plot_metrics(df, 'MAE')
plot_metrics(df, 'RMSE')
plot_metrics(df, 'R2')
plot_metrics(df, 'runtime')

ValueError: Could not interpret value `type` for `x`. An entry with this name does not appear in `data`.

In [14]:
# open melbourne_housing_clean.csv from datasets folder into df

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib qt
df = pd.read_csv('datasets/melbourne_housing_clean.csv')

# # plot the distribution of Latitude and Longitude
# plt.figure(figsize=(6, 6))
# sns.scatterplot(x='Lattitude', y='Longtitude', data=df, alpha=0.05, s=10, color='blue')
# plt.title('Distribution of Latitude and Longitude')
# plt.tight_layout()
# plt.show()

# # plot the distribution of Price
# plt.figure(figsize=(6, 4))
# sns.histplot(df['Price'], bins=100)
# plt.title('Distribution of Price')
# plt.xlabel('Price')
# plt.ylabel('Frequency')
# plt.tight_layout()


# lineplot the average monthly price over Date



df['Date'] = pd.to_datetime(df['Date'])
df['Month'] = df['Date'].dt.month
df['Year'] = df['Date'].dt.year
df['MonthYear'] = df['Date'].dt.to_period('M')

df['MonthYear'] = df['MonthYear'].astype(str)

df = df.groupby('MonthYear').mean().reset_index()


df['MonthYear'] = pd.to_datetime(df['MonthYear'])


plt.figure(figsize=(10, 4))
sns.lineplot(x='MonthYear', y='Price', data=df)
plt.title('Average Monthly Price')
plt.xlabel('Date')
plt.ylabel('Price (x1000 AUD)')
# make sure xlim is set to min and max of the Date column
plt.xlim(df['MonthYear'].min(), df['MonthYear'].max())
plt.grid(True)
plt.tight_layout()
plt.show()





In [11]:
df['MonthYear']

0    2016-02-01
1    2016-04-01
2    2016-05-01
3    2016-06-01
4    2016-07-01
5    2016-08-01
6    2016-09-01
7    2016-10-01
8    2016-11-01
9    2016-12-01
10   2017-02-01
11   2017-03-01
12   2017-04-01
13   2017-05-01
14   2017-06-01
15   2017-07-01
16   2017-08-01
17   2017-09-01
18   2017-10-01
19   2017-11-01
20   2017-12-01
21   2018-01-01
22   2018-02-01
23   2018-03-01
Name: MonthYear, dtype: datetime64[ns]