# Load forecasting on Smart Meter

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import torch
from Data import load_data
from Data import setup_seed
from Client import Client
from Server import Server

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
seed = 99

rounds= 100
# local fine-tuning round for federated learning 
finetuning_rounds = 30
learning_rate = 5e-4
mu = 0.5
gamma = 0.5

### Choose Dataset

In [None]:
path_BDG2 = 'BDG2_dataset/'

choose = ['Hog_office_Bessie', 'Hog_office_Napoleon',
          'Hog_office_Sydney', 'Hog_office_Mike',
          'Hog_office_Merilyn', 'Hog_education_Donnie',
          'Hog_office_Shon', 'Hog_education_Jewel',
          'Hog_education_Jordan', 'Hog_office_Myles',
          'Hog_office_Almeda', 'Hog_office_Denita',
          'Hog_office_Lizzie', 'Hog_office_Mary',
          'Hog_office_Betsy', 'Hog_office_Bill',
          'Hog_office_Miriam', 'Hog_office_Valda',
          'Hog_office_Shawna', 'Hog_office_Shawnna',
          'Hog_office_Sherrie', 'Hog_education_Madge',
          'Robin_office_Maryann', 'Hog_education_Rachael',
          'Robin_office_Antonina', 'Robin_office_Victor',
          'Robin_office_Zelma', 'Robin_office_Serena',
          'Robin_office_Sammie', 'Robin_office_Addie']
    
datas = load_data(path=path_BDG2, postfix='*.csv', choose=choose)

### Import Data into Smart Meter

In [None]:
setup_seed(seed)
clients = []

for i in range(len(choose)):
    data = datas[i]
    client = Client(data = data,
                    datas = datas, 
                    lr=learning_rate)
    clients.append(client)
    
server = Server(data = datas[0],
                datas = datas,
                clients = clients, 
                lr = learning_rate,
                rounds = rounds,
                mu = mu,
                gamma = gamma)

### Train and Test Models

#### Centralized

In [None]:
server.centralized_train()

#### Local

In [None]:
server.local_train()

#### FedAvg

In [None]:
server.fed_train()

#### Split

In [None]:
server.split_train()

#### SFLV1

In [None]:
server.sflv1_train()

#### SFLV2

In [None]:
server.sflv2_train()

#### Proposed

In [None]:
server.distillation_train()