In [12]:
import torch
import torch.nn.functional as F
import os
import pandas as pd
import numpy as np
import warnings
import collections
from dataloader.dataloader import data_generator
from algorithms.RAINCOAT import RAINCOAT
from sklearn.metrics import f1_score, accuracy_score

### Dataset loading configurations

In [14]:
class HAR_config():
    def __init__(self):
        super(HAR_config, self)
        self.class_names = ['walk', 'upstairs', 'downstairs', 'sit', 'stand', 'lie']
        self.sequence_len = 128
        self.shuffle = True
        self.drop_last = True
        self.normalize = True
        self.input_channels = 9
        self.kernel_size = 5
        self.stride = 1
        self.dropout = 0.5
        self.num_classes = 6
        self.fourier_modes = 64
        self.out_dim = 192
        self.mid_channels = 64
        self.final_out_channels = 128
        self.features_len = 1

data_path = './data/HAR'
src_id = '2'
trg_id = '11'
dataset_configs = HAR_config()
hparams = {"batch_size":64, 'learning_rate':5e-4,'weight_decay': 1e-4,'num_epochs': 50}

### Load Dataset

In [9]:
src_train_dl, src_test_dl = data_generator(data_path, src_id,dataset_configs, hparams)
trg_train_dl, trg_test_dl = data_generator(data_path, trg_id, dataset_configs,hparams)

In [16]:
def eval(algorithm, loader, final=False):
    device = 'cuda'
    feature_extractor = algorithm.feature_extractor.to(device)
    classifier = algorithm.classifier.to(device)
    # if final == True:
    #     feature_extractor.load_state_dict(torch.load(fpath))
    #     classifier.load_state_dict(torch.load(cpath))
    feature_extractor.eval()
    classifier.eval()
    trg_pred_labels = np.array([])
    trg_true_labels = np.array([])
    with torch.no_grad():
        for data, labels in loader:
            data = data.float().to(device)
            labels = labels.view((-1)).long().to(device)
            features,_ = feature_extractor(data)
            predictions = classifier(features)

            pred = predictions.detach().argmax(dim=1)  # get the index of the max log-probability

            trg_pred_labels = np.append(trg_pred_labels, pred.cpu().numpy())
            trg_true_labels = np.append(trg_true_labels, labels.data.cpu().numpy())
    accuracy = accuracy_score(trg_true_labels, trg_pred_labels)
    f1 = f1_score(trg_pred_labels, trg_true_labels, pos_label=None, average="macro")
    return accuracy*100, f1

In [19]:

device = 'cuda'
algorithm = RAINCOAT(dataset_configs, hparams,device)
algorithm.to(device)
best_acc = 0
for i in range(hparams['num_epochs']):
    joint_loaders = enumerate(zip(src_train_dl, trg_train_dl))
    for step, ((src_x, src_y), (trg_x, _)) in joint_loaders:
        src_x, src_y, trg_x = src_x.float().to(device), src_y.long().to(device), \
                                trg_x.float().to(device)
        losses = algorithm.update(src_x, src_y, trg_x)

    acc, f1 = eval(algorithm, src_test_dl)
    if acc>=best_acc:
        best_acc = acc
        torch.save(algorithm.feature_extractor.state_dict(), 'backbone.pth')
        torch.save(algorithm.classifier.state_dict(), 'classifier.pth')
    print(f'Epoch {i}: Validation Accuracy on Source Test is {acc}')

tar_acc, tar_f1 = eval(algorithm, trg_test_dl,final=True)

Validation Accuracy on Source Test:10.9375
Validation Accuracy on Source Test:25.0
Validation Accuracy on Source Test:34.375
Validation Accuracy on Source Test:48.4375
Validation Accuracy on Source Test:73.4375
Validation Accuracy on Source Test:70.3125
Validation Accuracy on Source Test:79.6875
Validation Accuracy on Source Test:82.8125
Validation Accuracy on Source Test:84.375
Validation Accuracy on Source Test:95.3125
Validation Accuracy on Source Test:100.0
Validation Accuracy on Source Test:100.0
Validation Accuracy on Source Test:95.3125
Validation Accuracy on Source Test:100.0
Validation Accuracy on Source Test:100.0
Validation Accuracy on Source Test:100.0
Validation Accuracy on Source Test:100.0
Validation Accuracy on Source Test:100.0
Validation Accuracy on Source Test:100.0
Validation Accuracy on Source Test:100.0
Validation Accuracy on Source Test:100.0
Validation Accuracy on Source Test:100.0
Validation Accuracy on Source Test:100.0
Validation Accuracy on Source Test:100.0

KeyboardInterrupt: 