In [2]:
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import torch.nn as nn
import os
import matplotlib.pyplot as plt
import math

from torch.utils.data import DataLoader, Dataset

In [2]:
class global_config():
    def __init__(self):
        self.task_name = 'classification'  # Example task name
        self.output_attention = False      # Whether to output attention weights
        self.d_model = 120                # Model dimension
        self.dropout = 0.25                # Dropout rate
        self.n_heads = 4                   # Number of attention heads
        self.enc_layers = 1                  # Number of encoder layers
        self.d_ff = 256                    # Feedforward network dimension
        self.epoch = 20                    # Train epoch
        self.factor = 1                    # Attention scaling factor
        self.activation = 'gelu'           # Activation function
        self.enc_in = 30                  # Encoder input dimension (example value)
        self.epochs = 5

In [3]:
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super(PositionalEmbedding, self).__init__()
        position_embedding = torch.zeros(max_len, d_model).float()
        position_embedding.requires_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        position_embedding[:, 0::2] = torch.sin(position * div_term)
        position_embedding[:, 1::2] = torch.cos(position * div_term)

        position_embedding = position_embedding.unsqueeze(0)
        self.register_buffer('positional_embedding', position_embedding)

    def forward(self, x):
        return self.positional_embedding[:, :x.size(1), :]

class DataEmbedding(nn.Module):
    def __init__(self, d_model, dropout=0.05):
        super(DataEmbedding, self).__init__()

        self.projection = nn.Linear(1, d_model)
        self.positional_embedding = PositionalEmbedding(d_model=d_model)
        self.dropout = nn.Dropout(p=dropout)

    def reshape(self, x):
        batch_size, num_channels, seq_len = x.shape
        x = x.view(batch_size * num_channels, seq_len, 1)  # 调整形状以匹配 Linear 层的输入
        x = self.projection(x)
        x = x.view(batch_size, num_channels, seq_len, -1)
        return x

    def forward(self, x):
        x = self.reshape(x)
        batch_size, num_channels, seq_len, d_model = x.shape

        # 获取位置嵌入
        pos_embedding = self.positional_embedding(x[:, 0, :, :])  # 取一个通道的位置嵌入
        pos_embedding = pos_embedding.unsqueeze(1)  # 添加通道维度
        pos_embedding = pos_embedding.expand(batch_size, num_channels, seq_len, d_model)  # 广播到所有通道

        # 添加位置嵌入
        x = x + pos_embedding
        return self.dropout(x)

In [4]:
class Encoder(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        # x [B, L, D]
        attns = []
        if self.conv_layers is not None:
            for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
                delta = delta if i == 0 else None
                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
                x = conv_layer(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns


class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        new_x, attn = self.attention(
            x, x, x,
            attn_mask=attn_mask,
            tau=tau, delta=delta
        )
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm2(x + y), attn


In [5]:
class FullAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(FullAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)

        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)

            scores.masked_fill_(attn_mask.mask, -np.inf)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        if self.output_attention:
            return V.contiguous(), A
        else:
            return V.contiguous(), None
        

class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_attention(
            queries,
            keys,
            values,
            attn_mask,
            tau=tau,
            delta=delta
        )
        out = out.view(B, L, -1)

        return self.out_projection(out), attn


In [6]:

class Channel_wise_transformer(nn.Module):
    def __init__(self, global_config):
        super(Channel_wise_transformer, self).__init__()
        self.task_name = global_config.task_name
        self.output_attention = global_config.output_attention

        self.encoder_embedding = DataEmbedding(d_model=global_config.d_model, dropout=global_config.dropout)

        self.encoder = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        FullAttention(False, global_config.factor, attention_dropout = global_config.output_attention),
                        global_config.d_model, global_config.n_heads
                    ),
                    global_config.d_model,
                    global_config.d_ff,
                    dropout = global_config.dropout,
                    activation = global_config.activation
                ) for l in range(global_config.enc_layers)
            ],
            norm_layer = torch.nn.LayerNorm(global_config.d_model)
        )

    def forward(self, x):
        enc_out = self.encoder_embedding(x)
        batch_size, channels, seq_length, d_model = enc_out.shape
        enc_out = enc_out.permute(0, 2, 1, 3) 
        enc_out = enc_out.reshape(batch_size * seq_length, channels, d_model)

        enc_out, attns = self.encoder(enc_out, attn_mask = None)
        enc_out = enc_out.reshape(batch_size, seq_length, channels, d_model)
        enc_out = enc_out.permute(0, 2, 1, 3)
        return enc_out

In [7]:
class Spatial_temporal_Conv(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.stconv = nn.Sequential(
            #spatial conv
            nn.Conv2d(in_channels=d_model,
                      out_channels= d_model * 2,
                      kernel_size=(6, 1),
                      padding=(6,0)),
            nn.AvgPool2d(kernel_size=(2, 1),
                         stride = (2, 1),
                         padding = 0),
            nn.BatchNorm2d(d_model * 2),
            nn.ReLU(),
            #temporal conv
            nn.Conv2d(in_channels=d_model * 2,
                      out_channels= d_model * 2,
                      kernel_size=(1, 30),
                      padding=(0, 0)),
            nn.AvgPool2d(kernel_size=(1, 2),
                         stride = (1, 2),
                         padding = 0),
            nn.BatchNorm2d(d_model * 2),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        self.conv1d = nn.Conv2d(
            in_channels= d_model*2,
            out_channels= 1,
            kernel_size=(1,1),
            padding=0
        )

    def forward(self, x):
        # x:[batch_size, channels, time, d_model]
        x = x.permute(0, 3, 1, 2) # x:[batch_size, d_model, channels, time]
        output = self.stconv(x)
        output = self.conv1d(output)
        output = torch.squeeze(1).view(output.size(0), -1)
        return output


In [3]:
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return res
    
class Project_Layer(nn.Sequential):
    def __init__(self, embedding_dim = 5000, project_dim = 1024, project_drop = 0.5):
        super().__init__(
            nn.Linear(embedding_dim, project_dim),
            ResidualAdd(nn.Sequential(
                nn.ReLU(),
                nn.Linear(project_dim, project_dim),
                nn.Dropout(project_drop)
            )),
            nn.LayerNorm(project_dim),
        )

- Completed EP encoder model

In [None]:
class EP_encoder(nn.Module):
    def __init__(self, global_config, *args, **kwargs):
        super(EP_encoder, self).__init__(*args, **kwargs)
        self.channel_wise_transformer = Channel_wise_transformer(global_config=global_config)
        self.spatial_temporal_conv = Spatial_temporal_Conv(d_model=global_config.d_model)
        self.project_layer = Project_Layer(embedding_dim=)

    def forward(self, x):
        x = self.channel_wise_transformer(x)
        embedding = self.spatial_temporal_conv(x)
        output = self.project_layer(embedding)
        return output
    
EP_encoder_v1 = EP_encoder(global_config=global_config())
EP_encoder_v1 = EP_encoder_v1.to(device='cuda')

- train

In [10]:
def loss_func(image_feature, EP_feature, t):
    def cross_entropy_loss(preds, targets, reduction = "none"):
        log_softmax = nn.LogSoftmax(dim=-1)
        loss = (-targets * log_softmax(preds)).sum(1)
        if reduction == 'none':
            return loss
        elif reduction == 'mean':
            return loss.mean()
    logits = (EP_feature @ image_feature.T) / t

    image_similarity = image_feature @ image_feature.T
    EP_similarity = EP_feature @ EP_feature.T
    targets = F.softmax(
        (image_similarity + EP_similarity) /2 * t, dim= -1
    )
    EP_loss = cross_entropy_loss(logits, targets, reduction = 'none')
    Image_loss = cross_entropy_loss(logits.T, targets.t, reduction = 'none')
    loss = (EP_loss + Image_loss)/2
    return loss.mean()

In [11]:
def train_model(encoder_modle, dataloader, optimizer, device):
    encoder_modle.train()
    total_loss = 0
    correct = 0
    total = 0
    EP_feature_list = []
    predict_label = []
    actual_label = []

    for batch_id, (EP_data, image, label, image_feature) in enumerate(dataloader):
        EP_data = EP_data.to(device)
        label = label.to(device)
        image_feature = image_feature.to(device)
        batch_size = EP_data.size(0)

        EP_feature = encoder_modle(EP_data).float()
        EP_feature_list.append(EP_feature)

        loss = loss_func(image_feature, EP_feature, t=0.05 )
        loss.backward()

        optimizer.step()
        total_loss += loss.item()

        for ep_idx, single_EP_feature in enumerate(EP_feature):
            min_loss = float('inf') 
            min_label = None  

            for _, (Image_feature, Image_label) in enumerate(encoder_modle.image_dataset):
                temp_loss = encoder_modle.loss_func(single_EP_feature, Image_feature)
                if temp_loss.item() < min_loss:
                    min_loss = temp_loss.item()  
                    min_label = Image_label 

            predict_label.append(min_label)
            actual_label.append(label[ep_idx])

            correct += (predict_label == actual_label).sum().item()
        
        total += batch_size
        average_loss = total_loss / (batch_id+1)
        accuracy = correct / total

        return average_loss, accuracy, torch.cat(EP_feature_list, dim=0)

In [12]:
def evaluate_model(encoder_modle, dataloader, device):
    encoder_modle.eval()

    total_loss = 0
    correct = 0
    total = 0
    EP_feature_list = []
    predict_label = []
    actual_label = []

    with torch.no_grad():
        for batch_id, (EP_data, image, label, image_feature) in enumerate(dataloader):
            EP_data = EP_data.to(device)
            label = label.to(device)
            image_feature = image_feature.to(device)
            batch_size = EP_data.size(0)

            EP_feature = encoder_modle(EP_data).float()
            EP_feature_list.append(EP_feature)

            loss = encoder_modle.loss_func(EP_feature, image_feature)
            total_loss += loss.item()

        for ep_idx, single_EP_feature in enumerate(EP_feature):
            min_loss = float('inf') 
            min_label = None  

            for _, (Image_feature, Image_label) in enumerate(encoder_modle.image_dataset):
                temp_loss = encoder_modle.loss_func(single_EP_feature, Image_feature)
                if temp_loss.item() < min_loss:
                    min_loss = temp_loss.item()  
                    min_label = Image_label 

            predict_label.append(min_label)
            actual_label.append(label[ep_idx])

            correct += (predict_label == actual_label).sum().item()
        
        total += batch_size
        average_loss = total_loss / (batch_id+1)
        accuracy = correct / total

        return average_loss, accuracy, predict_label, actual_label

In [13]:
import torch.optim as optim

optimizer = optim.Adam(EP_encoder_v1.parameters(), lr=0.0001)

In [14]:
def main_train_loop(date, encoder_modle, train_loader, test_loader, optimizer, device, config):
    train_losses, train_accuracies = [], []
    test_losses, test_accuracies = [], []
    results = []

    for epoch in range(5):
        train_loss, train_accuracy, features_tensor = train_model(encoder_modle, train_loader, optimizer, device)
        if (epoch +1) % 5 == 0:     
            os.makedirs(f"./models/{config.encoder_type}/{date}", exist_ok=True)             
            file_path = f"./models/{config.encoder_type}/{date}/{epoch+1}.pth"
            torch.save(encoder_modle.state_dict(), file_path)

            print(f"model saved in {file_path}!")
            print(f"train loss: {train_loss} at epoch: {epoch + 1}")
            print(f"train accuracy: {train_accuracy} at epoch: {epoch + 1}")
            print("----------")

        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)

        test_loss, test_accuracy, test_predict, test_actual = evaluate_model(encoder_modle, test_loader, device)

        epoch_results = {
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "train_accuracy": train_accuracy,
        "test_loss": test_loss,
        "test_accuracy": test_accuracy,
        "test_predict": test_predict,
        "test_actual": test_actual
        }
    
    results.append(epoch_results)

    fig, axs = plt.subplots(1, 2, figsize=(10, 15))

    # Loss curve
    axs[0, 0].plot(train_losses, label='Train Loss')
    axs[0, 0].plot(test_losses, label='Test Loss')
    axs[0, 0].legend()
    axs[0, 0].set_title("Loss Curve")

    # Overall accuracy curve
    axs[0, 1].plot(train_accuracies, label='Train Accuracy')
    axs[0, 1].plot(test_accuracies, label='Test Accuracy')
    axs[0, 1].legend()
    axs[0, 1].set_title("Accuracy Curve")

    return results

In [15]:
import pickle
with open('train_dataset.pkl', 'rb') as f:
    train_dataset = pickle.load(f)

with open('test_dataset.pkl', 'rb') as f:
    test_dataset = pickle.load(f)

train_loader = DataLoader(train_dataset, batch_size=15, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=15, shuffle=False)

In [16]:
results = main_train_loop(encoder_modle=EP_encoder_v1, train_loader=train_loader, test_loader=test_loader, optimizer=optimizer, device='cuda', config=global_config,date='20250215')

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.01 GiB. GPU 0 has a total capacity of 11.69 GiB of which 947.25 MiB is free. Including non-PyTorch memory, this process has 7.28 GiB memory in use. Of the allocated memory 5.56 GiB is allocated by PyTorch, and 1.53 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)