reference:  
> [TS-TCC: 通过Temporal和Contextual对比学习的经典之作](https://mp.weixin.qq.com/s/rsSBssZK4iWDt-aZwhAUaQ)

In [None]:
import numpy as np
import torch
from torch import nn

In [None]:
# 数据增强：
# 1、强增强：jitter-and-scale，jitter是通过在数据点上添加小幅度的随机变化，scale是随机调整数据点的数值幅度；
# 2、弱增强：permutation-and-jitter，permutation是将信号分割成随机数量的片段（最多M个），然后随机打乱这些片段的顺序。


def jitter(x, sigma=0.8):    
    # https://arxiv.org/pdf/1706.00527.pdf    
    return x + np.random.normal(loc=0., scale=sigma, size=x.shape)

def scaling(x, sigma=1.1):      
    factor = np.random.normal(loc=2., scale=sigma, size=(x.shape[0], x.shape[2]))    
    ai = []    
    for i in range(x.shape[1]):        
        xi = x[:, i, :]        
        ai.append(np.multiply(xi, factor[:, :])[:, np.newaxis, :])    
        return np.concatenate((ai), axis=1)

def permutation(x, max_segments=5, seg_mode="random"):    
    orig_steps = np.arange(x.shape[2])
    num_segs = np.random.randint(1, max_segments, size=(x.shape[0]))
    ret = np.zeros_like(x)    
    for i, pat in enumerate(x):        
        if num_segs[i] > 1:            
            if seg_mode == "random":                
                split_points = np.random.choice(x.shape[2] - 2, num_segs[i] - 1, replace=False)                
                split_points.sort()                
                splits = np.split(orig_steps, split_points)            
            else:                
                splits = np.array_split(orig_steps, num_segs[i])            
                warp = np.concatenate(np.random.permutation(splits)).ravel()            
                ret[i] = pat[0,warp]        
        else:            
            ret[i] = pat    
    return torch.from_numpy(ret)

In [None]:
for batch_idx, (data, labels, aug1, aug2) in enumerate(train_loader):    
    # send to device    
    # 数据增强后的数据    
    data, labels = data.float().to(device), labels.long().to(device)    
    aug1, aug2 = aug1.float().to(device), aug2.float().to(device)
    # optimizer    
    model_optimizer.zero_grad()    
    temp_cont_optimizer.zero_grad()
    if training_mode == "self_supervised":        
        # 喂入Encoder获取时序特征        
        predictions1, features1 = model(aug1)        
        predictions2, features2 = model(aug2)
        # normalize projection feature vectors        # 标准化        
        features1 = F.normalize(features1, dim=1)        
        features2 = F.normalize(features2, dim=1)                
        # 2边支路各自预测对方的未来特征值        
        temp_cont_loss1, temp_cont_lstm_feat1 = temporal_contr_model(features1, features2)        
        temp_cont_loss2, temp_cont_lstm_feat2 = temporal_contr_model(features2, features1)
        # normalize projection feature vectors        
        zis = temp_cont_lstm_feat1         
        zjs = temp_cont_lstm_feat2 
    else:        
        output = model(data)
    # compute loss    
    if training_mode == "self_supervised":        
        # 计算损失        
        lambda1 = 1        
        lambda2 = 0.7        
        nt_xent_criterion = NTXentLoss(device, config.batch_size, config.Context_Cont.temperature,config.Context_Cont.use_cosine_similarity)        
        loss = (temp_cont_loss1 + temp_cont_loss2) * lambda1 +  nt_xent_criterion(zis, zjs) * lambda2

In [None]:
class Seq_Transformer(nn.Module):    
    def __init__(self, *, patch_size, dim, depth, heads, mlp_dim, channels=1, dropout=0.1):        
        super().__init__()        
        patch_dim = channels * patch_size        
        self.patch_to_embedding = nn.Linear(patch_dim, dim)        
        self.c_token = nn.Parameter(torch.randn(1, 1, dim))        
        self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)        
        self.to_c_token = nn.Identity()

    def forward(self, forward_seq):        
        # 数据增强的序列embedding化        
        x = self.patch_to_embedding(forward_seq)        
        b, n, _ = x.shape        
        # 加入c_token        
        c_tokens = repeat(self.c_token, '() n d -> b n d', b=b)        
        x = torch.cat((c_tokens, x), dim=1)        
        # 进入Transformer        
        x = self.transformer(x)        
        # 获取c_token的隐层状态       
        c_t = self.to_c_token(x[:, 0])        
        return c_t

In [None]:
class TC(nn.Module):    
    def __init__(self, configs, device):        
        super(TC, self).__init__()        
        self.num_channels = configs.final_out_channels        
        self.timestep = configs.TC.timesteps        
        # 为预测的每个时间戳都构建一个线性层        
        self.Wk = nn.ModuleList([nn.Linear(configs.TC.hidden_dim, self.num_channels) for i in range(self.timestep)])        
        self.lsoftmax = nn.LogSoftmax()        
        self.device = device                
        # 图1中的非线性映射头        
        self.projection_head = nn.Sequential(            
            nn.Linear(configs.TC.hidden_dim, configs.final_out_channels // 2),            
            nn.BatchNorm1d(configs.final_out_channels // 2),            
            nn.ReLU(inplace=True),            
            nn.Linear(configs.final_out_channels // 2, configs.final_out_channels // 4),        
            )                
        # Transformer模型，用于提取c        
        self.seq_transformer = Seq_Transformer(patch_size=self.num_channels, dim=configs.TC.hidden_dim, depth=4, heads=4, mlp_dim=64)
    def forward(self, features_aug1, features_aug2):        
        z_aug1 = features_aug1  
        # features are (batch_size, #channels, seq_len)        
        seq_len = z_aug1.shape[2]        
        z_aug1 = z_aug1.transpose(1, 2)
        z_aug2 = features_aug2        
        z_aug2 = z_aug2.transpose(1, 2)
        batch = z_aug1.shape[0]        
        # 随机选一个时间戳        
        t_samples = torch.randint(seq_len - self.timestep, size=(1,)).long().to(self.device)  
        # randomly pick time stamps
        nce = 0  
        # average over timestep and batch        
        # # 从 features_aug2 中提取t_samples之后的“未来”样本，并存储在 encode_samples 中作为待预测的目标特征序列        
        encode_samples = torch.empty((self.timestep, batch, self.num_channels)).float().to(self.device)        
        for i in np.arange(1, self.timestep + 1):            
            encode_samples[i - 1] = z_aug2[:, t_samples + i, :].view(batch, self.num_channels)        
        # 从 features_aug2 中抽取t_samples之前的“历史样本”，进入trans并获取c        
        forward_seq = z_aug1[:, :t_samples + 1, :]        
        c_t = self.seq_transformer(forward_seq)                
        # 预测对面每个时间戳的值        
        pred = torch.empty((self.timestep, batch, self.num_channels)).float().to(self.device)        
        for i in np.arange(0, self.timestep):            
            linear = self.Wk[i]            
            pred[i] = linear(c_t)        
            # 计算预测值和真实值的nce        
            for i in np.arange(0, self.timestep):           
                total = torch.mm(encode_samples[i], torch.transpose(pred[i], 0, 1))            
                nce += torch.sum(torch.diag(self.lsoftmax(total)))        
                nce /= -1. * batch * self.timestep        
        return nce, self.projection_head(c_t)