In [1]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

from collections import Counter
import random
import numpy as np

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# --- 1. Synthetic Dataset ---
class SeqLabelingDataset(Dataset):
    def __init__(self, n_samples=100, min_len=5, max_len=15, dyn_feat_dim=4, static_feat_dim=2, num_classes=3):
        self.samples = []
        for _ in range(n_samples):
            seq_len = random.randint(min_len, max_len)
            dyn_feat = torch.randn(seq_len, dyn_feat_dim)
            static_feat = torch.randn(static_feat_dim)
            labels = torch.randint(0, num_classes, (seq_len,))
            self.samples.append((dyn_feat, static_feat, labels))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

# --- 2. Collate function ---
def collate_fn(batch):
    dyn_seqs, stat_feats, label_seqs = zip(*batch)
    lengths = torch.tensor([len(seq) for seq in dyn_seqs])

    padded_dyn = pad_sequence(dyn_seqs, batch_first=True)  # [B, T, F]
    padded_labels = pad_sequence(label_seqs, batch_first=True, padding_value=-100)  # [B, T]
    static_feats = torch.stack(stat_feats)  # [B, S]

    return padded_dyn, static_feats, lengths, padded_labels

# --- 3. LSTM Model ---
class LSTMTagger(nn.Module):
    def __init__(self, dyn_feat_dim, static_feat_dim, hidden_dim, num_classes):
        super().__init__()
        self.lstm = nn.LSTM(dyn_feat_dim, hidden_dim, batch_first=True)
        self.classifier = nn.Linear(hidden_dim + static_feat_dim, num_classes)

    def forward(self, x_dyn, x_static, lengths):
        packed_input = pack_padded_sequence(x_dyn, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_input)
        lstm_out, _ = pad_packed_sequence(packed_output, batch_first=True)  # [B, T, H]

        # Expand static features to [B, T, S]
        B, T, _ = lstm_out.shape
        x_static_exp = x_static.unsqueeze(1).expand(B, T, -1)

        combined = torch.cat([lstm_out, x_static_exp], dim=-1)  # [B, T, H+S]
        logits = self.classifier(combined)  # [B, T, C]
        return logits

# --- 4. Accuracy with masking ---
def masked_accuracy(logits, labels, pad_val=-100):
    preds = logits.argmax(dim=-1)
    mask = labels != pad_val
    correct = (preds == labels) & mask
    acc = correct.sum().float() / mask.sum()
    return acc.item()

def compute_class_weights(dataset, num_classes, pad_val=-100):
    all_labels = []
    for _, _, label_seq in dataset:
        all_labels.extend(label_seq.tolist())
    filtered_labels = [lb for lb in all_labels if lb != pad_val]
    label_counts = Counter(filtered_labels)

    # Inverse frequency
    counts = torch.tensor([label_counts.get(i, 1) for i in range(num_classes)], dtype=torch.float)
    weights = 1.0 / counts
    weights = weights / weights.sum()  # normalize
    return weights


def evaluate(model, dataloader, pad_val=-100):
    model.eval()
    all_preds, all_labels = [], []
    total_acc = 0.0
    batches = 0

    with torch.no_grad():
        for x_dyn, x_static, lengths, labels in dataloader:
            x_dyn = x_dyn.to(device)
            x_static = x_static.to(device)
            lengths = lengths.to(device)
            labels = labels.to(device)

            logits = model(x_dyn, x_static, lengths)
            acc = masked_accuracy(logits, labels, pad_val)
            total_acc += acc
            batches += 1

            ###
            # preds = logits.argmax(dim=-1)
            # mask = labels != pad_val
            # all_preds.extend(preds[mask].cpu().numpy())
            # all_labels.extend(labels[mask].cpu().numpy())

    return total_acc / batches

    # Generate confusion matrix
    # cm = confusion_matrix(all_labels, all_preds)
    # report = classification_report(all_labels, all_preds, digits=4)
    # return accuracy_score(all_labels, all_preds)

In [8]:
dyn_feat_dim = 4
static_feat_dim = 2
hidden_dim = 32
num_classes = 3
batch_size = 16
epochs = 200

dataset = SeqLabelingDataset()
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [9]:
d = dataset.samples[3]
d[0].shape, d[1].shape, d[2].shape

(torch.Size([14, 4]), torch.Size([2]), torch.Size([14]))

In [13]:
# --- 5. Training setup ---
class_weights = compute_class_weights(dataset, num_classes).to(device)
model = LSTMTagger(dyn_feat_dim, static_feat_dim, hidden_dim, num_classes).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=-100, weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# --- 6. Training loop ---
for epoch in range(epochs):
    model.train()
    total_loss = 0.0
    total_acc = 0.0
    batches = 0

    for x_dyn, x_static, lengths, labels in dataloader:
        x_dyn = x_dyn.to(device)
        x_static = x_static.to(device)
        lengths = lengths.to(device)
        labels = labels.to(device)

        logits = model(x_dyn, x_static, lengths)
        loss = criterion(logits.view(-1, num_classes), labels.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = masked_accuracy(logits, labels)
        total_loss += loss.item()
        total_acc += acc
        batches += 1

    test_acc = evaluate(model, dataloader)
    print(f"Epoch {epoch+1} | Train Loss: {total_loss/batches:.4f} | Train Accuracy: {total_acc/batches:.4f} | Test Acc: {test_acc:.4f}\n\n")

Epoch 1 | Train Loss: 1.1020 | Train Accuracy: 0.3229 | Test Acc: 0.3416


Epoch 2 | Train Loss: 1.0996 | Train Accuracy: 0.3290 | Test Acc: 0.3518


Epoch 3 | Train Loss: 1.0963 | Train Accuracy: 0.3595 | Test Acc: 0.3509


Epoch 4 | Train Loss: 1.0994 | Train Accuracy: 0.3412 | Test Acc: 0.3361


Epoch 5 | Train Loss: 1.0980 | Train Accuracy: 0.3458 | Test Acc: 0.3795


Epoch 6 | Train Loss: 1.0955 | Train Accuracy: 0.3727 | Test Acc: 0.3548


Epoch 7 | Train Loss: 1.0950 | Train Accuracy: 0.3601 | Test Acc: 0.3803


Epoch 8 | Train Loss: 1.0958 | Train Accuracy: 0.3490 | Test Acc: 0.3712


Epoch 9 | Train Loss: 1.0928 | Train Accuracy: 0.3875 | Test Acc: 0.3686


Epoch 10 | Train Loss: 1.0942 | Train Accuracy: 0.3700 | Test Acc: 0.3717


Epoch 11 | Train Loss: 1.0946 | Train Accuracy: 0.3681 | Test Acc: 0.3820


Epoch 12 | Train Loss: 1.0943 | Train Accuracy: 0.3697 | Test Acc: 0.3943


Epoch 13 | Train Loss: 1.0906 | Train Accuracy: 0.4033 | Test Acc: 0.3968


Epoch 14 | Train Loss