In [None]:
import sys
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd

In [None]:
# Set runningOnColab = True if you are running this a google colab session.
runningOnColab = False

if runningOnColab:
    from google.colab import drive
    drive.mount('/content/drive')
    sys.path.append('drive/My Drive/Colab Notebooks/electricity-theft-detection-with-self-attention')

In [None]:
from CNN_model import CNNModel
from Att_Augmented_Google_model import GoogleFullModel
from Hybrid_Attn import HybridAttentionModel
from train import perform_kfold_cv
from radam import RAdam
from data import download_data,get_processed_dataset

In [None]:
download_data()
df = get_processed_dataset('data.csv')
df.head()

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device}')

In [None]:
random_state   = 12
reproductivity = True

if reproductivity:
    manualSeed = 13

    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)

    # If using GPU
    if torch.cuda.is_available():
        torch.cuda.manual_seed(manualSeed)
        torch.cuda.manual_seed_all(manualSeed)

    torch.backends.cudnn.enabled = False 
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    print('Reproducing experiment with seed:',manualSeed)
    print('Using random state:',random_state)

else:
    print('Random Experiment')

In [None]:
k_folds = 5
lr = 0.001
models = [HybridAttentionModel().to(device) for _ in range(k_folds)]
optims = [RAdam(model.parameters(), lr) for model in models]
criterion = nn.CrossEntropyLoss()

In [None]:
f1_per_fold = perform_kfold_cv(df, models, optims, criterion, k_folds, device=device, n_epochs=1)

In [None]:
best_fold = f1_per_fold.index(sorted(f1_per_fold, key=lambda x:x[0], reverse=True)[0]) + 1

best_f1, best_epoch,_,_ = f1_per_fold[best_fold-1]
print(f'The best fold ,was {best_fold} with F1 of {best_f1} at epoch {best_epoch}')

In [None]:
model = HybridAttentionModel().to(device)
model.load_state_dict(torch.load(os.path.join('att_models', f'fold_{best_fold}', f'epoch_{best_epoch}.pth')))