# Quilt on the Synthetic and Real Datasets
### This Jupyter Notebook simulates holistic Quilt method on the synthetic and real data.

## Import libraries

In [1]:
import time
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from sklearn.metrics import accuracy_score, f1_score
from sklearn.preprocessing import StandardScaler

In [2]:
import EarlyStopping
from utils import prepare_data
from bayes_opt import BayesianOptimization

from DSS.simpleNN_net import TwoLayerNet
from DSS.dataloader.quiltdataloader import QuiltDataLoader
from dotmap import DotMap

## Set CUDA

In [4]:
cuda = True if torch.cuda.is_available() else False

if cuda:
    device = 'cuda:5'

In [3]:
# loss functions in model training
criterion = nn.CrossEntropyLoss()
criterion_nored = nn.CrossEntropyLoss(reduction='none')

In [5]:
def Quilt_cv(gain_th, disp_th):
    """Quilt performance function for bayesian optimization.
    
    Args:
        gain_th: Gain threshold value.
        disp_th: Disparity threshold value.
    Returns:
        Minus value of minimum validation loss.
    """
    # initialize model
    model = TwoLayerNet(n_feature, n_class, n_hidden, s)
    model = model.to(device)

    # set model optimizer, scheduler, and earlystop path
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)
    earlystop_path=f'./ckpt/{dataset}_Quilt.pt'

    # set dataloader parameters
    dss_args = dict(model=model,
                    loss=criterion_nored,
                    eta=0.001,
                    num_classes=n_class,
                    num_epochs=2000,
                    device=device,
                    fraction=1.0,
                    init_budget=split1,
                    select_every=1,
                    kappa=0,
                    linear_layer=True,
                    selection_type='PerBatch',
                    groups=group,
                    x_all=x_all_scale,
                    y_all=y_all,
                    gain_th = gain_th,
                    disp_th = disp_th
                   )
    dss_args = DotMap(dss_args)
    
    # define Quilt dataloader
    dataloader = QuiltDataLoader(trainloader, valloader, dss_args, batch_size=128, shuffle=True, pin_memory=False)
    
    val_losses = list()
    val_acc = list()
    
    early_stopping = EarlyStopping.EarlyStopping(patience=10, delta=0.0001, path=earlystop_path)
    
    # model training with data segment selection
    for epoch in range(num_epochs):
        model.train()
        for k, (inputs, targets, weights) in enumerate(dataloader):
            inputs = inputs.to(device)
            targets = targets.to(device, non_blocking=True)
            weights = weights.to(device)  
            optimizer.zero_grad()
            outputs = model(inputs, last=False, freeze=False)
            losses = criterion_nored(outputs, targets)
            loss = torch.dot(losses, weights/(weights.sum()))
            loss.backward()
            optimizer.step()

        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        # model evaluation using validation set
        with torch.no_grad():
            for idx, (inputs, targets) in enumerate(valloader):
                inputs, targets = inputs.to(device), targets.to(device, non_blocking=True, dtype=torch.int64)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()*targets.size(0)
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
            val_losses.append(val_loss / val_total)
            val_acc.append(val_correct / val_total)

        scheduler.step(val_loss/val_total)    
        early_stopping(val_losses[-1], model)
        if early_stopping.early_stop:
            break

    return -np.min(val_losses)

In [6]:
def run_dss(dataloader, trainloader, valloader, testloader, model, optimizer, scheduler, num_epochs, earlystop_path):
    """Run data segment selection algorithm.
    
    Args:
        dataloader: Dataloader with data subset selection method.
        trainloader: Train dataloader.
        valloader: Validation dataloader.
        testloader: Test dataloader.
        model: Train model.
        optimizer: Model optimizer.
        scheduler: Learning rate scheduler.
        num_epochs: Number of maximum epochs.
        earlystop_path: Earlystop model checkpoint path.
    Returns:
        Earlystop epoch, best accuracy, best f1 score, and runtime.
    """
    val_losses = list() 
    tst_acc = list()
    tst_f1 = list()
    timing = list()
    
    # set model optimizer, scheduler, and earlystop path
    early_stopping = EarlyStopping.EarlyStopping(patience=10, delta=0.0001, path=earlystop_path)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)
    
    # model training with data segment selection
    for epoch in range(num_epochs):
        model.train()
        start_time = time.time()      
        for k, (inputs, targets, weights) in enumerate(dataloader):
            inputs = inputs.to(device)
            targets = targets.to(device, non_blocking=True)
            weights = weights.to(device)  
            optimizer.zero_grad()
            outputs = model(inputs, last=False, freeze=False)
            losses = criterion_nored(outputs, targets)
            loss = torch.dot(losses, weights/(weights.sum()))
            loss.backward()
            optimizer.step()
        epoch_time = time.time() - start_time
        timing.append(epoch_time)

        val_loss = 0
        val_total = 0
        
        model.eval()
        
        # model evaluation using validation set
        with torch.no_grad():
            for idx, (inputs, targets) in enumerate(valloader):
                inputs, targets = inputs.to(device), targets.to(device, non_blocking=True, dtype=torch.int64)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()*targets.size(0)
                val_total += targets.size(0)
            val_losses.append(val_loss / val_total)

        scheduler.step(val_loss/val_total)
        
        y_pred = []
        y_truth = []

        # model evaluation using test set
        with torch.no_grad():
            for _, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device, non_blocking=True, dtype=torch.int64)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                y_pred.append(predicted.cpu().detach().numpy())
                y_truth.append(targets.cpu().detach().numpy())

        y_pred = [item for sublist in y_pred for item in sublist]
        y_truth = [item for sublist in y_truth for item in sublist]

        tst_acc.append(accuracy_score(y_truth, y_pred))
        
        if n_class > 2:
            tst_f1.append(f1_score(y_truth, y_pred, average='weighted'))
        elif n_class == 2:
            tst_f1.append(f1_score(y_truth, y_pred))
            
        early_stopping(val_losses[-1], model)
            
        if early_stopping.early_stop:
            break

    timing_array = np.array(timing)
    
    tmp = 0
    mod_cum_timing = np.zeros(len(timing_array))
    for i in range(len(timing_array)):
        tmp += timing_array[i]
        mod_cum_timing[i] = tmp
    cum_timing = list(mod_cum_timing)
    
    best_ind = np.argmin(val_losses)
    
    return best_ind, tst_acc[best_ind], tst_f1[best_ind], cum_timing[-1]

In [7]:
dataset_li = ['SEA', 'Hyperplane', 'RandomRBF', 'Sine', 'Electricity', 'Weather', 'Spam', 'Usenet1', 'Usenet2']

for dataset in dataset_li:
    # load data, label, and concept drift points
    x_all = np.load(f'./dataset/{dataset}/data.npy')
    y_all = np.load(f'./dataset/{dataset}/label.npy')
    concept_drifts = np.load(f'./dataset/{dataset}/concept_drifts.npy')
    
    # number of classes in dataset
    if dataset == 'RandomRBF':
        n_class = 5
    else:
        n_class = 2
    
    # number of nodes in hidden layer
    n_hidden = 256
    
    # number of maximum epochs
    num_epochs=2000
    
    # number of available data in current segment
    if dataset in ['Spam', 'Usenet1', 'Usenet2']:
        n_train = int((len(x_all)/len(concept_drifts))*0.2)
    else:
        n_train = int((len(x_all)/len(concept_drifts))*0.1)
        
    # split available data into train and valid
    split1 = int(n_train*0.5)
    
    print('dataset: ', dataset)
    print('concept drifts: ', concept_drifts)
    print("----------------------------------------------------------------------------")
    print('method: Quilt')

    all_time = []
    all_acc = []
    all_acc_std = []
    all_f1 = []
    all_f1_std = []

    # consecutive training and evaluation
    for n in range(len(concept_drifts)):
        n_dataset = n+1
        n_feature = x_all.shape[1]

        # data preprocessing (scaling)
        scaler = StandardScaler()
        x_all_scale = scaler.fit_transform(x_all)

        all_time_li = []
        all_acc_li = []
        all_f1_li = []

        # repeat experiments with 5 different seeds
        for s in range(5):
            # initialize model
            model = TwoLayerNet(n_feature, n_class, n_hidden, s)
            model = model.to(device)

            # split data into segments
            if n == 0:
                group = [np.arange(split1)]
            else:
                arr = np.arange(n*int((len(x_all)/len(concept_drifts))))
                group = np.split(arr, n)
                last_batch = np.arange(n*int((len(x_all)/len(concept_drifts))), n*int((len(x_all)/len(concept_drifts)))+split1)
                group.append(last_batch)

            # prepare data for train, valid, and test
            dataset_all = range(n_dataset)
            train_ds, valid_ds, test_ds, _ = prepare_data(n, n_train, x_all_scale, y_all, concept_drifts, 
                                                     dataset_all, n_feature, device)

            trainloader = DataLoader(train_ds, batch_size=int((len(x_all)/len(concept_drifts))), shuffle=False)
            valloader = DataLoader(valid_ds, batch_size=128, shuffle=True)
            testloader = DataLoader(test_ds, batch_size=128, shuffle=True)

            # bayesian optimization to find disparity threshold value
            pbounds = {'gain_th': (0, 0), 'disp_th': (0, 2)}
            optimizer_bo = BayesianOptimization(f=Quilt_cv, pbounds=pbounds, random_state=s, 
                                                allow_duplicate_points=True)
            optimizer_bo.maximize(init_points=10, n_iter=10)

            # initialize model
            model = TwoLayerNet(n_feature, n_class, n_hidden, s)
            model = model.to(device)

            # set model optimizer, scheduler, and earlystop path
            optimizer = optim.Adam(model.parameters(), lr=1e-3)
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)
            earlystop_path=f'./ckpt/{dataset}_Quilt.pt'

            # set dataloader parameters
            dss_args = dict(model=model,
                            loss=criterion_nored,
                            eta=0.001,
                            num_classes=n_class,
                            num_epochs=2000,
                            device=device,
                            fraction=1.0,
                            init_budget=split1,
                            select_every=1,
                            kappa=0,
                            linear_layer=True,
                            selection_type='PerBatch',
                            groups=group,
                            x_all=x_all_scale,
                            y_all=y_all,
                            gain_th = optimizer_bo.max['params']['gain_th'],
                            disp_th = optimizer_bo.max['params']['disp_th']
                           )
            dss_args = DotMap(dss_args)

            # define Quilt dataloader
            dataloader = QuiltDataLoader(trainloader, valloader, dss_args, batch_size=128, shuffle=True, 
                                         pin_memory=False)

            # model training and evaluation
            best_ind, best_acc, best_f1, runtime = run_dss(dataloader, trainloader, valloader, testloader, 
                                   model, optimizer, scheduler, num_epochs, earlystop_path=f'./ckpt/{dataset}_Quilt.pt')

            all_time_li.append(runtime)
            all_acc_li.append(best_acc)
            all_f1_li.append(best_f1)

        all_time.append(np.mean(all_time_li))
        all_acc.append(np.mean(all_acc_li))
        all_acc_std.append(np.std(all_acc_li))
        all_f1.append(np.mean(all_f1_li))
        all_f1_std.append(np.std(all_f1_li))
    
    # print runtime, accuracy, and F1 score
    print('overall train time: %.3f' %(np.mean(all_time)))
    print('overall test acc: avg %.3f, std %.3f' %(np.mean(all_acc), np.mean(all_acc_std)))
    print('overall test f1: avg %.3f, std %.3f' %(np.mean(all_f1), np.mean(all_f1_std)))
        
    print('\n')

dataset:  SEA
concept drifts:  [ 2000  4000  6000  8000 10000 12000 14000 16000]
----------------------------------------------------------------------------
method: Quilt
overall train time: 2.092
overall test acc: avg 0.889, std 0.004
overall test f1: avg 0.910, std 0.003


dataset:  Hyperplane
concept drifts:  [ 2000  4000  6000  8000 10000 12000 14000 16000]
----------------------------------------------------------------------------
method: Quilt
overall train time: 1.871
overall test acc: avg 0.910, std 0.010
overall test f1: avg 0.911, std 0.010


dataset:  RandomRBF
concept drifts:  [ 2000  4000  6000  8000 10000 12000 14000 16000]
----------------------------------------------------------------------------
method: Quilt
overall train time: 2.775
overall test acc: avg 0.832, std 0.007
overall test f1: avg 0.831, std 0.008


dataset:  Sine
concept drifts:  [ 2000  4000  6000  8000 10000 12000 14000 16000]
--------------------------------------------------------------------------