In [7]:
import os

import pandas as pd
import torch

from money_counter.models import VersionManager

In [32]:
MODEL_STATE_DIR = f'../model_state' # Relative to the notebook of the project

In [9]:
version_manager = VersionManager(MODEL_STATE_DIR)

In [10]:
def add(df: pd.DataFrame, data) -> pd.DataFrame:
    last_idx = len(df)
    df.loc[last_idx] = data # type: ignore
    return df


In [26]:
import re


def enumerate_states(model_state_dir):
    """
    Enumerates all the model states in the model_state_dir.
    :param model_state_dir: The directory where the model states are stored.
    :return: A generator of tuples (model name, epoch number, full path)
    """
    for model_name in os.listdir(MODEL_STATE_DIR):
        model_dir = os.path.join(MODEL_STATE_DIR, model_name)
        for model_state in os.listdir(model_dir):
            path = os.path.join(model_dir, model_state)
            epoch_number = re.match(".*epoch_(\\d+).pth", path).group(1)

            yield model_name, int(epoch_number), path


In [36]:
import re
import time

df = pd.DataFrame(columns=['model', 'loss', 'epoch', 'path'])

started_at = time.time()

for model_name, epoch_number, path in enumerate_states(MODEL_STATE_DIR):
    loaded = torch.load(path)    
    # add to the dataframe
    row = {
        'model': model_name,
        'epoch': epoch_number,
        'loss': loaded['loss'],
        'path': path
    }

    df = add(df, row)

    if len(df) % 20 == 0:
        print(f'Processed {len(df)} states. Current speed: {len(df) / (time.time() - started_at)} states per second')


df = df.sort_values(by=['loss'])


Processed 20 states. Current speed: 0.9094231909283957 states per second


In [38]:
# group by model and show the lowest losses
top_states = df.groupby('model').head(5)
top_states

Unnamed: 0,model,loss,epoch,path
16,fasterrcnn_resnet50_fpn_v2-pretrained.bak,3.248029,91,../model_state\fasterrcnn_resnet50_fpn_v2-pret...
14,fasterrcnn_resnet50_fpn_v2-pretrained.bak,3.248029,70,../model_state\fasterrcnn_resnet50_fpn_v2-pret...
11,fasterrcnn_resnet50_fpn_v2-pretrained.bak,3.267834,18,../model_state\fasterrcnn_resnet50_fpn_v2-pret...
17,fasterrcnn_resnet50_fpn_v2-pretrained.bak,3.284652,111,../model_state\fasterrcnn_resnet50_fpn_v2-pret...
20,fasterrcnn_resnet50_fpn_v2-pretrained.bak,3.291085,149,../model_state\fasterrcnn_resnet50_fpn_v2-pret...
0,fasterrcnn_resnet50_fpn-untrained,5.303829,0,../model_state\fasterrcnn_resnet50_fpn-untrain...
10,fasterrcnn_resnet50_fpn_v2-pretrained,97.271071,54,../model_state\fasterrcnn_resnet50_fpn_v2-pret...
5,fasterrcnn_resnet50_fpn_v2-pretrained,97.706293,48,../model_state\fasterrcnn_resnet50_fpn_v2-pret...
9,fasterrcnn_resnet50_fpn_v2-pretrained,98.288877,53,../model_state\fasterrcnn_resnet50_fpn_v2-pret...
3,fasterrcnn_resnet50_fpn_v2-pretrained,98.938254,45,../model_state\fasterrcnn_resnet50_fpn_v2-pret...


In [39]:
for model_name, epoch, path in enumerate_states(MODEL_STATE_DIR):
	# check if path is in the top_states dataframe
	if path not in top_states['path'].values:
		print(f'Removing {path}...', end='')
		os.remove(path)
		print('removed.')


Removing ../model_state\fasterrcnn_resnet50_fpn_v2-pretrained\epoch_043.pth...removed.
Removing ../model_state\fasterrcnn_resnet50_fpn_v2-pretrained\epoch_044.pth...removed.
Removing ../model_state\fasterrcnn_resnet50_fpn_v2-pretrained\epoch_047.pth...removed.
Removing ../model_state\fasterrcnn_resnet50_fpn_v2-pretrained\epoch_051.pth...removed.
Removing ../model_state\fasterrcnn_resnet50_fpn_v2-pretrained\epoch_052.pth...removed.
Removing ../model_state\fasterrcnn_resnet50_fpn_v2-pretrained.bak\epoch_038.pth...removed.
Removing ../model_state\fasterrcnn_resnet50_fpn_v2-pretrained.bak\epoch_043.pth...removed.
Removing ../model_state\fasterrcnn_resnet50_fpn_v2-pretrained.bak\epoch_081.pth...removed.
Removing ../model_state\fasterrcnn_resnet50_fpn_v2-pretrained.bak\epoch_113.pth...removed.
Removing ../model_state\fasterrcnn_resnet50_fpn_v2-pretrained.bak\epoch_116.pth...removed.
Removing ../model_state\fasterrcnn_resnet50_fpn_v2-untrained\epoch_069.pth...removed.
Removing ../model_state\