In [1]:

from utils.path_utils import project_root

import os

import torch

import numpy as np
import pandas as pd

import tqdm


In [37]:
from utils.pretrain_utils.data import get_pretrain_finetune_test_datasets

pt_train, finetune, test = get_pretrain_finetune_test_datasets()

In [2]:
!sbatch eval_simmtm.job

In [3]:
-

# Args

In [None]:
import argparse

parser = argparse.ArgumentParser()

home_dir = os.getcwd()
parser.add_argument('--run_description', default='run1', type=str, help='Experiment Description')
parser.add_argument('--seed', default=2023, type=int, help='seed value')

parser.add_argument('--training_mode', default='pre_train', type=str, help='pre_train, fine_tune')
parser.add_argument('--pretrain_dataset', default='SleepEEG', type=str,
                    help='Dataset of choice: SleepEEG, FD_A, HAR, ECG')
parser.add_argument('--target_dataset', default='Epilepsy', type=str,
                    help='Dataset of choice: Epilepsy, FD_B, Gesture, EMG')

parser.add_argument('--logs_save_dir', default='experiments_logs', type=str, help='saving directory')
parser.add_argument('--device', default='cuda', type=str, help='cpu or cuda')
parser.add_argument('--home_path', default=home_dir, type=str, help='Project home directory')
parser.add_argument('--subset', action='store_true', default=False, help='use the subset of datasets')
parser.add_argument('--log_epoch', default=5, type=int, help='print loss and metrix')
parser.add_argument('--draw_similar_matrix', default=10, type=int, help='draw similarity matrix')
parser.add_argument('--pretrain_lr', default=0.0001, type=float, help='pretrain learning rate')
parser.add_argument('--lr', default=0.0001, type=float, help='learning rate')
parser.add_argument('--use_pretrain_epoch_dir', default=None, type=str,
                    help='choose the pretrain checkpoint to finetune')
parser.add_argument('--pretrain_epoch', default=10, type=int, help='pretrain epochs')
parser.add_argument('--finetune_epoch', default=300, type=int, help='finetune epochs')

parser.add_argument('--masking_ratio', default=0.5, type=float, help='masking ratio')
parser.add_argument('--positive_nums', default=3, type=int, help='positive series numbers')
parser.add_argument('--lm', default=3, type=int, help='average masked lenght')

parser.add_argument('--finetune_result_file_name', default="finetune_result.json", type=str,
                    help='finetune result json name')
parser.add_argument('--temperature', type=float, default=0.2, help='temperature')

args, unknown = parser.parse_known_args()


# Convert csv to pt

In [None]:

def csv_to_pt(patient_files, lengths, is_sepsis, desc):
    
    all_patients = {'samples': [], 'labels': []}
    
    max_time_step = 336
    # print(len(patient_files), len(lengths), len(is_sepsis))
    for idx, (file, length, sepsis) in tqdm.tqdm(enumerate(zip(patient_files, lengths, is_sepsis)), 
                                                      desc=f"{desc}", 
                                                      total=len(patient_files)):
        
        pad_width = ((0, max_time_step - len(file)), (0, 0))
        file = np.pad(file, pad_width=pad_width, mode='constant').astype(np.float32)
        
        if len(file) == max_time_step:
            all_patients['samples'].append(torch.from_numpy(file).unsqueeze(0))
            all_patients['labels'].append(torch.tensor(sepsis, dtype=torch.float32).unsqueeze(0))
        else:
            raise ValueError(f"Length {length} does not match length of patient {idx} with length {len(file)}")
    
    # print('samples: ', type(all_patients['samples']), 'labels: ', type(all_patients['labels']))
    
    all_patients['samples'] = torch.cat(all_patients['samples'], dim=0)
    all_patients['labels'] = torch.cat(all_patients['labels'], dim=0)
    
    return {'samples': all_patients['samples'], 'labels': all_patients['labels']}, lengths, is_sepsis

# all_patients, lengths, is_sepsis = csv_to_pt()


In [None]:
from sklearn.model_selection import train_test_split

def get_train_val_test_indices(sepsis_file, dset, save_distributions=True):
    
    sepsis = pd.read_csv(os.path.join(project_root(), 'data', 'tl_datasets', f'{sepsis_file}'), header=None)
    
    train_indices, val_indices = train_test_split(sepsis, test_size=0.2, random_state=2024)
    
    train_indices = train_indices.index.values
    val_indices = val_indices.index.values
    
    
    if save_distributions:
        train_dist = sepsis.iloc[train_indices].value_counts()
        val_dist = sepsis.iloc[val_indices].value_counts()
        
        train_dist_percentage = np.round(train_dist / len(sepsis.iloc[train_indices]), 2)
        val_dist_percentage = np.round(val_dist / len(sepsis.iloc[val_indices]), 2)
        
        pd.DataFrame(
            {
                'Train Images': train_dist, 'Validation Images': val_dist,
                'Train Distribution Percentage': train_dist_percentage, 'Validation Distribution Percentage': val_dist_percentage,
            }
        ).to_csv(os.path.join(project_root(), 'results', f'distributions{dset}.csv'), index=False)
        
        # pd.read_csv(os.path.join(project_root(), 'results', 'distributions.csv'))
        
    return train_indices, val_indices


# Datasetup

In [None]:
def get_pretrain_finetune_datasets():
    
    # Pre-training Indices
    pt_train_indices, pt_val_indices = get_train_val_test_indices(
        sepsis_file='is_sepsis_pretrain_A.txt', save_distributions=True,
        dset='Aa')
    
    # Gathering files, lengths, and sepsis label
    pt_files = pd.read_pickle(os.path.join(project_root(), 'data', 'tl_datasets', 'final_dataset_pretrain_A.pickle'))
    pt_lengths = pd.read_csv(os.path.join(project_root(), 'data', 'tl_datasets', 'lengths_pretrain_A.txt'), 
                             header=None)
    pt_sepsis = pd.read_csv(os.path.join(project_root(), 'data', 'tl_datasets', 'is_sepsis_pretrain_A.txt'),
                            header=None)
    
    # Checking whether the files are in same order or not
    pretrain_files = []
    for pdata, length in tqdm.tqdm(zip(pt_files, pt_lengths.values), desc="Checking Pre-training & Validation Files", 
                                   total=len(pt_files)):
        plength = len(pdata) 
        assert plength == length[0], f"{plength} doesn't match {length}"
        pretrain_files.append(pdata.drop(['PatientID', 'SepsisLabel'], axis=1))
    
    # Getting train and val
    pt_train = [pretrain_files[i] for i in pt_train_indices]
    pt_val = [pretrain_files[i] for i in pt_val_indices]
    
    pt_train_lengths = pt_lengths.iloc[pt_train_indices].values
    pt_val_lengths = pt_lengths.iloc[pt_val_indices].values
    
    pt_train_sepsis = pt_sepsis.iloc[pt_train_indices].values
    pt_val_sepsis = pt_sepsis.iloc[pt_val_indices].values
    
    pt_train, pt_train_lengths, pt_train_sepsis = csv_to_pt(pt_train, pt_train_lengths, pt_train_sepsis, desc='PT Train Set')
    pt_val, pt_val_lengths, pt_val_sepsis = csv_to_pt(pt_val, pt_val_lengths, pt_val_sepsis, desc='PT Validation Set')
    
    # Fine-tuning
    test_indices, finetune_indices = get_train_val_test_indices(
        sepsis_file='is_sepsis_finetune_B.txt', save_distributions=True,
        dset='Bb')
    
    # Gathering files, lengths, and sepsis label
    test_setB = pd.read_pickle(os.path.join(project_root(), 'data', 'tl_datasets', 'final_dataset_finetune_B.pickle'))
    test_setB_lengths = pd.read_csv(os.path.join(project_root(), 'data', 'tl_datasets', 'lengths_finetune_B.txt'), 
                             header=None)
    test_setB_sepsis = pd.read_csv(os.path.join(project_root(), 'data', 'tl_datasets', 'is_sepsis_finetune_B.txt'),
                            header=None)
    
    # Checking whether the files are in same order or not
    test_files = []
    for pdata, length in tqdm.tqdm(zip(test_setB, test_setB_lengths.values), desc="Checking Fine-tuning & Test Files",
                                   total=len(test_setB)):
        plength = len(pdata) 
        assert plength == length[0], f"{plength} doesn't match {length}"
        test_files.append(pdata.drop(['PatientID', 'SepsisLabel'], axis=1))
    
    # Getting finetune and test sets
    finetune = [test_files[i] for i in finetune_indices]
    test = [test_files[i] for i in test_indices]
    
    finetune_lengths = test_setB_lengths.iloc[finetune_indices].values
    test_lengths = test_setB_lengths.iloc[test_indices].values
    
    finetune_sepsis = test_setB_sepsis.iloc[finetune_indices].values
    test_sepsis = test_setB_sepsis.iloc[test_indices].values
    
    finetune, finetune_lengths, finetune_sepsis = csv_to_pt(finetune, finetune_lengths, finetune_sepsis, desc="Fine-tuning Set")
    test, test_lengths, test_sepsis = csv_to_pt(test, test_lengths, test_sepsis, desc="Test Set")
    
    print("Pre-training samples: ", pt_train['samples'].shape, "Validation samples: ", pt_val['samples'].shape)
    print("Fine-tuning samples: ", finetune['samples'].shape, "Test samples: ", test['samples'].shape)
    
    return pt_train, pt_val, finetune, test
    
pt_train, pt_val, finetune, test = get_pretrain_finetune_datasets()


In [None]:

import math

def geom_noise_mask_single(L, lm, masking_ratio):
    """
    Randomly create a boolean mask of length `L`, consisting of subsequences of average length lm, masking with 0s a `masking_ratio`
    proportion of the sequence L. The length of masking subsequences and intervals follow a geometric distribution.
    Args:
        L: length of mask and sequence to be masked
        lm: average length of masking subsequences (streaks of 0s)
        masking_ratio: proportion of L to be masked
    Returns:
        (L, ) boolean numpy array intended to mask ('drop') with 0s a sequence of length L
    """
    keep_mask = np.ones(L, dtype=bool)
    p_m = 1 / lm  # probability of each masking sequence stopping. parameter of geometric distribution.
    p_u = p_m * masking_ratio / (
            1 - masking_ratio)  # probability of each unmasked sequence stopping. parameter of geometric distribution.
    p = [p_m, p_u]

    # Start in state 0 with masking_ratio probability
    state = int(np.random.rand() > masking_ratio)  # state 0 means masking, 1 means not masking
    for i in range(L):
        keep_mask[i] = state  # here it happens that state and masking value corresponding to state are identical
        if np.random.rand() < p[state]:
            state = 1 - state

    return keep_mask


def noise_mask(X, masking_ratio=0.25, lm=3, distribution='geometric', exclude_feats=None):
    """
    Creates a random boolean mask of the same shape as X, with 0s at places where a feature should be masked.
    Args:
        X: (seq_length, feat_dim) numpy array of features corresponding to a single sample
        masking_ratio: proportion of seq_length to be masked. At each time step, will also be the proportion of
            feat_dim that will be masked on average
        lm: average length of masking subsequences (streaks of 0s). Used only when `distribution` is 'geometric'.
        distribution: whether each mask sequence element is sampled independently at random, or whether
            sampling follows a markov chain (and thus is stateful), resulting in geometric distributions of
            masked squences of a desired mean length `lm`
        exclude_feats: iterable of indices corresponding to features to be excluded from masking (i.e. to remain all 1s)
    Returns:
        boolean numpy array with the same shape as X, with 0s at places where a feature should be masked
    """
    if exclude_feats is not None:
        exclude_feats = set(exclude_feats)

    if distribution == 'geometric':  # stateful (Markov chain)
        mask = geom_noise_mask_single(X.shape[0] * X.shape[1] * X.shape[2], lm, masking_ratio)
        mask = mask.reshape(X.shape[0], X.shape[1], X.shape[2])
        
    elif distribution == 'masked_tail':
        mask = np.ones(X.shape, dtype=bool)
        for m in range(X.shape[0]):  # feature dimension

            keep_mask = np.zeros_like(mask[m, :], dtype=bool)
            n = math.ceil(keep_mask.shape[1] * (1 - masking_ratio))
            keep_mask[:, :n] = True
            mask[m, :] = keep_mask  # time dimension
            
    elif distribution == 'masked_head':
        mask = np.ones(X.shape, dtype=bool)
        for m in range(X.shape[0]):  # feature dimension

            keep_mask = np.zeros_like(mask[m, :], dtype=bool)
            n = math.ceil(keep_mask.shape[1] * masking_ratio)
            keep_mask[:, n:] = True
            mask[m, :] = keep_mask  # time dimension
    else:  # each position is independent Bernoulli with p = 1 - masking_ratio
        mask = np.random.choice(np.array([True, False]), size=X.shape, replace=True,
                                p=(1 - masking_ratio, masking_ratio))

    return torch.tensor(mask)

def data_transform_masked4cl(sample, masking_ratio, lm, positive_nums=None, distribution='geometric'):
    """Masked time series in time dimension"""

    if positive_nums is None:
        positive_nums = math.ceil(1.5 / (1 - masking_ratio))
        
    sample = sample.permute(0, 2, 1)  # (batch_size, channels, time_steps)
    
    # Creating the batch in #positive_nums sets
    sample_repeat = sample.repeat(positive_nums, 1, 1)  # (batch_size*positive_num, channels, time steps)

    mask = noise_mask(sample_repeat, masking_ratio, lm, distribution=distribution)
    x_masked = mask * sample_repeat

    return x_masked.permute(0, 2, 1), mask.permute(0, 2, 1)

# data_masked_m, mask = data_transform_masked4cl(all_patients['samples'][:32], 0.5, 3, positive_nums=1, distribution='geometric')


In [None]:
from torch.utils.data import Dataset


class Load_Dataset(Dataset):
    
    def __init__(self, dataset, TSlength_aligned, training_mode):
        
        super(Load_Dataset, self).__init__()
        self.training_mode = training_mode
        
        X_train = dataset["samples"]
        y_train = dataset["labels"]
        
        # shuffle
        data = list(zip(X_train, y_train))
        np.random.shuffle(data)
        
        X_train, y_train = zip(*data)
        X_train, y_train = torch.stack(list(X_train), dim=0), torch.stack(list(y_train), dim=0)

        if len(X_train.shape) < 3:
            X_train = X_train.unsqueeze(2)

        # if X_train.shape.index(min(X_train.shape)) != 1:  # make sure the Channels in second dim
        #     X_train = X_train.permute(0, 2, 1)

        """Align the TS length between source and target datasets"""
        # X_train = X_train[:, :1, :int(config.TSlength_aligned)] # take the first 178 samples
        X_train = X_train[:, :, :int(TSlength_aligned)]
        
        if isinstance(X_train, np.ndarray):
            self.x_data = torch.from_numpy(X_train)
            self.y_data = torch.from_numpy(y_train).long()
        else:
            self.x_data = X_train
            self.y_data = y_train

        self.len = X_train.shape[0]

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len
    

In [None]:
import math

from torch import nn
import torch
from torch.nn import ModuleList

import torch.nn.functional as F

from models.simmtm.gtn.encoder import Encoder
from simmtm.loss import ContrastiveWeight, AggregationRebuild, AutomaticWeightedLoss


class TFC(nn.Module):
    def __init__(self, configs, args):
        
        super(TFC, self).__init__()
        self.training_mode = 'pre_train'
        
        # Projecting input into deep representations
        self.encoder_list_1 = ModuleList([Encoder(d_model=configs.d_model, d_hidden=configs.d_hidden, q=configs.q,
                                                  v=configs.v, h=configs.h, mask=configs.mask, dropout=configs.dropout,
                                                  device=configs.device) for _ in range(configs.N)])

        self.encoder_list_2 = ModuleList([Encoder(d_model=configs.d_model, d_hidden=configs.d_hidden, q=configs.q,
                                                  v=configs.v, h=configs.h, dropout=configs.dropout,
                                                  device=configs.device) for _ in range(configs.N)])

        self.embedding_channel = torch.nn.Linear(configs.d_channel, configs.d_model)
        self.embedding_input = torch.nn.Linear(configs.d_input, configs.d_model)

        self.gate = torch.nn.Linear(configs.d_model * configs.d_input + configs.d_model * configs.d_channel,
                                    configs.d_output)

        self.pe = configs.pe
        self._d_input = configs.d_input
        self._d_model = configs.d_model

        # MLP Layer - To generate Projector(.); to Obtain series-wise representations
        self.dense = nn.Sequential(
            nn.Linear(192512, 256),  # 240128 = encoder1 out features + encoder2 out features
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

        if self.training_mode == 'pre_train':
            self.awl = AutomaticWeightedLoss(2)
            self.contrastive = ContrastiveWeight(args)
            self.aggregation = AggregationRebuild(args)
            # self.head = nn.Linear(240128, 336)  # Reconstruction, we have 336 time steps
            self.head = nn.Linear(192512, configs.d_input * configs.d_channel)  # Replaced to handle multi-variate
            self.mse = torch.nn.MSELoss()
            
    def forward(self, stage, x_in_t, pre_train=False):

        # x_in_t: (128, 336, 133)
        encoding_1 = self.embedding_channel(x_in_t)  # (128, 336, 512)
        input_to_gather = encoding_1 

        if self.pe:
            pe = torch.ones_like(encoding_1[0])
            position = torch.arange(0, self._d_input).unsqueeze(-1)
            temp = torch.Tensor(range(0, self._d_model, 2))
            temp = temp * -(math.log(10000) / self._d_model)
            temp = torch.exp(temp).unsqueeze(0)
            temp = torch.matmul(position.float(), temp)  # shape:[input, d_model/2]
            pe[:, 0::2] = torch.sin(temp)
            pe[:, 1::2] = torch.cos(temp)

            encoding_1 = encoding_1 + pe  # (128, 336, 512)

        for encoder in self.encoder_list_1:
            # encoding_1: (128, 336, 512)
            encoding_1, score_input = encoder(encoding_1, stage)
            
        encoding_2 = self.embedding_input(x_in_t.transpose(-1, -2))  # encoding_2: (128, 133, 512)
        channel_to_gather = encoding_2  

        for encoder in self.encoder_list_2:
            # encoding_2: (128, 133, 512)
            encoding_2, score_channel = encoder(encoding_2, stage)

        encoding_1 = encoding_1.reshape(encoding_1.shape[0], -1)  # (128, 172032)
        encoding_2 = encoding_2.reshape(encoding_2.shape[0], -1)  # (128, 68096)
        
        encoding_concat = self.gate(torch.cat([encoding_1, encoding_2], dim=-1))  # (128, 2)
        
        # gate: torch.Size([128, 2])
        gate = F.softmax(encoding_concat, dim=-1)  
        encoding = torch.cat([encoding_1 * gate[:, 0:1], encoding_2 * gate[:, 1:2]], dim=-1)  # (128, 240128)
        # print(encoding.shape)
        
        # Projections
        projections = self.dense(encoding)  # (128, 128)

        if pre_train:
            # loss_cl: torch.Size([])
            # similarity_matrix: torch.Size([128, 128])
            # logits: torch.Size([128, 127])
            # positives_mask: torch.Size([128, 128])
            loss_cl, similarity_matrix, logits, positives_mask = self.contrastive(projections)           
            
            # rebuild_weight_matrix: torch.Size([128, 128])
            # agg_x: torch.Size([128, 240128])
            rebuild_weight_matrix, agg_x = self.aggregation(similarity_matrix, encoding)
            
            # pred_x: torch.Size([128, 336])
            pred_x = self.head(agg_x.reshape(agg_x.size(0), -1))
            
            # x_in_t.shape: torch.Size([128, 336, 133])
            # x_in_t.reshape(x_in_t.size(0), -1): torch.Size([128, 44688])
            loss_rb = self.mse(pred_x, x_in_t.reshape(x_in_t.size(0), -1).detach())
            loss = self.awl(loss_cl, loss_rb)

            return loss, loss_cl, loss_rb

        return encoding, encoding_concat
    

In [None]:

def model_pretrain(model, model_optimizer, model_scheduler, train_loader, configs, args, device):
    total_loss = []
    total_cl_loss = []
    total_rb_loss = []
    
    model.to(device)
    model.train()
    for batch_idx, (data, labels) in tqdm.tqdm(enumerate(train_loader), desc="Pre-training model", total=len(train_loader)):  # data shape: (batch_size, seqs, channels)

        model_optimizer.zero_grad()
        # When masking, data is reshaped to (batch_size, channel, seqs) - Inside the data_transform_masked4cl()
        data_masked_m, mask = data_transform_masked4cl(data, args.masking_ratio, args.lm, args.positive_nums)
        data_masked_om = torch.cat([data, data_masked_m], 0)  # (batch_size, seqs, channels)

        data, labels = data.float().to('cpu'), labels.float().to('cpu')
        data_masked_om = data_masked_om.float().to(device)

        # Produce embeddings of original and masked samples  (data_masked_om = data samples + masked samples)
        # loss, loss_cl, loss_rb = model(data_masked_om, pretrain=True)
        # return loss, loss_cl, loss_rb
        
        loss, loss_cl, loss_rb = model(stage='train', x_in_t=data_masked_om, pre_train=True)
        
        # return loss, loss_cl, loss_rb

        loss.backward()
        model_optimizer.step()

        total_loss.append(loss.item())
        total_cl_loss.append(loss_cl.item())
        total_rb_loss.append(loss_rb.item())

    total_loss = torch.tensor(total_loss).mean()
    total_cl_loss = torch.tensor(total_cl_loss).mean()
    total_rb_loss = torch.tensor(total_rb_loss).mean()

    model_scheduler.step()

    return total_loss, total_cl_loss, total_rb_loss


In [None]:
pt_dataset = Load_Dataset(pt_train, TSlength_aligned=336, training_mode='pretrain')
train_loader = torch.utils.data.DataLoader(dataset=pt_dataset, batch_size=32, shuffle=True, 
                                           drop_last=True, num_workers=4)  # (32, 336, 40)

val_dataset = Load_Dataset(pt_val, TSlength_aligned=336, training_mode='pretrain')
val_loader = torch.utils.data.DataLoader(dataset=pt_val, batch_size=32, shuffle=True, 
                                           drop_last=True, num_workers=4)

finetune_dataset = Load_Dataset(finetune, TSlength_aligned=336, training_mode='finetune')
finetune_loader = torch.utils.data.DataLoader(finetune_dataset, batch_size=32, shuffle=True, 
                                              drop_last=True, num_workers=4)



# Training

In [None]:

def get_model_size(model):
    
    def convert_to_gigabytes(input_megabyte):
        gigabyte = 1.0/1024
        convert_gb = gigabyte * input_megabyte
        return convert_gb
    
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
        
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_all_mb = (param_size + buffer_size) / 1024**2
    
    print('model size: {:.3f} GB'.format(convert_to_gigabytes(size_all_mb)))
    
    return convert_to_gigabytes(size_all_mb)


In [None]:
from models.simmtm.model import target_classifier

from models.simmtm.config import Config

def build_model(args, lr, configs, device='cuda', chkpoint=None):
    
    model = TFC(configs, args).to(device)
    if chkpoint:
        pretrained_dict = chkpoint["model_state_dict"]
        model_dict = model.state_dict()
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    classifier = target_classifier(configs).to(device)
    model_optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(configs.beta1, configs.beta2), weight_decay=0)
    classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=lr, 
                                            betas=(configs.beta1, configs.beta2),
                                            weight_decay=0)
    model_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=model_optimizer, T_max=args.finetune_epoch)

    return model, classifier, model_optimizer, classifier_optimizer, model_scheduler

# model, classifier, model_optimizer, classifier_optimizer, model_scheduler = build_model(args, Config().lr, Config())


In [None]:

from sklearn.metrics import (roc_auc_score, average_precision_score, accuracy_score, 
                             precision_score, f1_score, recall_score)

def model_finetune(model, val_dl, device, model_optimizer, model_scheduler, classifier=None, classifier_optimizer=None):
    model.train()
    classifier.train()

    total_loss = []
    total_acc = []
    total_auc = []
    total_prc = []

    criterion = nn.CrossEntropyLoss()
    outs = np.array([])
    trgs = np.array([])

    for data, labels in val_dl:
        model_optimizer.zero_grad()
        classifier_optimizer.zero_grad()

        data, labels = data.float().to(device), labels.long().to(device)

        # Produce embeddings
        h, z = model(stage='train', x_in_t=data, pre_train=False)

        # Add supervised classifier: 1) it's unique to finetuning. 2) this classifier will also be used in test
        fea_concat = h

        predictions = classifier(fea_concat)
        fea_concat_flat = fea_concat.reshape(fea_concat.shape[0], -1)
        print(predictions)
        print(labels)
        print(predictions.shape, labels.shape)
        loss = criterion(predictions, labels)

        acc_bs = labels.eq(predictions.detach().argmax(dim=1)).float().mean()
        onehot_label = F.one_hot(labels)
        pred_numpy = predictions.detach().cpu().numpy()

        try:
            auc_bs = roc_auc_score(onehot_label.detach().cpu().numpy(), pred_numpy, average="macro", multi_class="ovr")
        except:
            auc_bs = 0.0

        try:
            prc_bs = average_precision_score(onehot_label.detach().cpu().numpy(), pred_numpy)
        except:
            prc_bs = 0.0

        total_acc.append(acc_bs)

        if auc_bs != 0:
            total_auc.append(auc_bs)
        if prc_bs != 0:
            total_prc.append(prc_bs)
        total_loss.append(loss.item())

        loss.backward()
        model_optimizer.step()
        classifier_optimizer.step()

        pred = predictions.max(1, keepdim=True)[1]
        outs = np.append(outs, pred.cpu().numpy())
        trgs = np.append(trgs, labels.data.cpu().numpy())

    labels_numpy = labels.detach().cpu().numpy()
    pred_numpy = np.argmax(pred_numpy, axis=1)
    F1 = f1_score(labels_numpy, pred_numpy, average='macro', )  # labels=np.unique(ypred))

    total_loss = torch.tensor(total_loss).mean()  # average loss
    total_acc = torch.tensor(total_acc).mean()  # average acc
    total_auc = torch.tensor(total_auc).mean()  # average auc
    total_prc = torch.tensor(total_prc).mean()

    model_scheduler.step(total_loss)

    return total_loss, total_acc, total_auc, total_prc, fea_concat_flat, trgs, F1


In [None]:

def train(train_loader, val_loader, finetune_loader, device='cuda'):
    
    model = TFC(configs=Config(), args=args)
    params_group = [{'params': model.parameters()}]
    model_optimizer = torch.optim.Adam(params_group, lr=args.pretrain_lr, 
                                       betas=(Config().beta1, Config().beta2),
                                       weight_decay=0)
    
    model_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=model_optimizer, T_max=args.pretrain_epoch)

    experiment_log_dir = os.path.join(project_root(), 'results', 'simmtm')
    os.makedirs(os.path.join(experiment_log_dir, f"saved_models"), exist_ok=True)
    
    best_performance = None
    seed = 2024
    for epoch in range(Config().pretrain_epoch):
        total_loss, total_cl_loss, total_rb_loss = model_pretrain(model=model, model_optimizer=model_optimizer,
                                                              model_scheduler=model_scheduler, train_loader=train_loader, 
                                                              configs=Config(), args=args, device='cuda')
        
        print(f'Pre-training Epoch: {epoch}\t Train Loss: {total_loss:.4f}\t CL Loss: {total_cl_loss:.4f}\t RB Loss: {total_rb_loss:.4f}\n')
        
        chkpoint = {'seed': seed, 'epoch': epoch, 'train_loss': total_loss, 'model_state_dict': model.state_dict()}
        torch.save(chkpoint, os.path.join(experiment_log_dir, f"saved_models/", f'ckp_ep{epoch}.pt'))
        
        # if epoch % 2 == 0:
        for ep in range(1, Config().finetune_epoch+1):
            print(f"Fine-tuning started...")
            ft_model, ft_classifier, ft_model_optimizer, ft_classifier_optimizer, ft_scheduler = build_model(
                args, args.lr, Config(), device=device, chkpoint=chkpoint)
            
            for ep in range(1, Config().finetune_epoch):
                valid_loss, valid_acc, valid_auc, valid_prc, emb_finetune, label_finetune, F1 = model_finetune(
                    ft_model, finetune_loader, device, ft_model_optimizer, ft_scheduler, classifier=ft_classifier,
                    classifier_optimizer=ft_classifier_optimizer)
        
        
        # # Loading the model
        # temp_model = TFC(configs=Config(), args=args)
        # 
        # pretrained_dict = chkpoint["model_state_dict"]
        # model_dict = temp_model.state_dict()
        # model_dict.update(pretrained_dict)
        # temp_model.load_state_dict(model_dict)

train(train_loader, val_loader, finetune_loader)


In [12]:
from utils.evaluate_helper_methods import load_sepsis_model
from utils.path_utils import project_root
import os

model_path = os.path.join(project_root(), 'results', 'simmtm', 'saved_models', 'finetune_ep16.pt')
model = load_sepsis_model(d_input=336, d_channel=40, d_output=2, model_name=model_path,
                          pre_model="simmtm")


In [24]:
torch.load(model_path)['classifier']

In [26]:
configs = Config()
classifier = target_classifier(configs=configs)

In [15]:
import torch

test_data = torch.load(os.path.join(project_root(), 'data', 'test_data', 'simmtm', 'test.pt'))['samples']


In [41]:
test_setB_all_files = os.path.join(project_root(), 'physionet.org', 'files', 'challenge-2019', '1.0.0', 'training',
                         'training_setB')
test_setB_files = os.listdir(test_setB_all_files)
test_setB_files.sort()
test_setB_files.remove('index.html')

test_setB_files = [test_setB_files[i] for i in test_indices]


In [67]:
save_path = os.path.join(project_root(), 'data', 'test_data', 'simmtm', 'psv_files')
for pidx in test_setB_files:
    pdata = pd.read_csv(os.path.join(test_setB_all_files, pidx), sep='|')
    pdata.to_csv(os.path.join(save_path, pidx), sep='|', index=False)


In [1]:
import pandas as pd
import os

from utils.path_utils import project_root
from utils.pretrain_utils.data import get_train_val_test_indices

test_indices, finetune_indices = get_train_val_test_indices(
        sepsis_file='is_sepsis_finetune_B.txt', save_distributions=True,
        dset='Bb')

test_setB_all_files = os.path.join(project_root(), 'physionet.org', 'files', 'challenge-2019', '1.0.0', 'training',
                         'training_setB')
test_setB_files = os.listdir(test_setB_all_files)
test_setB_files.sort()
test_setB_files.remove('index.html')

test_setB_files = [test_setB_files[i] for i in test_indices]

for pidx in test_setB_files:
    break
    
save_path = os.path.join(project_root(), 'data', 'test_data', 'simmtm', 'psv_files')
pd.read_csv(os.path.join(save_path, pidx), sep='|')

In [7]:
import torch
# torch.load(os.path.join(project_root(), 'results', 'simmtm', 'saved_models', 'ckp_ep9.pt'))['model_state_dict']

In [9]:
from models.simmtm.config import Config
from models.simmtm.model import TFC, target_classifier

config = Config()
classifier = target_classifier(config).to('cuda')


In [11]:
[8, 15 (0.89), 16 (1), 17 (1)]

In [None]:
import math
from torch import nn
import torch
from torch.autograd import Function
from torch.nn import ModuleList

from models.adatime.gtn.encoder import Encoder
import torch.nn.functional as F


def get_backbone_class(backbone_name):
    if backbone_name not in globals():
        raise NotImplementedError("Algorithm not found: {}".format(backbone_name))
    return globals()[backbone_name]


class GTN(nn.Module):

    def __init__(self, configs):
        super(GTN, self).__init__()

        self.encoder_list_1 = ModuleList([Encoder(d_model=configs.d_model, d_hidden=configs.d_hidden, q=configs.q,
                                                  v=configs.v, h=configs.h, mask=configs.mask, dropout=configs.dropout,
                                                  device=configs.device) for _ in range(configs.N)])

        self.encoder_list_2 = ModuleList([Encoder(d_model=configs.d_model, d_hidden=configs.d_hidden, q=configs.q,
                                                  v=configs.v, h=configs.h, dropout=configs.dropout,
                                                  device=configs.device) for _ in range(configs.N)])

        self.embedding_channel = torch.nn.Linear(configs.d_channel, configs.d_model)
        self.embedding_input = torch.nn.Linear(configs.d_input, configs.d_model)

        self.gate = torch.nn.Linear(configs.d_model * configs.d_input + configs.d_model * configs.d_channel,
                                    configs.d_output)

        self.pe = configs.pe
        self._d_input = configs.d_input
        self._d_model = configs.d_model

        self.head = nn.Linear(192512, int((configs.d_input * configs.d_channel)/4))

    def forward(self, stage, x_in_t):

        encoding_1 = self.embedding_channel(x_in_t)

        if self.pe:
            pe = torch.ones_like(encoding_1[0])
            position = torch.arange(0, self._d_input).unsqueeze(-1)
            temp = torch.Tensor(range(0, self._d_model, 2))
            temp = temp * -(math.log(10000) / self._d_model)
            temp = torch.exp(temp).unsqueeze(0)
            temp = torch.matmul(position.float(), temp)
            pe[:, 0::2] = torch.sin(temp)
            pe[:, 1::2] = torch.cos(temp)

            encoding_1 = encoding_1 + pe

        for encoder in self.encoder_list_1:
            encoding_1, score_input = encoder(encoding_1, stage)

        encoding_2 = self.embedding_input(x_in_t.transpose(-1, -2))

        for encoder in self.encoder_list_2:
            encoding_2, score_channel = encoder(encoding_2, stage)

        encoding_1 = encoding_1.reshape(encoding_1.shape[0], -1)
        encoding_2 = encoding_2.reshape(encoding_2.shape[0], -1)

        encoding_concat = self.gate(torch.cat([encoding_1, encoding_2], dim=-1))

        gate = F.softmax(encoding_concat, dim=-1)
        encoding = torch.cat([encoding_1 * gate[:, 0:1], encoding_2 * gate[:, 1:2]], dim=-1)
        encoding = self.head(encoding)

        return encoding


class classifier(nn.Module):
    def __init__(self, configs):
        super(classifier, self).__init__()
        self.logits = nn.Linear(int((configs.d_input * configs.d_channel)/4), configs.num_classes)
        self.configs = configs

    def forward(self, x):
        predictions = self.logits(x)

        return predictions


#### Codes required by DANN ##############
class ReverseLayerF(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha
        return output, None


class Discriminator(nn.Module):
    """Discriminator model for source domain."""

    def __init__(self, configs):
        """Init discriminator."""
        super(Discriminator, self).__init__()

        self.layer = nn.Sequential(
            nn.Linear(configs.features_len * configs.final_out_channels, configs.disc_hid_dim),
            nn.ReLU(),
            nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim),
            nn.ReLU(),
            nn.Linear(configs.disc_hid_dim, 2)
            # nn.LogSoftmax(dim=1)
        )

    def forward(self, input):
        """Forward the discriminator."""
        out = self.layer(input)
        return out


#### Codes required by CDAN ##############
class RandomLayer(nn.Module):
    def __init__(self, input_dim_list=[], output_dim=1024):
        super(RandomLayer, self).__init__()
        self.input_num = len(input_dim_list)
        self.output_dim = output_dim
        self.random_matrix = [torch.randn(input_dim_list[i], output_dim) for i in range(self.input_num)]

    def forward(self, input_list):
        return_list = [torch.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)]
        return_tensor = return_list[0] / math.pow(float(self.output_dim), 1.0 / len(return_list))
        for single in return_list[1:]:
            return_tensor = torch.mul(return_tensor, single)
        return return_tensor

    def cuda(self):
        super(RandomLayer, self).cuda()
        self.random_matrix = [val.cuda() for val in self.random_matrix]


class Discriminator_CDAN(nn.Module):
    """Discriminator model for CDAN ."""

    def __init__(self, configs):
        """Init discriminator."""
        super(Discriminator_CDAN, self).__init__()

        self.restored = False

        self.layer = nn.Sequential(
            nn.Linear(configs.features_len * configs.final_out_channels * configs.num_classes, configs.disc_hid_dim),
            nn.ReLU(),
            nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim),
            nn.ReLU(),
            nn.Linear(configs.disc_hid_dim, 2)
            # nn.LogSoftmax(dim=1)
        )

    def forward(self, input):
        """Forward the discriminator."""
        out = self.layer(input)
        return out


class codats_classifier(nn.Module):
    def __init__(self, configs):
        super(codats_classifier, self).__init__()
        model_output_dim = configs.features_len
        self.hidden_dim = configs.hidden_dim
        self.logits = nn.Sequential(
            nn.Linear(model_output_dim * configs.final_out_channels, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, configs.num_classes))

    def forward(self, x_in):
        predictions = self.logits(x_in)
        return predictions


#### Codes required by AdvSKM ##############
class Cosine_act(nn.Module):
    def __init__(self):
        super(Cosine_act, self).__init__()

    def forward(self, input):
        return torch.cos(input)


cos_act = Cosine_act()


class AdvSKM_Disc(nn.Module):
    """Discriminator model for source domain."""

    def __init__(self, configs):
        """Init discriminator."""
        super(AdvSKM_Disc, self).__init__()

        self.input_dim = configs.features_len * configs.final_out_channels
        self.hid_dim = configs.DSKN_disc_hid
        self.branch_1 = nn.Sequential(
            nn.Linear(self.input_dim, self.hid_dim),
            nn.Linear(self.hid_dim, self.hid_dim),
            nn.BatchNorm1d(self.hid_dim),
            cos_act,
            nn.Linear(self.hid_dim, self.hid_dim // 2),
            nn.Linear(self.hid_dim // 2, self.hid_dim // 2),
            nn.BatchNorm1d(self.hid_dim // 2),
            cos_act
        )
        self.branch_2 = nn.Sequential(
            nn.Linear(configs.features_len * configs.final_out_channels, configs.disc_hid_dim),
            nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim),
            nn.BatchNorm1d(configs.disc_hid_dim),
            nn.ReLU(),
            nn.Linear(configs.disc_hid_dim, configs.disc_hid_dim // 2),
            nn.Linear(configs.disc_hid_dim // 2, configs.disc_hid_dim // 2),
            nn.BatchNorm1d(configs.disc_hid_dim // 2),
            nn.ReLU())

    def forward(self, input):
        """Forward the discriminator."""
        out_cos = self.branch_1(input)
        out_rel = self.branch_2(input)
        total_out = torch.cat((out_cos, out_rel), dim=1)
        return total_out


class attn_network(nn.Module):
    def __init__(self, configs):
        super(attn_network, self).__init__()

        self.h_dim = configs.features_len * configs.final_out_channels
        self.self_attn_Q = nn.Sequential(nn.Linear(in_features=self.h_dim, out_features=self.h_dim),
                                         nn.ELU()
                                         )
        self.self_attn_K = nn.Sequential(nn.Linear(in_features=self.h_dim, out_features=self.h_dim),
                                         nn.LeakyReLU()
                                         )
        self.self_attn_V = nn.Sequential(nn.Linear(in_features=self.h_dim, out_features=self.h_dim),
                                         nn.LeakyReLU()
                                         )

    def forward(self, x):
        Q = self.self_attn_Q(x)
        K = self.self_attn_K(x)
        V = self.self_attn_V(x)

        return Q, K, V


class Sparsemax(nn.Module):
    """Sparsemax function."""

    def __init__(self, dim=None):
        """Initialize sparsemax activation

        Args:
            dim (int, optional): The dimension over which to apply the sparsemax function.
        """
        super(Sparsemax, self).__init__()

        self.dim = -1 if dim is None else dim

    def forward(self, input):
        """Forward function.
        Args:
            input (torch.Tensor): Input tensor. First dimension should be the batch size
        Returns:
            torch.Tensor: [batch_size x number_of_logits] Output tensor
        """
        # Sparsemax currently only handles 2-dim tensors,
        # so we reshape to a convenient shape and reshape back after sparsemax
        input = input.transpose(0, self.dim)
        original_size = input.size()
        input = input.reshape(input.size(0), -1)
        input = input.transpose(0, 1)
        dim = 1

        number_of_logits = input.size(dim)

        # Translate input by max for numerical stability
        input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)

        # Sort input in descending order.
        # (NOTE: Can be replaced with linear time selection method described here:
        # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html)
        zs = torch.sort(input=input, dim=dim, descending=True)[0]
        range = torch.arange(start=1, end=number_of_logits + 1, step=1, device=input.device, dtype=input.dtype).view(1,
                                                                                                                     -1)
        range = range.expand_as(zs)

        # Determine sparsity of projection
        bound = 1 + range * zs
        cumulative_sum_zs = torch.cumsum(zs, dim)
        is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
        k = torch.max(is_gt * range, dim, keepdim=True)[0]

        # Compute threshold function
        zs_sparse = is_gt * zs

        # Compute taus
        taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
        taus = taus.expand_as(input)

        # Sparsemax
        self.output = torch.max(torch.zeros_like(input), input - taus)

        # Reshape back to original shape
        output = self.output
        output = output.transpose(0, 1)
        output = output.reshape(original_size)
        output = output.transpose(0, self.dim)

        return output

    def backward(self, grad_output):
        """Backward function."""
        dim = 1

        nonzeros = torch.ne(self.output, 0)
        sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
        self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))

        return self.grad_input


class CNN_ATTN(nn.Module):
    def __init__(self, configs):
        super(CNN_ATTN, self).__init__()

        self.conv_block1 = nn.Sequential(
            nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=configs.kernel_size,
                      stride=configs.stride, bias=False, padding=(configs.kernel_size // 2)),
            nn.BatchNorm1d(configs.mid_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
            nn.Dropout(configs.dropout)
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv1d(configs.mid_channels, configs.mid_channels * 2, kernel_size=8, stride=1, bias=False, padding=4),
            nn.BatchNorm1d(configs.mid_channels * 2),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
        )

        self.conv_block3 = nn.Sequential(
            nn.Conv1d(configs.mid_channels * 2, configs.final_out_channels, kernel_size=8, stride=1, bias=False,
                      padding=4),
            nn.BatchNorm1d(configs.final_out_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
        )

        self.adaptive_pool = nn.AdaptiveAvgPool1d(configs.features_len)
        self.attn_network = attn_network(configs)
        self.sparse_max = Sparsemax(dim=-1)
        self.feat_len = configs.features_len

    def forward(self, x_in):
        x = self.conv_block1(x_in)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.adaptive_pool(x)
        x_flat = x.reshape(x.shape[0], -1)
        attentive_feat = self.calculate_attentive_feat(x_flat)
        return attentive_feat

    def self_attention(self, Q, K, scale=True, sparse=True, k=3):

        attention_weight = torch.bmm(Q.view(Q.shape[0], self.feat_len, -1), K.view(K.shape[0], -1, self.feat_len))

        attention_weight = torch.mean(attention_weight, dim=2, keepdim=True)

        if scale:
            d_k = torch.tensor(K.shape[-1]).float()
            attention_weight = attention_weight / torch.sqrt(d_k)
        if sparse:
            attention_weight_sparse = self.sparse_max(torch.reshape(attention_weight, [-1, self.feat_len]))
            attention_weight = torch.reshape(attention_weight_sparse, [-1, attention_weight.shape[1],
                                                                       attention_weight.shape[2]])
        else:
            attention_weight = self.softmax(attention_weight)

        return attention_weight

    def attention_fn(self, Q, K, scaled=False, sparse=True, k=1):

        attention_weight = torch.matmul(F.normalize(Q, p=2, dim=-1),
                                        F.normalize(K, p=2, dim=-1).view(K.shape[0], K.shape[1], -1, self.feat_len))

        if scaled:
            d_k = torch.tensor(K.shape[-1]).float()
            attention_weight = attention_weight / torch.sqrt(d_k)
            attention_weight = k * torch.log(torch.tensor(self.feat_len, dtype=torch.float32)) * attention_weight

        if sparse:
            attention_weight_sparse = self.sparse_max(torch.reshape(attention_weight, [-1, self.feat_len]))

            attention_weight = torch.reshape(attention_weight_sparse, attention_weight.shape)
        else:
            attention_weight = self.softmax(attention_weight)

        return attention_weight

    def calculate_attentive_feat(self, candidate_representation_xi):
        Q_xi, K_xi, V_xi = self.attn_network(candidate_representation_xi)
        intra_attention_weight_xi = self.self_attention(Q=Q_xi, K=K_xi, sparse=True)
        Z_i = torch.bmm(intra_attention_weight_xi.view(intra_attention_weight_xi.shape[0], 1, -1),
                        V_xi.view(V_xi.shape[0], self.feat_len, -1))
        final_feature = F.normalize(Z_i, dim=-1).view(Z_i.shape[0],-1)

        return final_feature



In [None]:
# import torch
# import torch.nn as nn
# import numpy as np
# import itertools
#
# from da.adatime.da import classifier, ReverseLayerF, Discriminator, RandomLayer, Discriminator_CDAN, \
#     codats_classifier, AdvSKM_Disc, CNN_ATTN
#
# from da.adatime.da.loss import MMD_loss, CORAL, ConditionalEntropyLoss, VAT, LMMD_loss, HoMM_loss, NTXentLoss, SupConLoss
# from da.adatime.utils import EMA
# from torch.optim.lr_scheduler import StepLR
# from copy import deepcopy
# import torch.nn.functional as F
#
#
# def get_algorithm_class(algorithm_name):
#     """Return the algorithm class with the given name."""
#     if algorithm_name not in globals():
#         raise NotImplementedError("Algorithm not found: {}".format(algorithm_name))
#     return globals()[algorithm_name]
#
#
# class Algorithm(torch.nn.Module):
#     """
#     A subclass of Algorithm implements a domain adaptation algorithm.
#     Subclasses should implement the update() method.
#     """
#
#     def __init__(self, configs, backbone):
#         super(Algorithm, self).__init__()
#         self.configs = configs
#
#         self.cross_entropy = nn.CrossEntropyLoss()
#         self.feature_extractor = backbone(configs)
#         self.classifier = classifier(configs)
#         self.network = nn.Sequential(self.feature_extractor, self.classifier)
#
#     # update function is common to all algorithms
#     def update(self, src_loader, trg_loader, avg_meter, logger):
#         # defining best and last model
#         best_src_risk = float('inf')
#         best_model = None
#
#         for epoch in range(1, self.hparams["num_epochs"] + 1):
#
#             # training loop
#             self.training_epoch(src_loader, trg_loader, avg_meter, epoch)
#
#             # saving the best model based on src risk
#             if (epoch + 1) % 10 == 0 and avg_meter['Src_cls_loss'].avg < best_src_risk:
#                 best_src_risk = avg_meter['Src_cls_loss'].avg
#                 best_model = deepcopy(self.network.state_dict())
#
#             logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]')
#             for key, val in avg_meter.items():
#                 logger.debug(f'{key}\t: {val.avg:2.4f}')
#             logger.debug(f'-------------------------------------')
#
#         last_model = self.network.state_dict()
#
#         return last_model, best_model
#
#     # train loop vary from one method to another
#     def training_epoch(self, *args, **kwargs):
#         raise NotImplementedError
#
#
# class NO_ADAPT(Algorithm):
#     """
#     Lower bound: train on source and test on target.
#     """
#
#     def __init__(self, backbone, configs, hparams, device):
#         super().__init__(configs, backbone)
#
#         # optimizer and scheduler
#         self.optimizer = torch.optim.Adam(
#             self.network.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         # hparams
#         self.hparams = hparams
#         # device
#         self.device = device
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
#         for src_x, src_y in src_loader:
#
#             src_x, src_y = src_x.to(self.device), src_y.to(self.device)
#             src_feat = self.feature_extractor(src_x)
#             src_pred = self.classifier(src_feat)
#
#             src_cls_loss = self.cross_entropy(src_pred, src_y)
#
#             loss = src_cls_loss
#
#             self.optimizer.zero_grad()
#             loss.backward()
#             self.optimizer.step()
#
#             losses = {'Src_cls_loss': src_cls_loss.item()}
#
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#         self.lr_scheduler.step()
#
#
# class TARGET_ONLY(Algorithm):
#     """
#     Upper bound: train on target and test on target.
#     """
#
#     def __init__(self, backbone, configs, hparams, device):
#         super().__init__(configs, backbone)
#
#         # optimizer and scheduler
#         self.optimizer = torch.optim.Adam(
#             self.network.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         # hparams
#         self.hparams = hparams
#         # device
#         self.device = device
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
#
#         for trg_x, trg_y in trg_loader:
#
#             trg_x, trg_y = trg_x.to(self.device), trg_y.to(self.device)
#
#             trg_feat = self.feature_extractor(trg_x)
#             trg_pred = self.classifier(trg_feat)
#
#             trg_cls_loss = self.cross_entropy(trg_pred, trg_y)
#
#             loss = trg_cls_loss
#
#             self.optimizer.zero_grad()
#             loss.backward()
#             self.optimizer.step()
#
#             losses = {'Trg_cls_loss': trg_cls_loss.item()}
#
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#         self.lr_scheduler.step()
#
#
# class Deep_Coral(Algorithm):
#     """
#     Deep Coral: https://arxiv.org/abs/1607.01719
#     """
#
#     def __init__(self, backbone, configs, hparams, device):
#         super().__init__(configs, backbone)
#
#         # optimizer and scheduler
#         self.optimizer = torch.optim.Adam(
#             self.network.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         # hparams
#         self.hparams = hparams
#         # device
#         self.device = device
#
#         # correlation alignment loss
#         self.coral = CORAL()
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
#
#         # Construct Joint Loaders
#         # add if statement
#
#         if len(src_loader) > len(trg_loader):
#             joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
#         else:
#             joint_loader = enumerate(zip(itertools.cycle(src_loader), trg_loader))
#
#         for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
#             src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device)
#
#             src_feat = self.feature_extractor(src_x)
#             src_pred = self.classifier(src_feat)
#
#             src_cls_loss = self.cross_entropy(src_pred, src_y)
#
#             trg_feat = self.feature_extractor(trg_x)
#
#             coral_loss = self.coral(src_feat, trg_feat)
#
#             loss = self.hparams["coral_wt"] * coral_loss + \
#                    self.hparams["src_cls_loss_wt"] * src_cls_loss
#
#             self.optimizer.zero_grad()
#             loss.backward()
#             self.optimizer.step()
#
#             losses = {'Total_loss': loss.item(), 'Src_cls_loss': src_cls_loss.item(),
#                       'coral_loss': coral_loss.item()}
#
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#         self.lr_scheduler.step()
#
#
# class MMDA(Algorithm):
#     """
#     MMDA: https://arxiv.org/abs/1901.00282
#     """
#
#     def __init__(self, backbone, configs, hparams, device):
#         super().__init__(configs, backbone)
#
#         # optimizer and scheduler
#         self.optimizer = torch.optim.Adam(
#             self.network.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         # hparams
#         self.hparams = hparams
#         # device
#         self.device = device
#
#         # Aligment losses
#         self.mmd = MMD_loss()
#         self.coral = CORAL()
#         self.cond_ent = ConditionalEntropyLoss()
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
#
#         # Construct Joint Loaders
#         joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
#
#         for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
#             src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device)
#
#             src_feat = self.feature_extractor(src_x)
#             src_pred = self.classifier(src_feat)
#
#             src_cls_loss = self.cross_entropy(src_pred, src_y)
#
#             trg_feat = self.feature_extractor(trg_x)
#             src_feat = self.feature_extractor(src_x)
#             src_pred = self.classifier(src_feat)
#
#             src_cls_loss = self.cross_entropy(src_pred, src_y)
#
#             trg_feat = self.feature_extractor(trg_x)
#
#             coral_loss = self.coral(src_feat, trg_feat)
#             mmd_loss = self.mmd(src_feat, trg_feat)
#             cond_ent_loss = self.cond_ent(trg_feat)
#
#             loss = self.hparams["coral_wt"] * coral_loss + \
#                    self.hparams["mmd_wt"] * mmd_loss + \
#                    self.hparams["cond_ent_wt"] * cond_ent_loss + \
#                    self.hparams["src_cls_loss_wt"] * src_cls_loss
#
#             self.optimizer.zero_grad()
#             loss.backward()
#             self.optimizer.step()
#
#             losses = {'Total_loss': loss.item(), 'Coral_loss': coral_loss.item(), 'MMD_loss': mmd_loss.item(),
#                       'cond_ent_wt': cond_ent_loss.item(), 'Src_cls_loss': src_cls_loss.item()}
#
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#         self.lr_scheduler.step()
#
#
# class DANN(Algorithm):
#     """
#     DANN: https://arxiv.org/abs/1505.07818
#     """
#
#     def __init__(self, backbone, configs, hparams, device):
#         super().__init__(configs, backbone)
#
#         # optimizer and scheduler
#         self.optimizer = torch.optim.Adam(
#             self.network.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         # hparams
#         self.hparams = hparams
#         # device
#         self.device = device
#
#         # Domain Discriminator
#         self.domain_classifier = Discriminator(configs)
#         self.optimizer_disc = torch.optim.Adam(
#             self.domain_classifier.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"], betas=(0.5, 0.99)
#         )
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
#         # Combine dataloaders
#         # Method 1 (min len of both domains)
#         # joint_loader = enumerate(zip(src_loader, trg_loader))
#
#         # Method 2 (max len of both domains)
#         # joint_loader =enumerate(zip(src_loader, itertools.cycle(trg_loader)))
#         joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
#         num_batches = max(len(src_loader), len(trg_loader))
#
#         for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
#
#             src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device)
#
#             p = float(step + epoch * num_batches) / self.hparams["num_epochs"] + 1 / num_batches
#             alpha = 2. / (1. + np.exp(-10 * p)) - 1
#
#             # zero grad
#             self.optimizer.zero_grad()
#             self.optimizer_disc.zero_grad()
#
#             domain_label_src = torch.ones(len(src_x)).to(self.device)
#             domain_label_trg = torch.zeros(len(trg_x)).to(self.device)
#
#             src_feat = self.feature_extractor(src_x)
#             src_pred = self.classifier(src_feat)
#
#             trg_feat = self.feature_extractor(trg_x)
#
#             # Task classification  Loss
#             src_cls_loss = self.cross_entropy(src_pred.squeeze(), src_y)
#
#             # Domain classification loss
#             # source
#             src_feat_reversed = ReverseLayerF.apply(src_feat, alpha)
#             src_domain_pred = self.domain_classifier(src_feat_reversed)
#             src_domain_loss = self.cross_entropy(src_domain_pred, domain_label_src.long())
#
#             # target
#             trg_feat_reversed = ReverseLayerF.apply(trg_feat, alpha)
#             trg_domain_pred = self.domain_classifier(trg_feat_reversed)
#             trg_domain_loss = self.cross_entropy(trg_domain_pred, domain_label_trg.long())
#
#             # Total domain loss
#             domain_loss = src_domain_loss + trg_domain_loss
#
#             loss = self.hparams["src_cls_loss_wt"] * src_cls_loss + \
#                    self.hparams["domain_loss_wt"] * domain_loss
#
#             loss.backward()
#             self.optimizer.step()
#             self.optimizer_disc.step()
#
#             losses = {'Total_loss': loss.item(), 'Domain_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()}
#
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#         self.lr_scheduler.step()
#
#
# class CDAN(Algorithm):
#     """
#     CDAN: https://arxiv.org/abs/1705.10667
#     """
#
#     def __init__(self, backbone, configs, hparams, device):
#         super().__init__(configs, backbone)
#
#         # optimizer and scheduler
#         self.optimizer = torch.optim.Adam(
#             self.network.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         # hparams
#         self.hparams = hparams
#         # device
#         self.device = device
#
#         # Aligment Losses
#         self.criterion_cond = ConditionalEntropyLoss().to(device)
#
#         self.domain_classifier = Discriminator_CDAN(configs)
#         self.random_layer = RandomLayer([configs.features_len * configs.final_out_channels, configs.num_classes],
#                                         configs.features_len * configs.final_out_channels)
#         self.optimizer_disc = torch.optim.Adam(
#             self.domain_classifier.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"])
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
#
#         # Construct Joint Loaders
#         joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
#
#         for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
#             src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device)
#             # prepare true domain labels
#             domain_label_src = torch.ones(len(src_x)).to(self.device)
#             domain_label_trg = torch.zeros(len(trg_x)).to(self.device)
#             domain_label_concat = torch.cat((domain_label_src, domain_label_trg), 0).long()
#
#             # source features and predictions
#             src_feat = self.feature_extractor(src_x)
#             src_pred = self.classifier(src_feat)
#
#             # target features and predictions
#             trg_feat = self.feature_extractor(trg_x)
#             trg_pred = self.classifier(trg_feat)
#
#             # concatenate features and predictions
#             feat_concat = torch.cat((src_feat, trg_feat), dim=0)
#             pred_concat = torch.cat((src_pred, trg_pred), dim=0)
#
#             # Domain classification loss
#             feat_x_pred = torch.bmm(pred_concat.unsqueeze(2), feat_concat.unsqueeze(1)).detach()
#             disc_prediction = self.domain_classifier(feat_x_pred.view(-1, pred_concat.size(1) * feat_concat.size(1)))
#             disc_loss = self.cross_entropy(disc_prediction, domain_label_concat)
#
#             # update Domain classification
#             self.optimizer_disc.zero_grad()
#             disc_loss.backward()
#             self.optimizer_disc.step()
#
#             # prepare fake domain labels for training the feature extractor
#             domain_label_src = torch.zeros(len(src_x)).long().to(self.device)
#             domain_label_trg = torch.ones(len(trg_x)).long().to(self.device)
#             domain_label_concat = torch.cat((domain_label_src, domain_label_trg), 0)
#
#             # Repeat predictions after updating discriminator
#             feat_x_pred = torch.bmm(pred_concat.unsqueeze(2), feat_concat.unsqueeze(1))
#             disc_prediction = self.domain_classifier(feat_x_pred.view(-1, pred_concat.size(1) * feat_concat.size(1)))
#             # loss of domain discriminator according to fake labels
#
#             domain_loss = self.cross_entropy(disc_prediction, domain_label_concat)
#
#             # Task classification  Loss
#             src_cls_loss = self.cross_entropy(src_pred.squeeze(), src_y)
#
#             # conditional entropy loss.
#             loss_trg_cent = self.criterion_cond(trg_pred)
#
#             # total loss
#             loss = self.hparams["src_cls_loss_wt"] * src_cls_loss + self.hparams["domain_loss_wt"] * domain_loss + \
#                    self.hparams["cond_ent_wt"] * loss_trg_cent
#
#             # update feature extractor
#             self.optimizer.zero_grad()
#             loss.backward()
#             self.optimizer.step()
#
#             losses = {'Total_loss': loss.item(), 'Domain_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item(),
#                       'cond_ent_loss': loss_trg_cent.item()}
#
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#         self.lr_scheduler.step()
#
#
# class DIRT(Algorithm):
#     """
#     DIRT-T: https://arxiv.org/abs/1802.08735
#     """
#
#     def __init__(self, backbone, configs, hparams, device):
#         super().__init__(configs, backbone)
#
#         # optimizer and scheduler
#         self.optimizer = torch.optim.Adam(
#             self.network.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         # hparams
#         self.hparams = hparams
#         # device
#         self.device = device
#
#         # Aligment losses
#         self.criterion_cond = ConditionalEntropyLoss().to(device)
#         self.vat_loss = VAT(self.network, device).to(device)
#         self.ema = EMA(0.998)
#         self.ema.register(self.network)
#
#         # Discriminator
#         self.domain_classifier = Discriminator(configs)
#         self.optimizer_disc = torch.optim.Adam(
#             self.domain_classifier.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
#
#         # Construct Joint Loaders
#         joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
#
#         for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
#             src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(self.device)
#             # prepare true domain labels
#             domain_label_src = torch.ones(len(src_x)).to(self.device)
#             domain_label_trg = torch.zeros(len(trg_x)).to(self.device)
#             domain_label_concat = torch.cat((domain_label_src, domain_label_trg), 0).long()
#
#             src_feat = self.feature_extractor(src_x)
#             src_pred = self.classifier(src_feat)
#
#             # target features and predictions
#             trg_feat = self.feature_extractor(trg_x)
#             trg_pred = self.classifier(trg_feat)
#
#             # concatenate features and predictions
#             feat_concat = torch.cat((src_feat, trg_feat), dim=0)
#
#             # Domain classification loss
#             disc_prediction = self.domain_classifier(feat_concat.detach())
#             disc_loss = self.cross_entropy(disc_prediction, domain_label_concat)
#
#             # update Domain classification
#             self.optimizer_disc.zero_grad()
#             disc_loss.backward()
#             self.optimizer_disc.step()
#
#             # prepare fake domain labels for training the feature extractor
#             domain_label_src = torch.zeros(len(src_x)).long().to(self.device)
#             domain_label_trg = torch.ones(len(trg_x)).long().to(self.device)
#             domain_label_concat = torch.cat((domain_label_src, domain_label_trg), 0)
#
#             # Repeat predictions after updating discriminator
#             disc_prediction = self.domain_classifier(feat_concat)
#
#             # loss of domain discriminator according to fake labels
#             domain_loss = self.cross_entropy(disc_prediction, domain_label_concat)
#
#             # Task classification  Loss
#             src_cls_loss = self.cross_entropy(src_pred.squeeze(), src_y)
#
#             # conditional entropy loss.
#             loss_trg_cent = self.criterion_cond(trg_pred)
#
#             # Virual advariarial training loss
#             loss_src_vat = self.vat_loss(src_x, src_pred)
#             loss_trg_vat = self.vat_loss(trg_x, trg_pred)
#             total_vat = loss_src_vat + loss_trg_vat
#             # total loss
#             loss = self.hparams["src_cls_loss_wt"] * src_cls_loss + self.hparams["domain_loss_wt"] * domain_loss + \
#                    self.hparams["cond_ent_wt"] * loss_trg_cent + self.hparams["vat_loss_wt"] * total_vat
#
#             # update exponential moving average
#             self.ema(self.network)
#
#             # update feature extractor
#             self.optimizer.zero_grad()
#             loss.backward()
#             self.optimizer.step()
#
#             losses = {'Total_loss': loss.item(), 'Domain_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item(),
#                       'cond_ent_loss': loss_trg_cent.item()}
#
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#         self.lr_scheduler.step()
#
#
# class DSAN(Algorithm):
#     """
#     DSAN: https://ieeexplore.ieee.org/document/9085896
#     """
#
#     def __init__(self, backbone, configs, hparams, device):
#         super().__init__(configs, backbone)
#
#         # optimizer and scheduler
#         self.optimizer = torch.optim.Adam(
#             self.network.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         # hparams
#         self.hparams = hparams
#         # device
#         self.device = device
#
#         # Alignment losses
#         self.loss_LMMD = LMMD_loss(device=device, class_num=configs.num_classes).to(device)
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
#
#         # Construct Joint Loaders
#         joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
#
#         for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
#             src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(
#                 self.device)  # extract source features
#             src_feat = self.feature_extractor(src_x)
#             src_pred = self.classifier(src_feat)
#
#             # extract target features
#             trg_feat = self.feature_extractor(trg_x)
#             trg_pred = self.classifier(trg_feat)
#
#             # calculate lmmd loss
#             domain_loss = self.loss_LMMD.get_loss(src_feat, trg_feat, src_y,
#                                                   torch.nn.functional.softmax(trg_pred, dim=1))
#
#             # calculate source classification loss
#             src_cls_loss = self.cross_entropy(src_pred, src_y)
#
#             # calculate the total loss
#             loss = self.hparams["domain_loss_wt"] * domain_loss + \
#                    self.hparams["src_cls_loss_wt"] * src_cls_loss
#
#             self.optimizer.zero_grad()
#             loss.backward()
#             self.optimizer.step()
#
#             losses = {'Total_loss': loss.item(), 'LMMD_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()}
#
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#         self.lr_scheduler.step()
#
#
# class HoMM(Algorithm):
#     """
#     HoMM: https://arxiv.org/pdf/1912.11976.pdf
#     """
#
#     def __init__(self, backbone, configs, hparams, device):
#         super().__init__(configs, backbone)
#
#         # optimizer and scheduler
#         self.optimizer = torch.optim.Adam(
#             self.network.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         # hparams
#         self.hparams = hparams
#         # device
#         self.device = device
#
#         # aligment losses
#         self.coral = CORAL()
#         self.HoMM_loss = HoMM_loss()
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
#
#         # Construct Joint Loaders
#         joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
#
#         for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
#             src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(
#                 self.device)  # extract source features
#
#             src_feat = self.feature_extractor(src_x)
#             src_pred = self.classifier(src_feat)
#
#             # extract target features
#             trg_feat = self.feature_extractor(trg_x)
#             trg_pred = self.classifier(trg_feat)
#
#             # calculate source classification loss
#             src_cls_loss = self.cross_entropy(src_pred, src_y)
#
#             # calculate lmmd loss
#             domain_loss = self.HoMM_loss(src_feat, trg_feat)
#
#             # calculate the total loss
#             loss = self.hparams["domain_loss_wt"] * domain_loss + \
#                    self.hparams["src_cls_loss_wt"] * src_cls_loss
#
#             self.optimizer.zero_grad()
#             loss.backward()
#             self.optimizer.step()
#
#             losses = {'Total_loss': loss.item(), 'HoMM_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()}
#
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#         self.lr_scheduler.step()
#
#
# class DDC(Algorithm):
#     """
#     DDC: https://arxiv.org/abs/1412.3474
#     """
#
#     def __init__(self, backbone, configs, hparams, device):
#         super().__init__(configs, backbone)
#
#         # optimizer and scheduler
#         self.optimizer = torch.optim.Adam(
#             self.network.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         # hparams
#         self.hparams = hparams
#         # device
#         self.device = device
#
#         # Aligment losses
#         self.mmd_loss = MMD_loss()
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
#
#         # Construct Joint Loaders
#         joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
#
#         for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
#             src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(
#                 self.device)  # extract source features
#             # extract source features
#             src_feat = self.feature_extractor(src_x)
#             src_pred = self.classifier(src_feat)
#
#             # extract target features
#             trg_feat = self.feature_extractor(trg_x)
#
#             # calculate source classification loss
#             src_cls_loss = self.cross_entropy(src_pred, src_y)
#
#             # calculate mmd loss
#             domain_loss = self.mmd_loss(src_feat, trg_feat)
#
#             # calculate the total loss
#             loss = self.hparams["domain_loss_wt"] * domain_loss + \
#                    self.hparams["src_cls_loss_wt"] * src_cls_loss
#
#             self.optimizer.zero_grad()
#             loss.backward()
#             self.optimizer.step()
#
#             losses = {'Total_loss': loss.item(), 'MMD_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()}
#
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#         self.lr_scheduler.step()
#
#
# class CoDATS(Algorithm):
#     """
#     CoDATS: https://arxiv.org/pdf/2005.10996.pdf
#     """
#
#     def __init__(self, backbone, configs, hparams, device):
#         super().__init__(configs, backbone)
#
#         # we replace the original classifier with codats the classifier
#         # remember to use same name of self.classifier, as we use it for the model evaluation
#         self.classifier = codats_classifier(configs)
#         self.network = nn.Sequential(self.feature_extractor, self.classifier)
#
#         # optimizer and scheduler
#         self.optimizer = torch.optim.Adam(
#             self.network.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         # hparams
#         self.hparams = hparams
#         # device
#         self.device = device
#
#         # Domain classifier
#         self.domain_classifier = Discriminator(configs)
#
#         self.optimizer_disc = torch.optim.Adam(
#             self.domain_classifier.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"], betas=(0.5, 0.99)
#         )
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
#
#         # Construct Joint Loaders
#         joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
#         num_batches = max(len(src_loader), len(trg_loader))
#         for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
#             src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(
#                 self.device)  # extract source features
#
#             p = float(step + epoch * num_batches) / self.hparams["num_epochs"] + 1 / num_batches
#             alpha = 2. / (1. + np.exp(-10 * p)) - 1
#
#             # zero grad
#             self.optimizer.zero_grad()
#             self.optimizer_disc.zero_grad()
#
#             domain_label_src = torch.ones(len(src_x)).to(self.device)
#             domain_label_trg = torch.zeros(len(trg_x)).to(self.device)
#
#             src_feat = self.feature_extractor(src_x)
#             src_pred = self.classifier(src_feat)
#
#             trg_feat = self.feature_extractor(trg_x)
#
#             # Task classification  Loss
#             src_cls_loss = self.cross_entropy(src_pred.squeeze(), src_y)
#
#             # Domain classification loss
#             # source
#             src_feat_reversed = ReverseLayerF.apply(src_feat, alpha)
#             src_domain_pred = self.domain_classifier(src_feat_reversed)
#             src_domain_loss = self.cross_entropy(src_domain_pred, domain_label_src.long())
#
#             # target
#             trg_feat_reversed = ReverseLayerF.apply(trg_feat, alpha)
#             trg_domain_pred = self.domain_classifier(trg_feat_reversed)
#             trg_domain_loss = self.cross_entropy(trg_domain_pred, domain_label_trg.long())
#
#             # Total domain loss
#             domain_loss = src_domain_loss + trg_domain_loss
#
#             loss = self.hparams["src_cls_loss_wt"] * src_cls_loss + \
#                    self.hparams["domain_loss_wt"] * domain_loss
#
#             loss.backward()
#             self.optimizer.step()
#             self.optimizer_disc.step()
#
#             losses = {'Total_loss': loss.item(), 'Domain_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()}
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#         self.lr_scheduler.step()
#
#
# class AdvSKM(Algorithm):
#     """
#     AdvSKM: https://www.ijcai.org/proceedings/2021/0378.pdf
#     """
#
#     def __init__(self, backbone, configs, hparams, device):
#         super().__init__(configs, backbone)
#
#         # optimizer and scheduler
#         self.optimizer = torch.optim.Adam(
#             self.network.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         # hparams
#         self.hparams = hparams
#         # device
#         self.device = device
#
#         # Aligment losses
#         self.mmd_loss = MMD_loss()
#         self.AdvSKM_embedder = AdvSKM_Disc(configs).to(device)
#         self.optimizer_disc = torch.optim.Adam(
#             self.AdvSKM_embedder.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
#
#         # Construct Joint Loaders
#         joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
#         for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
#             src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(
#                 self.device)  # extract source features
#
#             src_feat = self.feature_extractor(src_x)
#             src_pred = self.classifier(src_feat)
#
#             # extract target features
#             trg_feat = self.feature_extractor(trg_x)
#
#             source_embedding_disc = self.AdvSKM_embedder(src_feat.detach())
#             target_embedding_disc = self.AdvSKM_embedder(trg_feat.detach())
#             mmd_loss = - self.mmd_loss(source_embedding_disc, target_embedding_disc)
#             mmd_loss.requires_grad = True
#
#             # update discriminator
#             self.optimizer_disc.zero_grad()
#             mmd_loss.backward()
#             self.optimizer_disc.step()
#
#             # calculate source classification loss
#             src_cls_loss = self.cross_entropy(src_pred, src_y)
#
#             # domain loss.
#             source_embedding_disc = self.AdvSKM_embedder(src_feat)
#             target_embedding_disc = self.AdvSKM_embedder(trg_feat)
#
#             mmd_loss_adv = self.mmd_loss(source_embedding_disc, target_embedding_disc)
#             mmd_loss_adv.requires_grad = True
#
#             # calculate the total loss
#             loss = self.hparams["domain_loss_wt"] * mmd_loss_adv + \
#                    self.hparams["src_cls_loss_wt"] * src_cls_loss
#
#             # update optimizer
#             self.optimizer.zero_grad()
#             loss.backward()
#             self.optimizer.step()
#
#             losses = {'Total_loss': loss.item(), 'MMD_loss': mmd_loss_adv.item(), 'Src_cls_loss': src_cls_loss.item()}
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#         self.lr_scheduler.step()
#
#
# class SASA(Algorithm):
#
#     def __init__(self, backbone, configs, hparams, device):
#         super().__init__(configs, backbone)
#
#         # feature_length for classifier
#         configs.features_len = 1
#         self.classifier = classifier(configs)
#         # feature length for feature extractor
#         configs.features_len = 1
#         self.feature_extractor = CNN_ATTN(configs)
#         self.network = nn.Sequential(self.feature_extractor, self.classifier)
#
#         # optimizer and scheduler
#         self.optimizer = torch.optim.Adam(
#             self.network.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         # hparams
#         self.hparams = hparams
#         # device
#         self.device = device
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
#
#         # Construct Joint Loaders
#         joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
#         for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
#             src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(
#                 self.device)  # extract source features
#
#             # Extract features
#             src_feature = self.feature_extractor(src_x)
#             tgt_feature = self.feature_extractor(trg_x)
#
#             # source classification loss
#             y_pred = self.classifier(src_feature)
#             src_cls_loss = self.cross_entropy(y_pred, src_y)
#
#             # MMD loss
#             domain_loss_intra = self.mmd_loss(src_struct=src_feature,
#                                               tgt_struct=tgt_feature, weight=self.hparams['domain_loss_wt'])
#
#             # total loss
#             total_loss = self.hparams['src_cls_loss_wt'] * src_cls_loss + domain_loss_intra
#
#             # remove old gradients
#             self.optimizer.zero_grad()
#             # calculate gradients
#             total_loss.backward()
#             # update the weights
#             self.optimizer.step()
#
#             losses = {'Total_loss': total_loss.item(), 'MMD_loss': domain_loss_intra.item(),
#                       'Src_cls_loss': src_cls_loss.item()}
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#         self.lr_scheduler.step()
#
#     def mmd_loss(self, src_struct, tgt_struct, weight):
#         delta = torch.mean(src_struct - tgt_struct, dim=-2)
#         loss_value = torch.norm(delta, 2) * weight
#         return loss_value
#
#
# class CoTMix(Algorithm):
#     def __init__(self, backbone, configs, hparams, device):
#         super().__init__(configs, backbone)
#
#         # optimizer and scheduler
#         self.optimizer = torch.optim.Adam(
#             self.network.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         self.lr_scheduler = StepLR(self.optimizer, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         # hparams
#         self.hparams = hparams
#         # device
#         self.device = device
#
#         # Aligment losses
#         self.contrastive_loss = NTXentLoss(device, hparams["batch_size"], 0.2, True)
#         self.entropy_loss = ConditionalEntropyLoss()
#         self.sup_contrastive_loss = SupConLoss(device)
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
#
#         # Construct Joint Loaders
#         joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
#         for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
#             src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(
#                 self.device)  # extract source features
#
#             # ====== Temporal Mixup =====================
#             src_dominant, trg_dominant = self.temporal_mixup(src_x, trg_x)
#
#             # ====== Source =====================
#             self.optimizer.zero_grad()
#
#             # Src original features
#             src_orig_feat = self.feature_extractor(src_x)
#             src_orig_logits = self.classifier(src_orig_feat)
#
#             # Target original features
#             trg_orig_feat = self.feature_extractor(trg_x)
#             trg_orig_logits = self.classifier(trg_orig_feat)
#
#             # -----------  The two main losses
#             # Cross-Entropy loss
#             src_cls_loss = self.cross_entropy(src_orig_logits, src_y)
#             loss = src_cls_loss * round(self.hparams["src_cls_weight"], 2)
#
#             # Target Entropy loss
#             trg_entropy_loss = self.entropy_loss(trg_orig_logits)
#             loss += trg_entropy_loss * round(self.hparams["trg_entropy_weight"], 2)
#
#             # -----------  Auxiliary losses
#             # Extract source-dominant mixup features.
#             src_dominant_feat = self.feature_extractor(src_dominant)
#             src_dominant_logits = self.classifier(src_dominant_feat)
#
#             # supervised contrastive loss on source domain side
#             src_concat = torch.cat([src_orig_logits.unsqueeze(1), src_dominant_logits.unsqueeze(1)], dim=1)
#             src_supcon_loss = self.sup_contrastive_loss(src_concat, src_y)
#             loss += src_supcon_loss * round(self.hparams["src_supCon_weight"], 2)
#
#             # Extract target-dominant mixup features.
#             trg_dominant_feat = self.feature_extractor(trg_dominant)
#             trg_dominant_logits = self.classifier(trg_dominant_feat)
#
#             # Unsupervised contrastive loss on target domain side
#             trg_con_loss = self.contrastive_loss(trg_orig_logits, trg_dominant_logits)
#             loss += trg_con_loss * round(self.hparams["trg_cont_weight"], 2)
#
#             loss.backward()
#             self.optimizer.step()
#
#             losses = {'Total_loss': loss.item(),
#                       'src_cls_loss': src_cls_loss.item(),
#                       'trg_entropy_loss': trg_entropy_loss.item(),
#                       'src_supcon_loss': src_supcon_loss.item(),
#                       'trg_con_loss': trg_con_loss.item()
#                       }
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#         self.lr_scheduler.step()
#
#     def temporal_mixup(self, src_x, trg_x):
#
#         mix_ratio = round(self.hparams["mix_ratio"], 2)
#         temporal_shift = self.hparams["temporal_shift"]
#         h = temporal_shift // 2  # half
#
#         src_dominant = mix_ratio * src_x + (1 - mix_ratio) * \
#                        torch.mean(torch.stack([torch.roll(trg_x, -i, 2) for i in range(-h, h)], 2), 2)
#
#         trg_dominant = mix_ratio * trg_x + (1 - mix_ratio) * \
#                        torch.mean(torch.stack([torch.roll(src_x, -i, 2) for i in range(-h, h)], 2), 2)
#
#         return src_dominant, trg_dominant
#
#
# # Untied Approaches: (MCD)
# class MCD(Algorithm):
#     """
#     Maximum Classifier Discrepancy for Unsupervised Domain Adaptation
#     MCD: https://arxiv.org/pdf/1712.02560.pdf
#     """
#
#     def __init__(self, backbone, configs, hparams, device):
#         super().__init__(configs, backbone)
#
#         self.feature_extractor = backbone(configs)
#         self.classifier = classifier(configs)
#         self.classifier2 = classifier(configs)
#
#         self.network = nn.Sequential(self.feature_extractor, self.classifier)
#
#         # optimizer and scheduler
#         self.optimizer_fe = torch.optim.Adam(
#             self.feature_extractor.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         # optimizer and scheduler
#         self.optimizer_c1 = torch.optim.Adam(
#             self.classifier.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#         # optimizer and scheduler
#         self.optimizer_c2 = torch.optim.Adam(
#             self.classifier2.parameters(),
#             lr=hparams["learning_rate"],
#             weight_decay=hparams["weight_decay"]
#         )
#
#         self.lr_scheduler_fe = StepLR(self.optimizer_fe, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         self.lr_scheduler_c1 = StepLR(self.optimizer_c1, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#         self.lr_scheduler_c2 = StepLR(self.optimizer_c2, step_size=hparams['step_size'], gamma=hparams['lr_decay'])
#
#         # hparams
#         self.hparams = hparams
#         # device
#         self.device = device
#
#         # Aligment losses
#         self.mmd_loss = MMD_loss()
#
#     def update(self, src_loader, trg_loader, avg_meter, logger):
#         # defining best and last model
#         best_src_risk = float('inf')
#         best_model = None
#
#         for epoch in range(1, self.hparams["num_epochs"] + 1):
#
#             # source pretraining loop
#             self.pretrain_epoch(src_loader, avg_meter)
#
#             # training loop
#             self.training_epoch(src_loader, trg_loader, avg_meter, epoch)
#
#             # saving the best model based on src risk
#             if (epoch + 1) % 10 == 0 and avg_meter['Src_cls_loss'].avg < best_src_risk:
#                 best_src_risk = avg_meter['Src_cls_loss'].avg
#                 best_model = deepcopy(self.network.state_dict())
#
#             logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]')
#             for key, val in avg_meter.items():
#                 logger.debug(f'{key}\t: {val.avg:2.4f}')
#             logger.debug(f'-------------------------------------')
#
#         last_model = self.network.state_dict()
#
#         return last_model, best_model
#
#     def pretrain_epoch(self, src_loader, avg_meter):
#         for src_x, src_y in src_loader:
#             src_x, src_y = src_x.to(self.device), src_y.to(self.device)
#
#             src_feat = self.feature_extractor(src_x)
#             src_pred1 = self.classifier(src_feat)
#             src_pred2 = self.classifier2(src_feat)
#
#             src_cls_loss1 = self.cross_entropy(src_pred1, src_y)
#             src_cls_loss2 = self.cross_entropy(src_pred2, src_y)
#
#             loss = src_cls_loss1 + src_cls_loss2
#
#             self.optimizer_c1.zero_grad()
#             self.optimizer_c2.zero_grad()
#             self.optimizer_fe.zero_grad()
#
#             loss.backward()
#
#             self.optimizer_c1.step()
#             self.optimizer_c2.step()
#             self.optimizer_fe.step()
#
#             losses = {'Src_cls_loss': loss.item()}
#
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch):
#
#         # Construct Joint Loaders
#         joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
#
#         for step, ((src_x, src_y), (trg_x, _)) in joint_loader:
#             src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(
#                 self.device)  # extract source features
#
#             # extract source features
#             src_feat = self.feature_extractor(src_x)
#             src_pred1 = self.classifier(src_feat)
#             src_pred2 = self.classifier2(src_feat)
#
#             # source losses
#             src_cls_loss1 = self.cross_entropy(src_pred1, src_y)
#             src_cls_loss2 = self.cross_entropy(src_pred2, src_y)
#             loss_s = src_cls_loss1 + src_cls_loss2
#
#             # Freeze the feature extractor
#             for k, v in self.feature_extractor.named_parameters():
#                 v.requires_grad = False
#             # update C1 and C2 to maximize their difference on target sample
#             trg_feat = self.feature_extractor(trg_x)
#             trg_pred1 = self.classifier(trg_feat.detach())
#             trg_pred2 = self.classifier2(trg_feat.detach())
#
#             loss_dis = self.discrepancy(trg_pred1, trg_pred2)
#
#             loss = loss_s - loss_dis
#
#             loss.backward()
#             self.optimizer_c1.step()
#             self.optimizer_c2.step()
#
#             self.optimizer_c1.zero_grad()
#             self.optimizer_c2.zero_grad()
#             self.optimizer_fe.zero_grad()
#
#             # Freeze the classifiers
#             for k, v in self.classifier.named_parameters():
#                 v.requires_grad = False
#             for k, v in self.classifier2.named_parameters():
#                 v.requires_grad = False
#                 # Freeze the feature extractor
#             for k, v in self.feature_extractor.named_parameters():
#                 v.requires_grad = True
#             # update feature extractor to minimize the discrepaqncy on target samples
#             trg_feat = self.feature_extractor(trg_x)
#             trg_pred1 = self.classifier(trg_feat)
#             trg_pred2 = self.classifier2(trg_feat)
#
#             loss_dis_t = self.discrepancy(trg_pred1, trg_pred2)
#             domain_loss = self.hparams["domain_loss_wt"] * loss_dis_t
#
#             domain_loss.backward()
#             self.optimizer_fe.step()
#
#             self.optimizer_fe.zero_grad()
#             self.optimizer_c1.zero_grad()
#             self.optimizer_c2.zero_grad()
#
#             losses = {'Total_loss': loss.item(), 'MMD_loss': domain_loss.item()}
#
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#         self.lr_scheduler_fe.step()
#         self.lr_scheduler_c1.step()
#         self.lr_scheduler_c2.step()
#
#     def discrepancy(self, out1, out2):
#
#         return torch.mean(torch.abs(F.softmax(out1) - F.softmax(out2)))

# Updated SASA

In [1]:
# class SASA(Algorithm):
#
#     def __init__(self, backbone, configs, device):
#         super().__init__(configs, backbone)
#
#         # feature_length for classifier
#         configs.features_len = 1
#         self.classifier = classifier(configs)
#
#         # feature length for feature extractor
#         configs.features_len = 1
#         self.feature_extractor = GTN(configs)
#         self.network = nn.Sequential(self.feature_extractor, self.classifier)
#
#         # optimizer and scheduler
#         self.optimizer = torch.optim.Adam(
#             self.network.parameters(),
#             lr=configs.learning_rate,
#             weight_decay=configs.weight_decay
#         )
#         self.lr_scheduler = StepLR(self.optimizer, step_size=configs.step_size, gamma=configs.lr_decay)
#
#         self.configs = configs
#         self.device = device
#
#     def training_epoch(self, src_loader, trg_loader, avg_meter, epoch, stage='train'):
#
#         # Construct Joint Loaders
#         joint_loader = enumerate(zip(src_loader, itertools.cycle(trg_loader)))
#         for step, ((src_x, src_y), (trg_x, _)) in tqdm.tqdm(joint_loader, desc='Training', total=len(src_loader)):
#             src_x, src_y, trg_x = src_x.to(self.device), src_y.to(self.device), trg_x.to(
#                 self.device)  # extract source features
#
#             # Extract features
#             src_feature = self.feature_extractor(stage, src_x)
#             tgt_feature = self.feature_extractor(stage, trg_x)
#
#             # source classification loss
#             y_pred = self.classifier(src_feature)
#             src_cls_loss = self.cross_entropy(y_pred, src_y)
#
#             # MMD loss
#             domain_loss_intra = self.mmd_loss(src_struct=src_feature,
#                                               tgt_struct=tgt_feature, weight=self.configs.sasa_domain_loss_wt)
#
#             # total loss
#             total_loss = self.configs.sasa_src_cls_loss_wt * src_cls_loss + domain_loss_intra
#
#             # remove old gradients
#             self.optimizer.zero_grad()
#             # calculate gradients
#             total_loss.backward()
#             # update the weights
#             self.optimizer.step()
#
#             losses = {'Total_loss': total_loss.item(), 'MMD_loss': domain_loss_intra.item(),
#                       'Src_cls_loss': src_cls_loss.item()}
#             for key, val in losses.items():
#                 avg_meter[key].update(val, 32)
#
#         self.lr_scheduler.step()
#
#     def mmd_loss(self, src_struct, tgt_struct, weight):
#         delta = torch.mean(src_struct - tgt_struct, dim=-2)
#         loss_value = torch.norm(delta, 2) * weight
#
#         return loss_value