# Importing Functions 

In [1]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
import time 
from torch.autograd import grad
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch import Tensor
device = torch.device('cuda')
seed = 0
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Models

In [2]:
class Encoder_AMDA(nn.Module):
    def __init__(self):
        super(Encoder_AMDA, self).__init__()
        self.encoder = nn.Sequential(
            # first layer  4096*1-->  1017*8
            nn.Conv1d(1, 8, kernel_size=32,stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2),
            # second layer  1017*8-->  250*16
            nn.Conv1d(8, 16, kernel_size=16,stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2),
            # third layer  250*16-->  60*32
            nn.Conv1d(16, 32, kernel_size=8,stride=2,padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2), 
            # fourth layer 60*32--> 14*32
            nn.Conv1d(32, 32, kernel_size=8,stride=2,padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2),
            # fifth layer 14*32--> 3*64
            nn.Conv1d(32, 64, kernel_size=3,stride=2,padding=1),
            nn.MaxPool1d(2))
         # flatenning wit fully connected layers
        self.fc1 = nn.Linear(256, 256)# optimal when 0 source domain

    def forward(self, input):
        conv_out = self.encoder(input)
        feat = self.fc1(conv_out.view(conv_out.shape[0],-1))
        return feat        
    """classifier model for AMDA."""
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.fc2 = nn.Linear(256, 3) #this with 0 as souce it give optimal results 
  

    def forward(self, feat):
        out = F.dropout(F.relu(feat), training=self.training)
        out = self.fc2(out)
        return out  
class Discriminator(nn.Module):
    """Discriminator model for source domain."""

    def __init__(self, input_dims, hidden_dims, output_dims):
        super(Discriminator, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, output_dims),
            nn.LogSoftmax())

    def forward(self, input):
        out = self.layer(input)
        return out

# Data Loading 

### Real KAt Data 

In [5]:
real=torch.load('../data_5120L_new/real_domains_raw.pt')
wk_cond_a_full,wk_cond_b_full,wk_cond_c_full,wk_cond_d_full=real

### Domain Shift Scenario

In [6]:
#selecting source and target domain 
source_domain,source_labels,test_data,test_labels=wk_cond_a_full
target_domain,target_labels,target_test,target_test_labels=wk_cond_b_full
target_domain_1,target_labels_1,target_test_1,target_test_labels_1=wk_cond_c_full
target_domain_2,target_labels_2,target_test_2,target_test_labels_2=wk_cond_d_full

sample_length=source_domain.size(1)
num_samples=source_domain.size(0)
num_test_samples=test_data.size(0)
num_target_samples= target_test.size(0)

# Parameters

In [7]:
"""Params for AMDA."""
# params for setting up models
d_input_dims = 256
d_hidden_dims = 256
d_output_dims = 2
# params for training network
num_epochs_pre = 30 
log_step_pre = 5
eval_step_pre = 20
save_step_pre = 100
num_epochs = 5 
log_step = 1
save_step = 100
# params for optimizing models
d_learning_rate = 1e-4
c_learning_rate = 1e-4
c_init_learning_rate = 1e-4
beta1 = 0.5
beta2 = 0.9
bs=100 #20

# Pre-Training 

In [8]:
def train_src(encoder, classifier):
    """Train classifier for source domain."""
    ####################
    # 1. setup network #
    ####################
    # set train state for Dropout and BN layers
    src_encoder.train()
    classifier.train()
    # setup criterion and optimizer
    optimizer = optim.Adam(
        list(encoder.parameters()) + list(classifier.parameters()),
        lr=c_init_learning_rate,
        betas=(beta1, beta2))
    criterion = nn.CrossEntropyLoss()
    shuffled_indices=torch.randperm(num_samples)
    ####################
    # 2. train network #
    ####################
    t0 = time.time()
    for epoch in range(num_epochs_pre):
        running_loss=0
        running_accuracy=0
        num_batches=0
        shuffled_indices=torch.randperm(num_samples)
        for step in range(0,num_samples,bs):
            # shuffled data samples
            indices=shuffled_indices[step:step+bs]
            # training on target domain_a as a source
            minibatch_data =  Variable(source_domain[indices].unsqueeze(dim=1))
            minibatch_label=  Variable(source_labels[indices].squeeze())
            minibatch_data=minibatch_data.to(device)
            minibatch_label=minibatch_label.to(device)
    
            # zero gradients for optimizer
            optimizer.zero_grad()
            
            # compute loss for critic
            preds =classifier (encoder(minibatch_data.float()))
            loss = criterion(preds, minibatch_label)
            
            # optimize source classifier
            loss.backward()
            optimizer.step()
            running_loss += loss.detach().item()
        
            running_accuracy += (preds.max(1)[1] == minibatch_label).float().mean().item()
            num_batches+=1
        # print epoch info
        if ((epoch) % log_step_pre == 0):
            print("Epoch [{}/{}] : loss={} train_accuracy={}"
                  .format(epoch,
                          num_epochs_pre,
                          running_loss/num_batches,
                        running_accuracy*100/num_batches ))
       
        print('{} seconds'.format(time.time() - t0))
        # eval model on test set
        if ((epoch + 1) % eval_step_pre == 0):
            eval_src(encoder, classifier)
    return encoder, classifier


def eval_src(encoder, classifier):
    """Evaluate classifier for source domain."""
    # set eval state for Dropout and BN layers
    encoder.eval()
    classifier.eval()
    # init loss and accuracy
    loss = 0
    acc = 0
    num_batches=0
    mean_loss=0
    run_loss=0
    # set loss function
    criterion = nn.CrossEntropyLoss()
    x=[]
    y=[]
    shuffled_indices=torch.randperm(num_test_samples) #[0:4000]

#     bs=20
    for i in range(0,num_test_samples,bs):
        indices=shuffled_indices[i:i+bs]
        minibatch_data =  test_data[indices].unsqueeze(dim=1)
        minibatch_label= test_labels[indices].squeeze()
        minibatch_data=minibatch_data.to(device)
        minibatch_label=minibatch_label.to(device)
        
        # forward pass
        scores=classifier(encoder(minibatch_data.float()))
        # calculate accuracy                      
        acc += (scores.max(1)[1] == minibatch_label).float().mean().item()
        x.append(scores.max(1)[1])
        y.append(minibatch_label)
        loss = criterion(scores, minibatch_label)
        num_batches+=1
        run_loss += loss.detach().item()
    mean_accuracy = acc / num_batches
    mean_loss = run_loss / num_batches
    
    print("Avg Loss = {}, Avg Accuracy = {}".format( mean_loss, mean_accuracy*100))
    return (x,y)

### Adaptation Fucntion 

In [9]:
def train_tgt(src_encoder, tgt_encoder, critic):
    """Train encoder for target domain."""
    ####################
    # 1. setup network #
    ####################
    # set train state for Dropout and BN layers
    tgt_encoder.train()
    critic.train()
    # setup criterion and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer_tgt = optim.Adam(tgt_encoder.parameters(),
                               lr=c_learning_rate,
                               betas=(beta1, beta2))
    optimizer_critic = optim.Adam(critic.parameters(),
                                  lr=d_learning_rate,
                                  betas=(beta1, beta2))
    len_data_loader =num_samples
    t1 = time.time()

    ####################
    # 2. train network #
    ####################
    for epoch in range(num_epochs):
        # initilize loss
        run_critic_loss=0
        acc=0
        run_tgt_loss=0
        num_batches=0
        shuffled_indices=torch.randperm(num_samples)#[0:4000]
        # zip source and target data pair
        for i in range(0,num_samples,bs):
            ###########################
            # 2.1 train discriminator #
            ###########################
            indices=shuffled_indices[i:i+bs]
            # make  variable
            sample_src= Variable(source_domain[indices].float().unsqueeze(dim=1)).to(device)
            sample_tgt_1=Variable(target_domain[indices].float().unsqueeze(dim=1)).to(device)
            sample_tgt_2=Variable(target_domain_1[i:i+bs].float().unsqueeze(dim=1)).to(device)
            sample_tgt_3=Variable(target_domain_2[i:i+bs].float().unsqueeze(dim=1)).to(device)

            # zero gradients for optimizer
            optimizer_critic.zero_grad()

            # extract and concat features
            feat_src = src_encoder(sample_src)
            feat_tgt_1 = tgt_encoder(sample_tgt_1)
            feat_tgt_2 = tgt_encoder(sample_tgt_2)
            feat_tgt_3 = tgt_encoder(sample_tgt_3)
            feat_tgt= torch.cat((feat_tgt_1,feat_tgt_2,feat_tgt_3),0)
            feat_concat = torch.cat((feat_src, feat_tgt), 0)
    
            # predict on discriminator
            pred_concat = critic(feat_concat.detach()).to(device)
            # prepare real and fake lafeat_tgt_1bel
            label_src =Variable(torch.ones(feat_src.size(0)).long())
            label_tgt = Variable(torch.zeros(feat_tgt.size(0)).long())
            label_concat = torch.cat((label_src, label_tgt), 0).to(device)

            # compute loss for critic
            loss_critic = criterion(pred_concat, label_concat)
            loss_critic.backward()

            # optimize critic
            optimizer_critic.step()

            pred_cls = torch.squeeze(pred_concat.max(1)[1])
            acc += (pred_cls == label_concat).float().mean()
            run_critic_loss+=  loss_critic.detach().item()
            
            ############################
            # 2.2 train target encoder #
            ############################
            optimizer_tgt.zero_grad() # edited here becareful 
            for i in range (1):
            # zero gradients for optimizer
                # extract and target features
                feat_tgt_1 = tgt_encoder(sample_tgt_1)
                feat_tgt_2 = tgt_encoder(sample_tgt_2)
                feat_tgt_3 = tgt_encoder(sample_tgt_3)
#                 feat_tgt= torch.cat((feat_tgt_1,feat_tgt_2),0)
                feat_tgt= torch.cat((feat_tgt_1,feat_tgt_2,feat_tgt_3),0)

                # predict on discriminator
                pred_tgt = critic(feat_tgt).to(device)

                # prepare fake labels to enforce the feature extractor to confuse the critic 
                label_tgt = Variable(torch.ones(feat_tgt.size(0)).long()).to(device)

                # compute loss for target encoder
                loss_tgt = criterion(pred_tgt, label_tgt)
                loss_tgt.backward()

                # optimize target encoder
                optimizer_tgt.step()

                run_tgt_loss+=loss_tgt.detach().item()
            num_batches+=1
            
        #######################
        # 2.3 print epoch info #
        #######################
        print('{} seconds'.format(time.time() - t1))
        if ((epoch) % log_step == 0):
            print("Epoch [{}/{}] :"
                  "discriminator_loss={:.5f} target_loss={:.5f} discriminator_acc={:.5f}"
                  .format(epoch,
                         num_epochs,
                          run_critic_loss/num_batches,
                          run_tgt_loss/(num_batches*5),
                          acc.data[0]/num_batches))
            print("=== Evaluating classifier for encoded target domain ===")
            print(">>> source only <<<")
            eval_tgt(src_encoder, src_classifier)
            print(">>> domain adaption <<<")
            eval_tgt(tgt_encoder, src_classifier)
    return tgt_encoder

def eval_tgt(encoder, classifier):
    """Evaluation for target encoder by source classifier on target dataset."""
    # set eval state for Dropout and BN layers
    encoder.eval()
    classifier.eval()
    # init loss and accuracy
    loss = 0
    acc = 0
    mean_loss=0
    mean_acc=0
    num_batches=0
    # set loss function
    criterion = nn.CrossEntropyLoss()
    x=[]
    y=[]
    shuffled_indices_T=torch.randperm(num_target_samples)
    # evaluate network
    with torch.no_grad():
        for i in range(0,num_target_samples,bs):
            indices_T=shuffled_indices_T[i:i+bs]
            minibatch_data =  target_test[indices_T].unsqueeze(dim=1)
            minibatch_label= target_test_labels[indices_T].squeeze()
            minibatch_data=minibatch_data.to(device)
            minibatch_label=minibatch_label.to(device)
            scores=classifier(encoder(minibatch_data))
            # calculate accuracy                     
            acc += (scores.max(1)[1] == minibatch_label).float().mean().item()
            loss += criterion(scores, minibatch_label)
            x.append(scores.max(1)[1])
            y.append(minibatch_label)
            num_batches+=1
    mean_accuracy = acc / num_batches
    mean_loss = loss / num_batches
    print("Avg Accuracy = {:2%}".format(mean_accuracy))
    return x,y

#  Main Code  

In [None]:
# # load models
src_encoder = Encoder_AMDA().to(device)

src_classifier= Classifier().to(device)

tgt_encoder = Encoder_AMDA().to(device)

critic = Discriminator(input_dims=d_input_dims,
                                  hidden_dims=d_hidden_dims,
                                  output_dims=d_output_dims).to(device)
src_encoder, src_classifier = train_src(
        src_encoder, src_classifier)
# src_encoder.load_state_dict(torch.load('src_enc_wk_d_temp.pt'))
# src_classifier.load_state_dict(torch.load('classifier_wk_d_temp.pt'))

tgt_encoder.load_state_dict(src_encoder.state_dict())
tgt_encoder = train_tgt(src_encoder, tgt_encoder, critic)

# eval target encoder on test set of target dataset
print("=== Evaluating classifier for encoded target domain ===")
print(">>> source only <<<")
eval_tgt(src_encoder, src_classifier)
print(">>> domain adaption <<<")
eval_tgt(tgt_encoder, src_classifier)


Epoch [0/50] : loss=1.6868433492879074 train_accuracy=45.362846087664366
1.1918480396270752 seconds
2.310789108276367 seconds
3.4039793014526367 seconds
4.45129132270813 seconds
5.5062174797058105 seconds
Epoch [5/50] : loss=0.2215828502861162 train_accuracy=91.22829648355643
6.593669414520264 seconds
7.643713474273682 seconds
8.696155548095703 seconds
9.748639583587646 seconds
10.799995183944702 seconds
Epoch [10/50] : loss=0.006018737720296485 train_accuracy=99.84374927977721
11.857589960098267 seconds
12.909178018569946 seconds
13.962198495864868 seconds
15.02686619758606 seconds
16.080912351608276 seconds
Epoch [15/50] : loss=0.0010680420701968767 train_accuracy=99.97916656235854
17.131296634674072 seconds
18.20425057411194 seconds
19.25678515434265 seconds
20.310431718826294 seconds
21.36260151863098 seconds
Avg Loss = 1.1473339294220417e-05, Avg Accuracy = 100.0
Epoch [20/50] : loss=0.00014957914743263245 train_accuracy=100.0
22.609604835510254 seconds


### Test on Mutiple Domains

In [14]:
src_encoder.load_state_dict(torch.load('src_enc_wk_b_M_ADDA_new.pt'))
src_classifier.load_state_dict(torch.load('classifier_wk_b_M_ADDA_mew.pt'))
tgt_encoder.load_state_dict(torch.load('tgt_enc_wk_b_M_ADDA_new.pt'))

target_domain,target_labels,target_test,target_test_labels=wk_cond_b_full

print(">>> domain adaption_B <<<")
print(">>> source only <<<")
x,y=eval_tgt(src_encoder, src_classifier)
print(">>> domain adaption <<<")
x,y=eval_tgt(tgt_encoder, src_classifier)

target_domain,target_labels,target_test,target_test_labels=wk_cond_c_full
print(">>> domain adaption _C<<<")
print(">>> source only <<<")
x,y=eval_tgt(src_encoder, src_classifier)
print(">>> domain adaption <<<")
x,y=eval_tgt(tgt_encoder, src_classifier)

target_domain,target_labels,target_test,target_test_labels=wk_cond_d_full
print(">>> domain adaption _D<<<")
print(">>> source only <<<")
x,y=eval_tgt(src_encoder, src_classifier)
print(">>> domain adaption <<<")
x,y=eval_tgt(tgt_encoder, src_classifier)


>>> domain adaption_B <<<
>>> source only <<<
Avg Accuracy = 81.833508%
>>> domain adaption <<<
Avg Accuracy = 98.308587%
>>> domain adaption _C<<<
>>> source only <<<
Avg Accuracy = 85.876986%
>>> domain adaption <<<
Avg Accuracy = 96.997876%
>>> domain adaption _D<<<
>>> source only <<<
Avg Accuracy = 86.635205%
>>> domain adaption <<<
Avg Accuracy = 100.000000%
