Model & Method

In [1]:
import torch.nn as nn
import torch
from torch.autograd import Function
import torch.nn.functional as F
from torch.functional import tensordot
import numpy as np

# ========== 新增交叉注意力模块 ==========
class CrossAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.scale = dim ** -0.5

    def forward(self, x1, x2):
        q = self.query(x1)
        k = self.key(x2)
        v = self.value(x2)
        attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)
        return out + x1  # 残差连接


class CNN(nn.Module):
    def __init__(self, configs):
        super(CNN, self).__init__()
        self.modality_num = configs.modality_nums  # 模态数
        configs = configs.model_configs['CNN']

        self.conv1_blocks = nn.ModuleList([nn.Sequential(
            nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=configs.kernel_size,
                      stride=configs.stride, bias=False, padding=(configs.kernel_size // 2)),
            nn.BatchNorm1d(configs.mid_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
            nn.Dropout(configs.dropout)
        ) for i in range(self.modality_num)])

        self.conv2_blocks = nn.ModuleList([nn.Sequential(
            nn.Conv1d(configs.mid_channels, configs.mid_channels * 2,
                      kernel_size=8, stride=1, bias=False, padding=4),
            nn.BatchNorm1d(configs.mid_channels * 2),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
        ) for i in range(self.modality_num)])

        self.conv3_blocks = nn.ModuleList([nn.Sequential(
            nn.Conv1d(configs.mid_channels * 2, configs.final_out_channels, kernel_size=8, stride=1, bias=False,
                      padding=4),
            nn.BatchNorm1d(configs.final_out_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
        ) for i in range(self.modality_num)])

        self.adaptive_pools = nn.ModuleList([nn.AdaptiveAvgPool1d(
            configs.features_len) for i in range(self.modality_num)])

    def forward(self, k_x_in):
        '''x: [k*(N, 3, 300)]
            系列卷积不论dim=-1为多少，128，200，300输出都是128
        '''
        k_out = []
        for modality_idx in range(self.modality_num):
            x = k_x_in[modality_idx]
            x = self.conv1_blocks[modality_idx](x)
            # print(f"conv1 x_out: {x.shape}")
            x = self.conv2_blocks[modality_idx](x)
            # print(f"conv2 x_out: {x.shape}")
            x = self.conv3_blocks[modality_idx](x)
            # print(f"conv3 x_out: {x.shape}")
            x = self.adaptive_pools[modality_idx](x)
            # print(f"adaptive x_out: {x.shape}")
            x_flat = x.reshape(x.shape[0], 1, -1)
            # print(f"sensor conv flat: {x_flat.shape}")  # (N, 1, 128)
            k_out.append(x_flat)

        k_out = torch.cat(k_out, dim=1)
        # print(f"conv sensors cat: {k_out.shape}")  # (N, k, 128)
        # print(f"conv out: {k_out.shape}")  # (N, k, 128)
        return k_out


class c_fusion(nn.Module):
    def __init__(self, configs) -> None:
        super(c_fusion, self).__init__()
        self.modality_num = configs.modality_nums
        self.fusion_cfg = configs.model_configs['fusion']
        self.d_W1 = nn.Parameter(torch.normal(
            mean=0, std=0.1, size=[self.fusion_cfg.final_out_channels, 1], requires_grad=True))
        self.d_b1 = nn.Parameter(torch.normal(mean=0, std=0.1, size=[
                                 1, self.modality_num, 1]), requires_grad=True)  # (1, k, 1)
        self.d_w1 = nn.Parameter(torch.normal(
            mean=0, std=0.1, size=[self.modality_num, 1]), requires_grad=True)

    def forward(self, x):
        '''x: (N, k, len)'''
        # print(f"attn x_unsqueeze: {x.shape}")  # (N, k, 128)

        MLP_input = tensordot(x, self.d_W1, dims=1)
        MLP_input += self.d_b1
        # print(f"MLP input: {MLP_input.shape}")
        miu = torch.tanh(MLP_input)  # [BN, k, 1]
        softmax_input = tensordot(miu, self.d_w1, dims=1)
        # print(f"softmax_input: {softmax_input.shape}")
        alpha = F.softmax(softmax_input, dim=1)  # [BN, k, 1]
        f_per_sensor = alpha * x  # [BN, 1, 64]
        # print(f"f_modality: {f_per_sensor.shape}")
        out = torch.sum(f_per_sensor, dim=1)  # [BN, len]
        # print(f"c_fusion: {out.shape}")
        return out


class FE(nn.Module):
    def __init__(self, configs) -> None:
        '''input:
                x_k_t: [k*(N, 3, 300)]
                x_k_f: [k*(N, 3, 151)]
            return:
                merge_tf: (N, 2*128)'''
        super(FE, self).__init__()
        self.cnn_t = CNN(configs=configs)
        self.fusion_t = c_fusion(configs=configs)
        self.cnn_f = CNN(configs=configs)
        self.fusion_f = c_fusion(configs=configs)
        # ========== 新增模块 ==========
        # self.cross_attention_t = CrossAttention(128)  # 时域交叉注意力
        # self.cross_attention_f = CrossAttention(128)  # 频域交叉注意力
        self.contrastive_proj = nn.Sequential(        # 对比学习投影层
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )
        
    def forward(self, x_k_t, x_k_f):
        x_k_t = self.cnn_t(x_k_t)
        merge_t = self.fusion_t(x_k_t)  # (N, 300)

        x_k_f = self.cnn_f(x_k_f)  # (N, k, 151)
        merge_f = self.fusion_f(x_k_f)

        # print(x_k_f.shape)  # (N, k, 151)
        # print(merge_f.shape)  # (N, 128)

        # # ========== 新增交叉交互 ==========
        # merge_t = self.cross_attention_t(merge_t, merge_f)  # 时域融合频域信息
        # merge_f = self.cross_attention_f(merge_f, merge_t)  # 频域融合时域信息
        
        # 合并时频特征
        merge_tf = torch.cat([merge_t, merge_f], dim=1)
        return merge_tf , merge_t, merge_f  # 返回三个特征用于对比学习


class Classifier(nn.Module):
    def __init__(self, configs):
        super(Classifier, self).__init__()
        label_num = configs.num_classes
        configs = configs.model_configs['Classifier']
        model_output_dim = configs.features_len
        self.hidden_dim = configs.hidden_dim
        self.logits = nn.Sequential(
            nn.Linear(model_output_dim * configs.final_out_channels * 2, self.hidden_dim * 2),  # 2 为 temporal + frequency
            nn.ReLU(),
            nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, label_num),
            nn.LogSoftmax(dim=-1),
        )

    def forward(self, x_in):
        '''x: (N, 2*128)'''
        # print(f"Codats_Classifier x_in: {x_in.shape}")
        predictions = self.logits(x_in)
        return predictions


class Discriminator(nn.Module):
    def __init__(self, configs):
        """Init discriminator."""
        super(Discriminator, self).__init__()
        configs = configs.model_configs['Discriminator']
        self.layer = nn.Sequential(
            nn.Linear(configs.features_len * configs.final_out_channels * 2,
                      configs.disc_hid_dim * 2),  # 2 为 temporal + frequency
            nn.ReLU(),
            nn.Linear(configs.disc_hid_dim * 2, configs.disc_hid_dim),
            nn.ReLU(),
            nn.Linear(configs.disc_hid_dim, 2),
            nn.LogSoftmax(dim=1),
        )

    def forward(self, input):
        """Forward the discriminator.
            x: (N, k*len)"""
        out = self.layer(input)
        return out


class ReverseLayerF(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha
        return output, None






class ATFA(nn.Module):
    """
    CoDATS: https://arxiv.org/pdf/2005.10996.pdf
    """

    def __init__(self, configs, hparams, device):
        super(ATFA, self).__init__()
        self.configs = configs
        self.cross_entropy = nn.CrossEntropyLoss()
        self.channel_num = configs.channel_nums  # 不同模态的通道数e.g. [3,3,1]
        self.feature_extractor = FE(configs=configs)
        self.classifier_tf = Classifier(configs)
        self.domain_discriminator_tf = Discriminator(configs)

        self.optimizer = torch.optim.Adam(
            list(self.feature_extractor.parameters()) +
            list(self.classifier_tf.parameters()),
            lr=hparams["learning_rate"],
            weight_decay=hparams["weight_decay"], betas=(0.5, 0.99)
        )
        self.optimizer_disc = torch.optim.Adam(
            self.domain_discriminator_tf.parameters(),
            lr=hparams["learning_rate"],
            weight_decay=hparams["weight_decay"], betas=(0.5, 0.99)
        )
        self.hparams = hparams
        self.device = device

    def get_f_x(self, x):
        x_f = torch.fft.rfft(x)
        return x_f.real


    def update(self, src_x, src_y, trg_x, step, epoch, len_dataloader):
        '''x: (N, 6, 128)
            y: (N,)'''
        p = float(step + epoch * len_dataloader) / \
            self.hparams["num_epochs"] + 1 / len_dataloader
        alpha = 2. / (1. + np.exp(-10 * p)) - 1
    
        # split_modality_x ------------------------------------------------------
        modality_src_x_t = torch.split(src_x, split_size_or_sections=self.channel_num, dim=1)
        modality_trg_x_t = torch.split(trg_x, split_size_or_sections=self.channel_num, dim=1)
    
        # get t/f x -------------------------------------------------------------
        modality_src_x_f = [self.get_f_x(src_x_t) for src_x_t in modality_src_x_t]
        modality_trg_x_f = [self.get_f_x(trg_x_t) for trg_x_t in modality_trg_x_t]
    
        # zero grad -------------------------------------------------------------
        self.optimizer.zero_grad()
        self.optimizer_disc.zero_grad()
    
        # domain label ----------------------------------------------------------
        domain_label_src = torch.ones(len(src_x)).to(self.device)
        domain_label_trg = torch.zeros(len(trg_x)).to(self.device)
    
        # ========== 核心修改点 ==========
        # src features ----------------------------------------------------------
        src_feat_tf, src_t, src_f = self.feature_extractor(modality_src_x_t, modality_src_x_f)  # 修改点1：获取时频单独特征
        src_pred = self.classifier_tf(src_feat_tf)
    
        # trg features ----------------------------------------------------------
        trg_feat_tf, trg_t, trg_f = self.feature_extractor(modality_trg_x_t, modality_trg_x_f)
    
        # ========== 新增对比损失计算 ==========
        src_proj_t = self.feature_extractor.contrastive_proj(src_t)  # 修改点2：投影时域特征
        src_proj_f = self.feature_extractor.contrastive_proj(src_f)  # 修改点3：投影频域特征
        contrastive_loss = self.nt_xent_loss(src_proj_t, src_proj_f)  # 修改点4：计算对比损失
    
        # Task classification Loss ----------------------------------------------
        src_cls_loss = self.cross_entropy(src_pred.squeeze(), src_y)
    
        # Domain classification loss --------------------------------------------
        # source
        src_feat_reversed = ReverseLayerF.apply(src_feat_tf, alpha)
        src_domain_pred = self.domain_discriminator_tf(src_feat_reversed)
        src_domain_loss = self.cross_entropy(src_domain_pred, domain_label_src.long())
    
        # target
        trg_feat_reversed = ReverseLayerF.apply(trg_feat_tf, alpha)
        trg_domain_pred = self.domain_discriminator_tf(trg_feat_reversed)
        trg_domain_loss = self.cross_entropy(trg_domain_pred, domain_label_trg.long())
    
        # Total domain loss -----------------------------------------------------
        domain_loss = src_domain_loss + trg_domain_loss
    
        # ========== 修改总损失公式 ==========
        loss = (
            self.hparams["src_cls_loss_wt"] * src_cls_loss 
            + self.hparams["domain_loss_wt"] * domain_loss
            + self.hparams["contrastive_loss_wt"] * contrastive_loss  # 修改点5：添加加权对比损失
        )
    
        # Backpropagation -------------------------------------------------------
        loss.backward()
        self.optimizer.step()
        self.optimizer_disc.step()
    
        return {'Total_loss': loss.item(), 'Domain_loss': domain_loss.item(), 'Src_cls_loss': src_cls_loss.item()}

    def test_batch(self, trg_x):
        modality_trg_x_t = torch.split(
            trg_x, split_size_or_sections=self.channel_num, dim=1)
        modality_trg_x_f = [self.get_f_x(trg_x_t)
                            for trg_x_t in modality_trg_x_t]

        # 关键修复：解包返回值的三个元素，只取第一个
        trg_feat_tf, _, _ = self.feature_extractor(modality_trg_x_t, modality_trg_x_f)

        trg_pred = self.classifier_tf(trg_feat_tf)
        return trg_pred
    # ========== 新增对比损失函数 ==========
    def nt_xent_loss(self, z1, z2, temperature=0.07):
        batch_size = z1.size(0)
        z = torch.cat([z1, z2], dim=0)
        z = F.normalize(z, dim=1)
        sim_matrix = torch.mm(z, z.T) / temperature
        mask = torch.eye(2*batch_size, device=z.device).bool()
        sim_matrix = sim_matrix.masked_fill(mask, -1e12)
        labels = torch.arange(2*batch_size, device=z.device)
        labels = (labels + batch_size) % (2*batch_size)
        return F.cross_entropy(sim_matrix, labels)

Data & Utils

In [2]:
import wandb
from kaggle_secrets import UserSecretsClient
from sklearn.metrics import classification_report, accuracy_score
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms

from sklearn.model_selection import train_test_split

import os
import numpy as np
import random


def split_valid(dataset):
    data = dataset["samples"]
    label = dataset["labels"]
    train_data, valid_data, train_labels, valid_labels = train_test_split(
        data, label, test_size=0.25, shuffle=True, random_state=42)
    train_dataset = {
        "samples": train_data,
        "labels": train_labels,
    }
    valid_dataset = {
        "samples": valid_data,
        "labels": valid_labels,
    }
    return train_dataset, valid_dataset


class Load_Dataset(Dataset):
    def __init__(self, dataset, normalize):
        super(Load_Dataset, self).__init__()

        data = dataset["samples"]
        label = dataset["labels"]

        if len(data.shape) < 3:
            data = data.unsqueeze(2)

        if isinstance(data, np.ndarray):
            data = torch.from_numpy(data)
            label = torch.from_numpy(label).long()

        # make sure the Channels in second dim
        if data.shape.index(min(data.shape[1], data.shape[2])) != 1:
            # 数最小（channel）的dim不是1，要将channel转换到dim1来
            data = data.permute(0, 2, 1)  # (N, 128, 6)=>(N, 6, 128)

        self.data = data
        self.label = label

        self.num_channels = data.shape[1]

        if normalize:
            # Assume datashape: num_samples, num_channels, seq_length
            data_mean = torch.FloatTensor(self.num_channels).fill_(
                0).tolist()  # assume min= number of channels
            data_std = torch.FloatTensor(self.num_channels).fill_(
                1).tolist()  # assume min= number of channels
            data_transform = transforms.Normalize(mean=data_mean, std=data_std)
            self.transform = data_transform
        else:
            self.transform = None

        self.len = data.shape[0]

    def __getitem__(self, index):
        if self.transform is not None:
            output = self.transform(
                self.data[index].view(self.num_channels, -1, 1))
            self.data[index] = output.view(self.data[index].shape)

        return self.data[index].float(), self.label[index].long()

    def __len__(self):
        return self.len


def data_generator(data_path, domain_id, dataset_configs, hparams):
    """
        Args:
            data_path      : 数据文件夹位置
            domain_id      : 用户的id
            dataset_configs: 数据集配置
            hparams        : 超参数配置
    """
    # loading path
    train_dataset = torch.load(os.path.join(
        data_path, "train_" + str(domain_id) + ".pt"))
    test_dataset = torch.load(os.path.join(
        data_path, "test_" + str(domain_id) + ".pt"))
    train_dataset, valid_dataset = split_valid(train_dataset)

    # Loading datasets
    train_dataset = Load_Dataset(train_dataset, dataset_configs.normalize)
    valid_dataset = Load_Dataset(valid_dataset, dataset_configs.normalize)
    test_dataset = Load_Dataset(test_dataset, dataset_configs.normalize)

    # Dataloaders
    batch_size = hparams["batch_size"]
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
                              shuffle=True, drop_last=True, num_workers=2)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size,
                              shuffle=True, drop_last=True, num_workers=2)

    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
                             shuffle=False, drop_last=dataset_configs.drop_last, num_workers=2)
    return train_loader, valid_loader, test_loader


def fix_randomness(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


"""
  一个epoch内，记录不同batch的测试指标，并返回平均值
"""


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        """n是batch增加的个数"""
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def _calc_metrics(pred_labels, true_labels, target_names):
    """
        在log_dir下保存.csv，记录衡量指标，返回百分制的acc和maf1
    """
    # print(pred_labels.shape)
    pred_labels = np.array(pred_labels).astype(int)
    true_labels = np.array(true_labels).astype(int)

    # precisiion, recall, f1, support
    r = classification_report(
        true_labels, pred_labels, labels=range(len(target_names)), target_names=target_names, digits=6, output_dict=True)
    df = pd.DataFrame(r)

    # acc
    accuracy = accuracy_score(true_labels, pred_labels)
    df["accuracy"] = accuracy  # 一列都是重复的值，没有办法

    # 转换为百分制
    df = df * 100

    # 保存结果
#     file_name = "classification_report.csv"
#     report_Save_path = os.path.join(log_dir, file_name)
#     df.to_csv(report_Save_path)
#     df.to_excel(report_Save_path+".xlsx")

    return accuracy * 100, r["macro avg"]["f1-score"] * 100


class EarlyStopping:
    def __init__(self, patience=7, verbose=True, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model=None, path=None):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            print(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).\n')
            self.val_loss_min = val_loss
        elif score <= self.best_score + self.delta:
            self.counter += 1
            print(
                f'EarlyStopping counter: {self.counter} out of {self.patience}\n')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0

    def refresh(self):
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf


user_secrets = UserSecretsClient()

# I have saved my API token with "wandb_api" as Label.
# If you use some other Label make sure to change the same below.
wandb_api = user_secrets.get_secret(
    "wandb_api")  # Add-ons => Secrets 选中 keys 才能用

wandb.login(key=wandb_api)


[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

Trainer

In [3]:
import collections
from tqdm import tqdm
import torch
import os
import numpy as np
import torch.nn.functional as F
import pandas as pd
import wandb


class trainer():
    def __init__(self, args) -> None:
        self.method_cls = args.method_cls
        self.data_path = args.data_path
        self.ds_config = args.ds_configs
        self.hparams = args.hparams
        self.method_name = args.method_name
        self.dataset_name = args.dataset_name
        self.default_hparams = {
            **self.hparams.alg_hparams[self.method_name],
            **self.hparams.train_params,
        }
        self.es = EarlyStopping()
        self.num_runs = args.num_runs  # 每个任务重复次数（不同seed）
        self.device = torch.device(args.device)
        # 记录结果

    def get_dataloader(self, uid):
        """
            trian_dl, valid_dl, test_dl
        """
        return data_generator(self.data_path, uid, self.ds_config, self.default_hparams)

    def evaluate(self, valid=False):
        """
            验证模型性能
            
            return:
                metric: {'accuracy': str, 'maf1_score': str}
                loss: str
        """
        method = self.method.to(self.device).eval()

        total_loss_ = []

        self.trg_pred_labels = []
        self.trg_true_labels = []

        if valid:
            dataloader = self.trg_valid_dl
        else:
            dataloader = self.trg_test_dl

        with torch.no_grad():
            for data, labels in dataloader:
                data = data.float().to(self.device)
                labels = labels.view((-1)).long().to(self.device)

                # forward pass
                predictions = method.test_batch(trg_x=data)

                # compute loss
                loss = F.cross_entropy(predictions, labels)
                total_loss_.append(loss.item())
                # get the index of the max log-probability
                pred = predictions.detach().argmax(dim=1)

                self.trg_pred_labels.append(pred.cpu())
                self.trg_true_labels.append(labels.cpu())

        self.trg_loss = torch.tensor(total_loss_).mean()  # average loss

        # n*[BN,]->[n*BN]
        self.trg_pred_labels = torch.concat(self.trg_pred_labels, dim=0)
        self.trg_true_labels = torch.concat(self.trg_true_labels, dim=0)

        # 计算结果
        self.acc, self.maf1 = _calc_metrics(self.trg_pred_labels, self.trg_true_labels,
                                            self.ds_config.class_names)

        self.run_metrics = {'accuracy': self.acc, 'maf1_score': self.maf1}

        return self.run_metrics, self.trg_loss

    def save_results(self):
        """
            将运行结果汇总生成csv，保存在实验目录下（scenario的上一层）
                            acc,    maf1
            scenario_run
        """
        self.all_scenario_results = {
            "acc": {},
            "maf1": {}
        }
        for run_name in self.metrics.keys():
            self.all_scenario_results["acc"][run_name] = self.metrics[run_name]['accuracy']
            self.all_scenario_results["maf1"][run_name] = self.metrics[run_name]['maf1_score']

        # all_run_accs = np.array(self.metrics['accuracy'])
        # all_run_maf1s = np.array(self.metrics['maf1_score'])

        # acc_mean = all_run_accs.mean()
        # acc_std = all_run_accs.std()

        # maf1_mean = all_run_maf1s.mean()
        # maf1_std = all_run_maf1s.std()

        # self.all_scenario_results["acc_mean"][self.tmp_scenario] = acc_mean
        # self.all_scenario_results["acc_std"][self.tmp_scenario] = acc_std
        # self.all_scenario_results["maf1_mean"][self.tmp_scenario] = maf1_mean
        # self.all_scenario_results["maf1_std"][self.tmp_scenario] = maf1_std

        # 所有scenario_run结果，保存在Exp_logs/Exp_name/exp_id下
        self.rs_df = pd.DataFrame(self.all_scenario_results)
        file_save_pth = os.path.join(
            "/kaggle/working/", f"{self.dataset_name}_{self.method_name}_all_scenario_run_results.csv")
        self.rs_df.to_csv(file_save_pth)
        return

    def train(self):
        scenarios = self.ds_config.scenarios
        self.metrics = {}
        for i in scenarios:
            self.tmp_scenario = f"{i[0]}_to_{i[1]}"
            src_id = i[0]
            trg_id = i[1]

            # 获取dataloader
            print(f'-----获取dataloader-----')
            self.src_train_dl, self.src_valid_dl, self.src_test_dl = self.get_dataloader(
                src_id)
            self.trg_train_dl, self.trg_valid_dl, self.trg_test_dl = self.get_dataloader(
                trg_id)

            log_flag = True

            for run_id in range(self.num_runs):
                seed = run_id
                if log_flag == True:
                    # 每个scenario只记录一次run（看收敛）
                    self.run = wandb.init(project='UDA',
                                          #                                  group='AdvSKM',
                                          tags=[
                                              f'seed_{seed}', self.tmp_scenario, self.dataset_name, self.method_name],
                                          #                                  job_type='train',
                                          reinit=True,
                                          )

                # 固定随机种子
                fix_randomness(seed)

                # 类实例化
                method = self.method_cls(
                    configs=self.ds_config, hparams=self.default_hparams, device=self.device).to(self.device)

                # AvgMeters
                avg_meters = collections.defaultdict(lambda: AverageMeter())

                # 训练
                for epoch in range(1, self.default_hparams["num_epochs"]+1):

                    joint_loaders = enumerate(
                        zip(self.src_train_dl, self.trg_train_dl))
                    len_dataloader = min(
                        len(self.src_train_dl), len(self.trg_train_dl))

                    method.train()
                    for step, ((src_x, src_y), (trg_x, _)) in tqdm(joint_loaders):
                        src_x, src_y, trg_x = src_x.float().to(self.device), src_y.long().to(
                            self.device), trg_x.float().to(self.device)

                        if self.method_name == "CoDATS" or self.method_name == 'CoDATS_tf' or self.method_name == 'ATFA':
                            loss_dict = method.update(
                                src_x=src_x, src_y=src_y, trg_x=trg_x, step=step, epoch=epoch, len_dataloader=len_dataloader)
                        else:
                            loss_dict = method.update(src_x, src_y, trg_x)

                        for key, val in loss_dict.items():
                            # batch loss，n=1
                            avg_meters[f"Loss/{key}"].update(val, n=1)

                    # 每个epoch输出结果
                    print(
                        f'------------ train epoch: {epoch} ------------------')
                    log_dict = {}
                    for key, val in avg_meters.items():
                        if "Loss/" in key:
                            print(f'{key}\t: {val.avg:2.4f}')
                            log_dict[key] = val.avg
                            if log_flag:
                                self.run.log(
                                    {f'train/{key}': val.avg}, step=epoch)
                    print(f'---------------------------------------------------')

                    # Valid
                    if epoch % self.default_hparams["valid_interval"] == 0:
                        print(
                            f'------------- valid epoch: {epoch} ------------------')
                        self.method = method
                        _, valid_loss = self.evaluate(
                            valid=True)  # run_metrics

                        self.es(val_loss=valid_loss)

                        if self.es.early_stop:
                            print("Early stopping")
                            break

                self.es.refresh()
                # Train Done, Test
                self.method = method
                metric, _ = self.evaluate()
                self.metrics[f"{self.tmp_scenario}_run_{run_id}"] = metric

            # scenario_run Done
        # scenario Done
        self.save_results()
    # end train
# end class


Configs & Main

In [None]:

import os
class CNN_configs(object):
    def __init__(self) -> None:
        super(CNN_configs, self).__init__()
        self.input_channels = 3  # 模态通道总数
        self.kernel_size = 5  # 第一层kernel_size
        self.stride = 1  # 第一层stride
        self.dropout = 0.5  # 第一层dropout

        self.mid_channels = 64  # 中间通道数
        self.final_out_channels = 128  # 最后一层输出通道数
        self.features_len = 1  # avgpool 平均结果长度


class Classifier_configs(object):
    def __init__(self) -> None:
        super(Classifier_configs, self).__init__()
        self.features_len = 1  # CNN avgpool 平均结果长度
        self.hidden_dim = 500
        self.final_out_channels = 128  # CNN 最后一层输出通道数


class Discriminator_configs(object):
    def __init__(self) -> None:
        super(Discriminator_configs, self).__init__()
        self.final_out_channels = 128  # CNN最后一层输出通道数
        self.features_len = 1  # CNN avgpool 平均结果长度
        self.disc_hid_dim = 64


class fusion_configs(object):
    def __init__(self) -> None:
        super(fusion_configs, self).__init__()
        self.final_out_channels = 128  # cnn_rs.shape[-1]
        self.hidden_size = 500


class PAMAP2_configs(object):  # HHAR dataset, SAMSUNG device.
    def __init__(self):
        super(PAMAP2_configs, self).__init__()
        self.sequence_len = 200
        self.num_users = 9  # 1~9

        self.scenarios = [("1", "7"), ("2", "5"), ("5", "7"),
                          ("6", "5"), ("7", "2")]  # select 5
        self.num_classes = 18  # 0~17
        self.class_names = ['lying' 'sitting', 'standing', 'walking', 'running',
                            'cycling', 'Nordic walking', 'watching TV', 'computer work', 'car driving',
                            'ascending stairs', 'descending stairs', 'vacuum cleaning', 'ironing', 'folding laundry',
                            'house cleaning', 'playing soccer', 'rope jumping']

        self.shuffle = True
        self.drop_last = True
        self.normalize = True
        self.modality_nums = 3 * 3  # 传感器模态数
        self.channel_nums = [3, 3, 3] + [3, 3, 3] + [3, 3, 3]  # 各个模态通道数

        self.model_configs = {
            'CNN': CNN_configs(),
            'fusion': fusion_configs(),
            'Classifier': Classifier_configs(),
            'Discriminator': Discriminator_configs(),
        }


# -------------------- 训练配置 ---------------------------
class PAMAP2_hparams():
    def __init__(self):
        super(PAMAP2_hparams, self).__init__()
        self.train_params = {
            'num_epochs': 100,  # 总训练轮数
            'batch_size': 128,  # 每个域的batch数
            'weight_decay': 1e-4,
            'valid_interval': 2,
        }
        self.alg_hparams = {
            'ATFA':       {'learning_rate': 0.0005,   'src_cls_loss_wt': 7.737,  'domain_loss_wt': 3.369,  'contrastive_loss_wt': 0.5,  },#'contrastive_loss_wt': 0.5,  # 新增参数
        }


class args():
    def __init__(self) -> None:
        self.method_cls = ATFA
        self.hparams = PAMAP2_hparams()
        self.ds_configs = PAMAP2_configs()
        self.data_path = "/kaggle/input/pamap2accgyromag/PAMAP2_data"
        self.num_runs = 5
        self.device = "cuda:0"
        self.method_name = "ATFA"
        self.dataset_name = "PAMAP2"


args = args()
trainer = trainer(args)
trainer.train()


-----获取dataloader-----


[34m[1mwandb[0m: Currently logged in as: [33mgdrcosg[0m ([33mgdrcosg-guet[0m). Use [1m`wandb login --relogin`[0m to force relogin


9it [00:03,  2.32it/s]

------------ train epoch: 1 ------------------
Loss/Total_loss	: 25.7311
Loss/Domain_loss	: 1.3859
Loss/Src_cls_loss	: 2.4598
---------------------------------------------------



9it [00:01,  5.11it/s]

------------ train epoch: 2 ------------------
Loss/Total_loss	: 21.6077
Loss/Domain_loss	: 1.3808
Loss/Src_cls_loss	: 1.9741
---------------------------------------------------
------------- valid epoch: 2 ------------------





Validation loss decreased (inf --> 1.572308).



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  5.09it/s]

------------ train epoch: 3 ------------------
Loss/Total_loss	: 18.3663
Loss/Domain_loss	: 1.3712
Loss/Src_cls_loss	: 1.5839
---------------------------------------------------



9it [00:01,  4.98it/s]

------------ train epoch: 4 ------------------
Loss/Total_loss	: 16.0357
Loss/Domain_loss	: 1.3600
Loss/Src_cls_loss	: 1.3032
---------------------------------------------------
------------- valid epoch: 4 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  4.88it/s]

------------ train epoch: 5 ------------------
Loss/Total_loss	: 14.3789
Loss/Domain_loss	: 1.3547
Loss/Src_cls_loss	: 1.1035
---------------------------------------------------



9it [00:01,  4.98it/s]

------------ train epoch: 6 ------------------
Loss/Total_loss	: 13.1583
Loss/Domain_loss	: 1.3534
Loss/Src_cls_loss	: 0.9542
---------------------------------------------------
------------- valid epoch: 6 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  5.08it/s]

------------ train epoch: 7 ------------------
Loss/Total_loss	: 12.2450
Loss/Domain_loss	: 1.3537
Loss/Src_cls_loss	: 0.8437
---------------------------------------------------



9it [00:01,  5.08it/s]

------------ train epoch: 8 ------------------
Loss/Total_loss	: 11.4906
Loss/Domain_loss	: 1.3538
Loss/Src_cls_loss	: 0.7523
---------------------------------------------------
------------- valid epoch: 8 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


EarlyStopping counter: 1 out of 7



9it [00:01,  5.01it/s]

------------ train epoch: 9 ------------------
Loss/Total_loss	: 10.8970
Loss/Domain_loss	: 1.3542
Loss/Src_cls_loss	: 0.6813
---------------------------------------------------



9it [00:01,  4.83it/s]

------------ train epoch: 10 ------------------
Loss/Total_loss	: 10.3952
Loss/Domain_loss	: 1.3540
Loss/Src_cls_loss	: 0.6204
---------------------------------------------------
------------- valid epoch: 10 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  5.07it/s]

------------ train epoch: 11 ------------------
Loss/Total_loss	: 9.9852
Loss/Domain_loss	: 1.3538
Loss/Src_cls_loss	: 0.5719
---------------------------------------------------



9it [00:01,  5.04it/s]

------------ train epoch: 12 ------------------
Loss/Total_loss	: 9.6302
Loss/Domain_loss	: 1.3545
Loss/Src_cls_loss	: 0.5291
---------------------------------------------------
------------- valid epoch: 12 ------------------





EarlyStopping counter: 1 out of 7



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  5.09it/s]

------------ train epoch: 13 ------------------
Loss/Total_loss	: 9.3310
Loss/Domain_loss	: 1.3558
Loss/Src_cls_loss	: 0.4934
---------------------------------------------------



9it [00:01,  5.02it/s]

------------ train epoch: 14 ------------------
Loss/Total_loss	: 9.0636
Loss/Domain_loss	: 1.3567
Loss/Src_cls_loss	: 0.4615
---------------------------------------------------
------------- valid epoch: 14 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


EarlyStopping counter: 2 out of 7



9it [00:01,  4.89it/s]

------------ train epoch: 15 ------------------
Loss/Total_loss	: 8.8230
Loss/Domain_loss	: 1.3573
Loss/Src_cls_loss	: 0.4331
---------------------------------------------------



9it [00:01,  5.02it/s]

------------ train epoch: 16 ------------------
Loss/Total_loss	: 8.6092
Loss/Domain_loss	: 1.3581
Loss/Src_cls_loss	: 0.4075
---------------------------------------------------
------------- valid epoch: 16 ------------------





EarlyStopping counter: 3 out of 7



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  5.02it/s]

------------ train epoch: 17 ------------------
Loss/Total_loss	: 8.4258
Loss/Domain_loss	: 1.3591
Loss/Src_cls_loss	: 0.3857
---------------------------------------------------



9it [00:01,  5.01it/s]

------------ train epoch: 18 ------------------
Loss/Total_loss	: 8.2558
Loss/Domain_loss	: 1.3600
Loss/Src_cls_loss	: 0.3658
---------------------------------------------------
------------- valid epoch: 18 ------------------





EarlyStopping counter: 4 out of 7



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  4.90it/s]

------------ train epoch: 19 ------------------
Loss/Total_loss	: 8.1032
Loss/Domain_loss	: 1.3607
Loss/Src_cls_loss	: 0.3478
---------------------------------------------------



9it [00:01,  4.89it/s]

------------ train epoch: 20 ------------------
Loss/Total_loss	: 7.9652
Loss/Domain_loss	: 1.3613
Loss/Src_cls_loss	: 0.3316
---------------------------------------------------
------------- valid epoch: 20 ------------------





EarlyStopping counter: 5 out of 7



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  4.98it/s]

------------ train epoch: 21 ------------------
Loss/Total_loss	: 7.8360
Loss/Domain_loss	: 1.3619
Loss/Src_cls_loss	: 0.3166
---------------------------------------------------



9it [00:01,  5.01it/s]

------------ train epoch: 22 ------------------
Loss/Total_loss	: 7.7200
Loss/Domain_loss	: 1.3625
Loss/Src_cls_loss	: 0.3031
---------------------------------------------------
------------- valid epoch: 22 ------------------





EarlyStopping counter: 6 out of 7



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  5.04it/s]

------------ train epoch: 23 ------------------
Loss/Total_loss	: 7.6206
Loss/Domain_loss	: 1.3630
Loss/Src_cls_loss	: 0.2914
---------------------------------------------------



9it [00:01,  4.90it/s]

------------ train epoch: 24 ------------------
Loss/Total_loss	: 7.5250
Loss/Domain_loss	: 1.3635
Loss/Src_cls_loss	: 0.2801
---------------------------------------------------
------------- valid epoch: 24 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


EarlyStopping counter: 7 out of 7

Early stopping


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


0,1
train/Loss/Domain_loss,█▇▅▂▁▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃
train/Loss/Src_cls_loss,█▆▅▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
train/Loss/Total_loss,█▆▅▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
train/Loss/Domain_loss,1.36355
train/Loss/Src_cls_loss,0.28013
train/Loss/Total_loss,7.52504


9it [00:01,  4.89it/s]

------------ train epoch: 1 ------------------
Loss/Total_loss	: 25.9684
Loss/Domain_loss	: 1.3851
Loss/Src_cls_loss	: 2.4850
---------------------------------------------------



9it [00:01,  4.99it/s]

------------ train epoch: 2 ------------------
Loss/Total_loss	: 21.6467
Loss/Domain_loss	: 1.3804
Loss/Src_cls_loss	: 1.9758
---------------------------------------------------
------------- valid epoch: 2 ------------------





Validation loss decreased (inf --> 1.589469).



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  4.98it/s]

------------ train epoch: 3 ------------------
Loss/Total_loss	: 18.4013
Loss/Domain_loss	: 1.3720
Loss/Src_cls_loss	: 1.5859
---------------------------------------------------



9it [00:01,  4.93it/s]

------------ train epoch: 4 ------------------
Loss/Total_loss	: 16.0349
Loss/Domain_loss	: 1.3599
Loss/Src_cls_loss	: 1.3013
---------------------------------------------------
------------- valid epoch: 4 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  5.06it/s]

------------ train epoch: 5 ------------------
Loss/Total_loss	: 14.3474
Loss/Domain_loss	: 1.3523
Loss/Src_cls_loss	: 1.0986
---------------------------------------------------



9it [00:01,  4.99it/s]

------------ train epoch: 6 ------------------
Loss/Total_loss	: 13.1451
Loss/Domain_loss	: 1.3517
Loss/Src_cls_loss	: 0.9528
---------------------------------------------------
------------- valid epoch: 6 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  4.94it/s]

------------ train epoch: 7 ------------------
Loss/Total_loss	: 12.2033
Loss/Domain_loss	: 1.3527
Loss/Src_cls_loss	: 0.8381
---------------------------------------------------



9it [00:01,  5.04it/s]

------------ train epoch: 8 ------------------
Loss/Total_loss	: 11.4662
Loss/Domain_loss	: 1.3532
Loss/Src_cls_loss	: 0.7486
---------------------------------------------------
------------- valid epoch: 8 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  4.90it/s]

------------ train epoch: 9 ------------------
Loss/Total_loss	: 10.8639
Loss/Domain_loss	: 1.3539
Loss/Src_cls_loss	: 0.6765
---------------------------------------------------



9it [00:01,  4.95it/s]

------------ train epoch: 10 ------------------
Loss/Total_loss	: 10.3646
Loss/Domain_loss	: 1.3540
Loss/Src_cls_loss	: 0.6169
---------------------------------------------------
------------- valid epoch: 10 ------------------





EarlyStopping counter: 1 out of 7



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  4.99it/s]

------------ train epoch: 11 ------------------
Loss/Total_loss	: 9.9521
Loss/Domain_loss	: 1.3548
Loss/Src_cls_loss	: 0.5671
---------------------------------------------------



9it [00:01,  4.83it/s]

------------ train epoch: 12 ------------------
Loss/Total_loss	: 9.6019
Loss/Domain_loss	: 1.3555
Loss/Src_cls_loss	: 0.5251
---------------------------------------------------
------------- valid epoch: 12 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  4.98it/s]

------------ train epoch: 13 ------------------
Loss/Total_loss	: 9.2976
Loss/Domain_loss	: 1.3562
Loss/Src_cls_loss	: 0.4890
---------------------------------------------------



9it [00:01,  4.96it/s]

------------ train epoch: 14 ------------------
Loss/Total_loss	: 9.0292
Loss/Domain_loss	: 1.3568
Loss/Src_cls_loss	: 0.4567
---------------------------------------------------
------------- valid epoch: 14 ------------------





EarlyStopping counter: 1 out of 7



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  4.96it/s]

------------ train epoch: 15 ------------------
Loss/Total_loss	: 8.7978
Loss/Domain_loss	: 1.3580
Loss/Src_cls_loss	: 0.4293
---------------------------------------------------



9it [00:01,  4.98it/s]

------------ train epoch: 16 ------------------
Loss/Total_loss	: 8.5977
Loss/Domain_loss	: 1.3578
Loss/Src_cls_loss	: 0.4059
---------------------------------------------------
------------- valid epoch: 16 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  4.92it/s]

------------ train epoch: 17 ------------------
Loss/Total_loss	: 8.4118
Loss/Domain_loss	: 1.3580
Loss/Src_cls_loss	: 0.3842
---------------------------------------------------



9it [00:02,  4.46it/s]

------------ train epoch: 18 ------------------
Loss/Total_loss	: 8.2418
Loss/Domain_loss	: 1.3589
Loss/Src_cls_loss	: 0.3642
---------------------------------------------------
------------- valid epoch: 18 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


EarlyStopping counter: 1 out of 7



9it [00:01,  5.21it/s]

------------ train epoch: 19 ------------------
Loss/Total_loss	: 8.0892
Loss/Domain_loss	: 1.3597
Loss/Src_cls_loss	: 0.3464
---------------------------------------------------



9it [00:01,  5.11it/s]

------------ train epoch: 20 ------------------
Loss/Total_loss	: 7.9513
Loss/Domain_loss	: 1.3602
Loss/Src_cls_loss	: 0.3302
---------------------------------------------------
------------- valid epoch: 20 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


EarlyStopping counter: 2 out of 7



9it [00:01,  5.02it/s]

------------ train epoch: 21 ------------------
Loss/Total_loss	: 7.8256
Loss/Domain_loss	: 1.3608
Loss/Src_cls_loss	: 0.3155
---------------------------------------------------



9it [00:01,  4.97it/s]

------------ train epoch: 22 ------------------
Loss/Total_loss	: 7.7088
Loss/Domain_loss	: 1.3617
Loss/Src_cls_loss	: 0.3019
---------------------------------------------------
------------- valid epoch: 22 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


EarlyStopping counter: 3 out of 7



9it [00:01,  5.02it/s]

------------ train epoch: 23 ------------------
Loss/Total_loss	: 7.6008
Loss/Domain_loss	: 1.3623
Loss/Src_cls_loss	: 0.2894
---------------------------------------------------



9it [00:01,  4.83it/s]

------------ train epoch: 24 ------------------
Loss/Total_loss	: 7.5074
Loss/Domain_loss	: 1.3630
Loss/Src_cls_loss	: 0.2786
---------------------------------------------------
------------- valid epoch: 24 ------------------





EarlyStopping counter: 4 out of 7



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  4.96it/s]

------------ train epoch: 25 ------------------
Loss/Total_loss	: 7.4206
Loss/Domain_loss	: 1.3635
Loss/Src_cls_loss	: 0.2685
---------------------------------------------------



9it [00:01,  4.84it/s]

------------ train epoch: 26 ------------------
Loss/Total_loss	: 7.3351
Loss/Domain_loss	: 1.3639
Loss/Src_cls_loss	: 0.2586
---------------------------------------------------
------------- valid epoch: 26 ------------------





EarlyStopping counter: 5 out of 7



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  4.96it/s]

------------ train epoch: 27 ------------------
Loss/Total_loss	: 7.2551
Loss/Domain_loss	: 1.3646
Loss/Src_cls_loss	: 0.2493
---------------------------------------------------



9it [00:01,  4.99it/s]

------------ train epoch: 28 ------------------
Loss/Total_loss	: 7.1819
Loss/Domain_loss	: 1.3650
Loss/Src_cls_loss	: 0.2408
---------------------------------------------------
------------- valid epoch: 28 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


EarlyStopping counter: 6 out of 7



9it [00:01,  4.98it/s]

------------ train epoch: 29 ------------------
Loss/Total_loss	: 7.1147
Loss/Domain_loss	: 1.3655
Loss/Src_cls_loss	: 0.2330
---------------------------------------------------



9it [00:01,  5.11it/s]

------------ train epoch: 30 ------------------
Loss/Total_loss	: 7.0519
Loss/Domain_loss	: 1.3659
Loss/Src_cls_loss	: 0.2257
---------------------------------------------------
------------- valid epoch: 30 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


EarlyStopping counter: 7 out of 7

Early stopping


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


0,1
train/Loss/Domain_loss,█▇▅▃▁▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄
train/Loss/Src_cls_loss,█▆▅▄▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/Loss/Total_loss,█▆▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train/Loss/Domain_loss,1.36594
train/Loss/Src_cls_loss,0.22566
train/Loss/Total_loss,7.05194


9it [00:01,  5.10it/s]

------------ train epoch: 1 ------------------
Loss/Total_loss	: 26.1002
Loss/Domain_loss	: 1.3862
Loss/Src_cls_loss	: 2.4983
---------------------------------------------------



9it [00:01,  5.02it/s]

------------ train epoch: 2 ------------------
Loss/Total_loss	: 21.9101
Loss/Domain_loss	: 1.3822
Loss/Src_cls_loss	: 2.0069
---------------------------------------------------
------------- valid epoch: 2 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation loss decreased (inf --> 1.573044).



9it [00:01,  5.12it/s]

------------ train epoch: 3 ------------------
Loss/Total_loss	: 18.6166
Loss/Domain_loss	: 1.3716
Loss/Src_cls_loss	: 1.6121
---------------------------------------------------



9it [00:01,  5.01it/s]

------------ train epoch: 4 ------------------
Loss/Total_loss	: 16.2666
Loss/Domain_loss	: 1.3607
Loss/Src_cls_loss	: 1.3315
---------------------------------------------------
------------- valid epoch: 4 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  5.08it/s]

------------ train epoch: 5 ------------------
Loss/Total_loss	: 14.5472
Loss/Domain_loss	: 1.3573
Loss/Src_cls_loss	: 1.1233
---------------------------------------------------



9it [00:01,  5.07it/s]

------------ train epoch: 6 ------------------
Loss/Total_loss	: 13.2923
Loss/Domain_loss	: 1.3574
Loss/Src_cls_loss	: 0.9697
---------------------------------------------------
------------- valid epoch: 6 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  5.18it/s]

------------ train epoch: 7 ------------------
Loss/Total_loss	: 12.3462
Loss/Domain_loss	: 1.3567
Loss/Src_cls_loss	: 0.8554
---------------------------------------------------



9it [00:01,  5.05it/s]

------------ train epoch: 8 ------------------
Loss/Total_loss	: 11.6005
Loss/Domain_loss	: 1.3576
Loss/Src_cls_loss	: 0.7642
---------------------------------------------------
------------- valid epoch: 8 ------------------





EarlyStopping counter: 1 out of 7



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  4.72it/s]

------------ train epoch: 9 ------------------
Loss/Total_loss	: 10.9921
Loss/Domain_loss	: 1.3598
Loss/Src_cls_loss	: 0.6901
---------------------------------------------------



9it [00:01,  5.05it/s]

------------ train epoch: 10 ------------------
Loss/Total_loss	: 10.4862
Loss/Domain_loss	: 1.3607
Loss/Src_cls_loss	: 0.6289
---------------------------------------------------
------------- valid epoch: 10 ------------------





EarlyStopping counter: 2 out of 7



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  5.03it/s]

------------ train epoch: 11 ------------------
Loss/Total_loss	: 10.0610
Loss/Domain_loss	: 1.3610
Loss/Src_cls_loss	: 0.5783
---------------------------------------------------



9it [00:01,  5.08it/s]

------------ train epoch: 12 ------------------
Loss/Total_loss	: 9.6945
Loss/Domain_loss	: 1.3616
Loss/Src_cls_loss	: 0.5345
---------------------------------------------------
------------- valid epoch: 12 ------------------





EarlyStopping counter: 3 out of 7



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  5.12it/s]

------------ train epoch: 13 ------------------
Loss/Total_loss	: 9.3833
Loss/Domain_loss	: 1.3623
Loss/Src_cls_loss	: 0.4976
---------------------------------------------------



9it [00:01,  4.94it/s]

------------ train epoch: 14 ------------------
Loss/Total_loss	: 9.1221
Loss/Domain_loss	: 1.3617
Loss/Src_cls_loss	: 0.4672
---------------------------------------------------
------------- valid epoch: 14 ------------------





EarlyStopping counter: 4 out of 7



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  5.13it/s]

------------ train epoch: 15 ------------------
Loss/Total_loss	: 8.8751
Loss/Domain_loss	: 1.3605
Loss/Src_cls_loss	: 0.4383
---------------------------------------------------



9it [00:01,  5.13it/s]

------------ train epoch: 16 ------------------
Loss/Total_loss	: 8.6612
Loss/Domain_loss	: 1.3604
Loss/Src_cls_loss	: 0.4134
---------------------------------------------------
------------- valid epoch: 16 ------------------





EarlyStopping counter: 5 out of 7



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  5.04it/s]

------------ train epoch: 17 ------------------
Loss/Total_loss	: 8.4658
Loss/Domain_loss	: 1.3606
Loss/Src_cls_loss	: 0.3907
---------------------------------------------------



9it [00:01,  5.07it/s]

------------ train epoch: 18 ------------------
Loss/Total_loss	: 8.3021
Loss/Domain_loss	: 1.3601
Loss/Src_cls_loss	: 0.3718
---------------------------------------------------
------------- valid epoch: 18 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


EarlyStopping counter: 6 out of 7



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
9it [00:01,  4.99it/s]

------------ train epoch: 19 ------------------
Loss/Total_loss	: 8.1428
Loss/Domain_loss	: 1.3588
Loss/Src_cls_loss	: 0.3537
---------------------------------------------------



9it [00:01,  5.03it/s]

------------ train epoch: 20 ------------------
Loss/Total_loss	: 8.0000
Loss/Domain_loss	: 1.3591
Loss/Src_cls_loss	: 0.3371
---------------------------------------------------
------------- valid epoch: 20 ------------------



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


EarlyStopping counter: 7 out of 7

Early stopping


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


0,1
train/Loss/Domain_loss,█▇▅▂▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂
train/Loss/Src_cls_loss,█▆▅▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁
train/Loss/Total_loss,█▆▅▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
train/Loss/Domain_loss,1.3591
train/Loss/Src_cls_loss,0.33708
train/Loss/Total_loss,7.99998
