In [None]:
#  GFLOPs calculation for PyTorch model

import torchprofile
import torch

path = "path/to/model.pth"

load_model = torch.jit.load(f"{path}")
#or
load_model = torch.load(f"{path}")

device = "mps" if torch.backends.mps.is_available() else "cpu"
load_model.to(device)

input_tensor = torch.randn(1, 3, 224, 224).to(device)
flops = torchprofile.profile_macs(load_model, input_tensor)  # MACs: Multiply-Accumulate operations
flops *= 2
gflops = flops / 1e9

print(f'GFLOPs: {gflops:.4f}')

In [None]:
# results saving operation to csv file

import time
import torch
import pandas as pd

path = "path/to/model.pth"
model_name = "model_name"

model = torch.jit.load(f"{path}")
#or
model = torch.load(f"{path}")

csv_path = '../results.csv'


total_params = sum(p.numel() for p in model.parameters())
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)



df = pd.read_csv(csv_path)
row = pd.DataFrame({
    'model': [model_name],
    'params': [total_params],
    'tr_params': [total_trainable_params],
    # 'learning_rate': [learning_rate],
    # 'batch': [batch_size],
    # 'accuracy_(Tr)': [train_accuracy],
    # 'accuracy_(Va)': [validation_accuracy],
    # 'precision_(Va)': [validation_precision],
    # 'recall_(Va)': [validation_recall],
    # 'accuracy_(Te)': [test_accuracy],
    # 'precision_(Te)': [test_precision],
    # 'recall_(Te)': [test_recall],
    # 'time_(s)': [computation_time]
}, index=[0])
print(df)

df = pd.concat([df, row], ignore_index=False)
df.to_csv(csv_path, index=False)