In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ImprovedGMMClusterAttention(nn.Module):
    def __init__(self, input_dim, embed_dim=64, num_clusters=3, use_positional_encoding=True):
        super().__init__()
        self.K = num_clusters
        self.D = input_dim
        self.E = embed_dim
        self.use_positional_encoding = use_positional_encoding

        # Temporal projection to embedding space
        self.temporal_conv = nn.Conv1d(input_dim, embed_dim, kernel_size=5, padding=2)
        self.batch_norm = nn.BatchNorm1d(embed_dim, momentum=0.9)
        self.project = nn.Linear(embed_dim, embed_dim)
        self.ln = nn.LayerNorm(embed_dim)
        self.instance_norm = nn.InstanceNorm1d(embed_dim, affine=True, momentum=0.9)
        self.dropout = nn.Dropout(0.1)

        # GMM parameters in embedding space
        self.means = nn.Parameter(torch.randn(self.K, embed_dim))            # [K, E]
        self.log_vars = nn.Parameter(torch.zeros(self.K, embed_dim))         # [K, E]
        self.logits_pi = nn.Parameter(torch.zeros(self.K))                   # [K]
        self.temperature = nn.Parameter(torch.tensor(1.0))                   # Learnable

    def get_positional_encoding(self, seq_len, dim, device):
        pe = torch.zeros(seq_len, dim, device=device)
        position = torch.arange(0, seq_len, dtype=torch.float32, device=device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)  # [1, T, D]

    def gmm_prob(self, x_embed):  # [B, T, E]
        B, T, E = x_embed.shape
        x = x_embed.unsqueeze(2)                        # [B, T, 1, E]
        mu = self.means.unsqueeze(0).unsqueeze(0)       # [1, 1, K, E]
        log_var = self.log_vars.unsqueeze(0).unsqueeze(0)  # [1, 1, K, E]
        var = torch.exp(log_var)                        # [1, 1, K, E]

        # Log Gaussian probability
        log_prob = -0.5 * ((x - mu)**2 / var + log_var + math.log(2 * math.pi))  # [B, T, K, E]
        log_prob = log_prob.sum(dim=-1)  # [B, T, K]

        # Add log priors
        log_pi = F.log_softmax(self.logits_pi, dim=-1)  # [K]
        log_prob = log_prob + log_pi                    # [B, T, K]

        # Apply temperature scaling
        log_prob = log_prob / torch.clamp(self.temperature, min=0.1, max=10.0)

        probs = F.softmax(log_prob, dim=-1)  # [B, T, K]
        return probs, log_prob

    def forward(self, features):  # [B, T, D]
        B, T, D = features.shape

        # Optional positional encoding
        # if self.use_positional_encoding:
        #     pe = self.get_positional_encoding(T, D, features.device)
        #     features = features + pe

        # Temporal embedding
        x = self.batch_norm(self.temporal_conv(features.transpose(1, 2))).transpose(1, 2)  # [B, T, E]
        x_embed = self.ln(self.project(x))  # [B, T, E]

        # GMM soft assignment
        weights, log_probs = self.gmm_prob(x_embed)  # [B, T, K], [B, T, K]

        # Feature aggregation
        cluster_features = torch.einsum('btk,btd->bktd', weights, features)  # [B, K, T, D]

        return cluster_features, weights, log_probs
    
class FuzzyCMeansClustering(nn.Module):
    def __init__(self, input_dim, embed_dim=64, num_clusters=3, fuzziness=2.0, epsilon=1e-6):
        super().__init__()
        self.num_clusters = num_clusters
        self.fuzziness = fuzziness
        self.epsilon = epsilon

        # Temporal embedding
        self.temporal_conv = nn.Conv1d(input_dim, embed_dim, kernel_size=5, padding=2)
        self.batch_norm = nn.BatchNorm1d(embed_dim, momentum=0.9)
        self.project = nn.Linear(embed_dim, embed_dim)
        self.ln = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(0.1)

        # Cluster centers in embedded space
        self.cluster_centers = nn.Parameter(torch.randn(num_clusters, embed_dim))
        nn.init.xavier_uniform_(self.cluster_centers)

        # Learnable temperature
        self.temperature = nn.Parameter(torch.tensor(1.0))

    def forward(self, features):  # features: [B, T, D]
        B, T, D = features.shape

        # Temporal embedding
        x = self.temporal_conv(features.transpose(1, 2))      # [B, E, T]
        x = self.batch_norm(x).transpose(1, 2)                 # [B, T, E]
        x_embed = self.ln(self.project(x))                     # [B, T, E]

        # Normalize for cosine similarity
        x_norm = F.normalize(x_embed, p=2, dim=-1)             # [B, T, E]
        centers = F.normalize(self.cluster_centers, p=2, dim=-1)  # [K, E]

        # Cosine-based fuzzy distance
        cos_sim = torch.matmul(x_norm, centers.T)              # [B, T, K]
        dist_sq = 2 - 2 * cos_sim                              # [B, T, K]
        dist_sq = torch.clamp(dist_sq, min=self.epsilon)

        # Fuzzy membership weights
        power = 1.0 / (self.fuzziness - 1)
        inv_dist = torch.pow(dist_sq, -power)                  # [B, T, K]
        weights = inv_dist / (inv_dist.sum(dim=-1, keepdim=True) + self.epsilon)

        # Temperature scaling (optional)
        weights = weights / torch.clamp(self.temperature, min=0.1, max=10.0)

        # Weighted feature aggregation using **original features**
        cluster_features = torch.einsum('btk,btd->bktd', weights, features)  # [B, K, T, D]

        return cluster_features, weights

    
class IBNNet(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=16):
        super().__init__()
        self.conv1 = nn.Conv1d(input_dim, hidden_dim, kernel_size=5, stride=1, padding='same')
        self.bn1 = nn.BatchNorm1d(hidden_dim, momentum=0.9)
        self.lr1 = nn.LayerNorm(hidden_dim)  # Layer normalization after conv1

        self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1, stride=1, padding='same')
        self.bn2 = nn.BatchNorm1d(hidden_dim, momentum=0.9)
        self.lr2 = nn.LayerNorm(hidden_dim)  # Layer normalization after conv2

        # self.conv_middle = nn.Conv1d(hidden_dim * 2, hidden_dim, kernel_size=1, stride=1, padding='same')
        # self.bn_middle = nn.BatchNorm1d(hidden_dim, momentum=0.9)

        self.conv3 = nn.Conv1d(hidden_dim, 256, kernel_size=5, stride=1, padding='same')
        self.bn3 = nn.BatchNorm1d(256, momentum=0.9)

        self.lr3 = nn.LayerNorm(256)  # Layer normalization after conv3
        self.instance_norm = nn.InstanceNorm1d(256, affine=True, momentum=0.9)

    def forward(self, x, res=True, instance_norm=True):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))  # [B, hidden_dim, T]
        x = F.relu(self.bn2(self.conv2(x)))  # [B, hidden_dim, T]
        # x = F.relu(self.bn_middle(self.conv_middle(x)))  # [B, hidden_dim, T]
        x = self.conv3(x)

        if res == True:
            x = x + residual

        if instance_norm:
            x = self.instance_norm(x)
        else:
            x = self.bn3(x)
        # x = x.transpose(1, 2)  # [B, T, D]
        # x = self.lr(x)  # Apply layer normalization
        x = F.relu(x)
        return x
    
class UNetDecoder(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=64, kernel_size=3, padding=1):
        super().__init__()
        
        self.encoderhelper1 = IBNNet(input_dim, hidden_dim)
        self.upconv1 = nn.ConvTranspose1d(input_dim * 2, input_dim, kernel_size=kernel_size, stride=1, padding=padding)  # Upsample x2
        self.bn1 = nn.BatchNorm1d(input_dim, momentum=0.9)
        # self.conv1 = nn.Conv1d(input_dim * 2, input_dim, kernel_size=3, padding=1)
        # self.bn2 = nn.BatchNorm1d(input_dim, momentum=0.9)

        self.encoderhelper2 = IBNNet(input_dim, hidden_dim)
        self.upconv2 = nn.ConvTranspose1d(input_dim * 2, input_dim, kernel_size=kernel_size, stride=1, padding=padding)  # Upsample x2
        self.bn3 = nn.BatchNorm1d(input_dim, momentum=0.9)

        self.encoderhelper3 = IBNNet(input_dim, hidden_dim)
        self.upconv3 = nn.ConvTranspose1d(input_dim * 2, input_dim, kernel_size=kernel_size, stride=1, padding=padding)  # Upsample x2
        self.bn4 = nn.BatchNorm1d(input_dim, momentum=0.9)

        self.encoderhelper4 = IBNNet(input_dim, hidden_dim)
        self.upconv4 = nn.ConvTranspose1d(input_dim * 2, input_dim, kernel_size=kernel_size, stride=1, padding=padding)  # Upsample x2
        self.bn5 = nn.BatchNorm1d(input_dim, momentum=0.9)

        self.encoderhelper5 = IBNNet(input_dim, hidden_dim)
        self.upconv5 = nn.ConvTranspose1d(input_dim * 2, input_dim // 2, kernel_size=kernel_size, stride=1, padding=padding)  # Upsample x2
        self.bn6 = nn.BatchNorm1d(input_dim // 2, momentum=0.9)

        self.conv2 = nn.Conv1d(input_dim, input_dim // 2, kernel_size=kernel_size, padding=padding)
        self.bn6 = nn.BatchNorm1d(input_dim // 2, momentum=0.9)

        # self.fc = nn.Linear(input_dim // 2, output_dim)

        self.fc_list = nn.ModuleList([
            nn.Sequential(
                # nn.Conv1d(input_dim, 1, kernel_size=5, padding=2),
                # nn.BatchNorm1d(1, momentum=0.9),
                # Transpose(1, 2)
                nn.Linear(input_dim * 2, input_dim),
                nn.Dropout(0.1),
                nn.ReLU(),
                nn.Linear(input_dim, 1)
            ) for _ in range(output_dim)  # One for each cluster
        ])

    def forward(self, x, enc2, enc1):

        x = self.encoderhelper1(x, res=False)  # x: [B, C, T]
        # x: [B, C, T], enc2 and enc1 have same dims for concat
        x = torch.cat([x, enc2], dim=1)                  # Concat skip connection
        x = F.relu(self.bn1(self.upconv1(x)))            # Upsample
        # x = F.relu(self.bn2(self.conv1(x)))

        x = self.encoderhelper2(x, res=True, instance_norm=False)  # x: [B, C, T]
        x = torch.cat([x, enc1], dim=1)                  # Concat skip connection
        # x = F.relu(self.bn3(self.upconv2(x)))            # Upsample

        # x = self.encoderhelper3(x, res=True, instance_norm=False)  # x: [B, C, T]
        # x = torch.cat([x, enc1], dim=1)                  # Concat skip connection
        # x = F.relu(self.bn4(self.upconv3(x)))            # Upsample

        # x = self.encoderhelper4(x, res=True, instance_norm=False)  # x: [B, C, T]
        # x = torch.cat([x, enc1], dim=1)                  # Concat skip connection
        # x = F.relu(self.bn5(self.upconv4(x)))            # Upsample

        # x = self.encoderhelper5(x, res=True)  # x: [B, C, T]
        # x = torch.cat([x, enc1], dim=1)                  # Concat

        # x = F.relu(self.bn6(self.conv2(x)))

        # Global average pool over time dimension
        # x = torch.mean(x, dim=2)  # [B, C]
        x = x.transpose(1, 2)  # [B, T, C] for linear layers
        outputs = []
        for fc in self.fc_list:
            output_fin = fc(x)  # Apply each decoder to the concatenated features
            # output_fin = output_fin.transpose(1, 2)  # [B, T
            outputs.append(output_fin)
        outputs = torch.cat(outputs, dim=-1)  # Concatenate outputs for each cluster
        return outputs  # Return list of outputs for each cluster

class ImprovedConv1DEncoder(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=16):
        super().__init__()

        self.encoder1 = IBNNet(input_dim, hidden_dim)
        self.maxpool1 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2)  # Downsample by 2
        self.encoder2 = IBNNet(256, hidden_dim)
        self.maxpool2 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2)  # Downsample by 2
        self.encoder3 = IBNNet(256, hidden_dim)
        self.maxpool3 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2)  # Downsample by 2
        self.encoder4 = IBNNet(256, hidden_dim)
        self.maxpool4 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2)  # Downsample by 2
        # self.encoder5 = EncoderHelper(256, hidden_dim)  # Last encoder without pooling
        # self.maxpool5 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2)  # Optional pooling for last encoder
        # self.encoder6 = EncoderHelper(256, hidden_dim)  # Additional encoder for deeper features
        # self.maxpool6 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2)  # Optional pooling for last encoder

    def forward(self, x):
        x = x.transpose(1, 2)  # [B, input_dim, T]
        enc1 = self.encoder1(x, res=False)  # [B, 256, T]
        enc1 = self.maxpool1(enc1)  # Downsample by 2

        enc2 = self.encoder2(enc1, res=True) # [B, 256, T]
        enc2 = self.maxpool2(enc2)  # Downsample by 2

        enc3 = self.encoder3(enc2, res=True) # [B, 256, T]
        enc3 = self.maxpool3(enc3)  # Downsample by 2

        # enc4 = self.encoder4(enc3, res=True)
        # enc4 = self.maxpool4(enc4)  # Downsample by 2

        # enc5 = self.encoder5(enc4, res=True)  # [B, 256, T]
        # enc5 = self.maxpool5(enc5)

        # enc6 = self.encoder6(enc5, res=True)  # [B, 256, T]

        return enc3, enc2, enc1  # Return all encoder outputs for skip connections

class Transpose(nn.Module):
    def __init__(self, dim1, dim2):
        super().__init__()
        self.dim1 = dim1
        self.dim2 = dim2
    
    def forward(self, x):
        return torch.transpose(x, self.dim1, self.dim2)
    
class MCDNILM(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=16):
        super().__init__()
        self.encoder = ImprovedConv1DEncoder(input_dim, hidden_dim)
        self.temporal_lstm = nn.LSTM(258, 128, batch_first=True, bidirectional=True)
        self.temporal_ln = nn.LayerNorm(256)
        self.attention = FuzzyCMeansClustering(input_dim=256, num_clusters=3)

        # Use UNetStyleDecoder for all clusters
        self.decoders = nn.ModuleList([
            UNetDecoder(256, 3, kernel_size=5, padding=2),  # Long-cycle appliances
            UNetDecoder(256, 1, kernel_size=1, padding=0),  # HVAC appliances
            UNetDecoder(256, 3, kernel_size=5, padding=2, hidden_dim=80)   # Short-cycle appliances
        ])

    def forward(self, x):  # x: [B, T, input_dim]
        enc3, enc2, enc1 = self.encoder(x)  # [B, C, T] triples

        # enc_concat = torch.cat([enc5, x.transpose(1, 2)], dim=1)  # Concatenate all encoder outputs
        # enc_concat = enc_concat.transpose(1, 2)  # [B, T, C]
        # features, _ = self.temporal_lstm(enc_concat)  # [B, T, C]
        # features = self.temporal_ln(features)  # Apply layer normalization
        # Prepare features for attention: transpose to [B, T, C]
        features = enc3.transpose(1, 2)  # [B, T, C]
        # print(features.shape)
        # features, _ = self.temporal_lstm(features)  # [B, T, C]
        # features = self.temporal_ln(features)
        # print(features.shape)
        cluster_feats, weights = self.attention(features)  # [B, K, T, C]
        # weights = None  # No attention weights needed for UNet-style decoder
        outputs = []
        for i in range(3):
            feat = cluster_feats[:, i].transpose(1, 2)  # [B, C, T]
            # Pass feat and encoder skip connections to decoder
            # feat = enc3  # without clustering
            outputs.append(self.decoders[i](feat, enc2, enc1))
        output = torch.cat(outputs, dim=-1)  # Concatenate cluster outputs [B, sum_of_output_dims]
        return output, weights

# test_data = torch.randn(32, 100, 2)  # Example input: [B, T, D]
# model = ImprovedConv1DEncDecNILM2(input_dim=2, hidden_dim=64)
# output  = model(test_data)
# print(output[0].shape)  # Should print: torch.Size([32, 7]) for 3 clusters with outputs [3, 2, 2]