In [13]:
import torch
import torch.nn.functional as F
import os
import pandas as pd
import numpy as np
import warnings
import collections
from dataloader.dataloader import data_generator
from algorithms.RAINCOAT import RAINCOAT
from sklearn.metrics import f1_score, accuracy_score

### Dataset loading configurations

In [14]:
class WISDM_config():
    def __init__(self):
        super(WISDM_config, self)
        self.class_names = ['walk', 'jog', 'sit', 'stand', 'upstairs', 'downstairs']
        self.sequence_len = 128
        # Closed Set DA
        self.scenarios = [("2", "32"), ("4", "15"),("7", "30"),('12','7'), ('12','19'),('18','20'),\
                          ('20','30'), ("21", "31"),("25", "29"), ('26','2')]

        self.num_classes = 6
        self.shuffle = True
        self.drop_last = False
        self.normalize = True

        # model configs
        self.input_channels = 3
        self.kernel_size = 5
        self.stride = 1
        self.dropout = 0.5
        self.num_classes = 6
        self.width = 64  # for FNN
        self.fourier_modes = 64
        # features
        self.mid_channels = 64
        self.final_out_channels = 128
        self.out_dim = 192
        self.features_len = 1


data_path = './data/WISDM'
src_id = '2'
trg_id = '32'
dataset_configs = WISDM_config()
hparams = {"batch_size":64, 'learning_rate':1e-3,'weight_decay': 1e-4,'num_epochs': 100}

### Load Dataset

In [15]:
src_train_dl, src_test_dl = data_generator(data_path, src_id,dataset_configs, hparams)
trg_train_dl, trg_test_dl = data_generator(data_path, trg_id, dataset_configs,hparams)

In [16]:
def eval(algorithm, loader, final=False):
    device = 'cuda'
    feature_extractor = algorithm.feature_extractor.to(device)
    classifier = algorithm.classifier.to(device)
    if final == True:
        feature_extractor.load_state_dict(torch.load('backbone.pth'))
        classifier.load_state_dict(torch.load('classifier.pth'))
    feature_extractor.eval()
    classifier.eval()
    trg_pred_labels = np.array([])
    trg_true_labels = np.array([])
    with torch.no_grad():
        for data, labels in loader:
            data = data.float().to(device)
            labels = labels.view((-1)).long().to(device)
            features,_ = feature_extractor(data)
            predictions = classifier(features)

            pred = predictions.detach().argmax(dim=1)  # get the index of the max log-probability

            trg_pred_labels = np.append(trg_pred_labels, pred.cpu().numpy())
            trg_true_labels = np.append(trg_true_labels, labels.data.cpu().numpy())
    accuracy = accuracy_score(trg_true_labels, trg_pred_labels)
    f1 = f1_score(trg_pred_labels, trg_true_labels, pos_label=None, average="macro")
    return accuracy*100, f1

In [17]:

device = 'cuda'
algorithm = RAINCOAT(dataset_configs, hparams,device)
algorithm.to(device)
best_f1 = 0
for i in range(hparams['num_epochs']):
    joint_loaders = enumerate(zip(src_train_dl, trg_train_dl))
    for step, ((src_x, src_y), (trg_x, _)) in joint_loaders:
        src_x, src_y, trg_x = src_x.float().to(device), src_y.long().to(device), \
                                trg_x.float().to(device)
        losses = algorithm.update(src_x, src_y, trg_x)

    acc, f1 = eval(algorithm, src_train_dl)
    if f1>=best_f1:
        best_f1 = f1
        torch.save(algorithm.feature_extractor.state_dict(), 'backbone.pth')
        torch.save(algorithm.classifier.state_dict(), 'classifier.pth')
    if i %10==0:
        print(f'Epoch {i}: Validation Accuracy on Source Test is {acc}')

tar_acc, tar_f1 = eval(algorithm, trg_test_dl,final=True)
print(f'Target Accuracy before correction:{tar_acc}, Target F1:{tar_f1}')
### Correction 
for i in range(10):
    joint_loaders = enumerate(zip(src_train_dl, trg_train_dl))
    for step, ((src_x, src_y), (trg_x, _)) in joint_loaders:
        src_x, src_y, trg_x = src_x.float().to(device), src_y.long().to(device), \
                                trg_x.float().to(device)
        losses = algorithm.correct(src_x, src_y, trg_x)

    acc, f1 = eval(algorithm, src_train_dl)
    if f1>=best_f1:
        best_f1 = f1
        torch.save(algorithm.feature_extractor.state_dict(), 'backbone.pth')
        torch.save(algorithm.classifier.state_dict(), 'classifier.pth')
    if i %10==0:
        print(f'Epoch {i}: Validation Accuracy on Source Test is {acc}')
tar_acc, tar_f1 = eval(algorithm, trg_train_dl,final=True)
print(f'Target Accuracy after correction:{tar_acc}, Target F1:{tar_f1}')

Epoch 0: Validation Accuracy on Source Test is 25.78125
Epoch 10: Validation Accuracy on Source Test is 57.8125
Epoch 20: Validation Accuracy on Source Test is 68.75
Epoch 30: Validation Accuracy on Source Test is 86.71875
Epoch 40: Validation Accuracy on Source Test is 93.75
Epoch 50: Validation Accuracy on Source Test is 100.0
Epoch 60: Validation Accuracy on Source Test is 100.0
Epoch 70: Validation Accuracy on Source Test is 100.0
Epoch 80: Validation Accuracy on Source Test is 100.0
Epoch 90: Validation Accuracy on Source Test is 100.0
Target Accuracy before correction:72.46376811594203, Target F1:0.5899758454106281
Epoch 0: Validation Accuracy on Source Test is 100.0
Target Accuracy after correction:78.125, Target F1:0.7123949579831933
