In [13]:
import numpy as np
import utilities as ut
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import LabelEncoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [14]:
# Change dataset name if needed
dataset = 'UTSv2'
pred_w = 'block' ## 'block' or 'floor'
data = np.load(f'data/dataset_{dataset}.npz', allow_pickle=True)
model_name = 'fmodel/lstm_UTS.pt'

# Extract contents
X_train_seq = data['X_train_seq']
Y_train_seq = data['Y_train_seq']
X_test_seq  = data['X_test_seq']
Y_test_seq  = data['Y_test_seq']
X_test_row = data['X_test']
Y_test_row = data['Y_test']
block_info  = data['block_info'][0]  # it's stored as an object array

In [15]:
# ====================
# Prepare Labels
# ====================

if pred_w == 'block':
    pred_col = -1
elif pred_w == 'floor':
    pred_col = 0

Y_train_seq = data['Y_train_seq']
y_train_ids = Y_train_seq[:, pred_col].astype(int)
y_test_ids  = Y_test_seq[:, pred_col].astype(int)


encoder = LabelEncoder()
encoder.fit(y_train_ids)

y_train = torch.tensor(encoder.transform(y_train_ids), dtype=torch.long)
y_test  = torch.tensor(encoder.transform(y_test_ids), dtype=torch.long)

# ====================
# Prepare Inputs
# ====================
X_train = torch.tensor(X_train_seq, dtype=torch.float32)
X_test  = torch.tensor(X_test_seq, dtype=torch.float32)

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=64, shuffle=True)
test_loader  = DataLoader(TensorDataset(X_test, y_test), batch_size=64)

In [16]:
# ====================
# LSTM Model
# ====================
class LSTMBlockClassifier(nn.Module):
    def __init__(self, input_size, hidden_size=256, fc_size=256, num_classes=100):
        super(LSTMBlockClassifier, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,
                            batch_first=True)
        self.fc1 = nn.Linear(hidden_size, fc_size)
        self.fc2 = nn.Linear(fc_size, num_classes)

    def forward(self, x):
        mean = x.mean(dim=2, keepdim=True)        # mean over features per time step
        std = x.std(dim=2, keepdim=True) + 1e-8   # std over features per time step
        x = (x - mean) / std
        out, (h_n, _) = self.lstm(x)
        h_last = h_n.squeeze(0)
        x = F.relu(self.fc1(h_last))
        x = self.fc2(x)
        return x

In [17]:
# ====================
# Load Trained Model
# ====================
input_size = X_train.shape[2]
num_classes = len(encoder.classes_)

model = LSTMBlockClassifier(input_size=input_size, num_classes=num_classes).to(device)
model.load_state_dict(torch.load(model_name, map_location=device, weights_only=True))
model.eval()

LSTMBlockClassifier(
  (lstm): LSTM(589, 256, batch_first=True)
  (fc1): Linear(in_features=256, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=175, bias=True)
)

In [18]:
# ====================
# Prepare Flattened Features for Student
# ====================
def flatten_lstm_input(X, steps=1):
    """
    Select the first `steps` time steps, then flatten (N, T, F) â†’ (N, T*F)
    """
    # return X[:, :steps, :].reshape(X.shape[0], -1)
    return X[:, :steps, :].mean(axis=1)

X_student_train = flatten_lstm_input(X_train_seq, steps=3)
X_student_test  = X_test_row

# Normalize inputs for student
X_student_train = ut.normalize_zscore_rowwise(X_student_train)
X_student_test  = ut.normalize_zscore_rowwise(X_student_test)

In [19]:
# ====================
# Get Teacher Logits and Soft Labels (Temperature = 2)
# ====================
def get_teacher_logits(model, X_tensor, device, batch_size=64):
    """
    Runs inference through teacher model and returns logits.
    """
    model.eval()
    logits = []
    with torch.no_grad():
        for i in range(0, len(X_tensor), batch_size):
            xb = X_tensor[i:i + batch_size].to(device)
            out = model(xb)
            logits.append(out.cpu())
    return torch.cat(logits, dim=0)

# Temperature hyperparameter
T = 2.0

# Step 1: Get teacher logits (pre-softmax)
X_tensor = torch.tensor(X_train_seq, dtype=torch.float32)
logits_teacher = get_teacher_logits(model, X_tensor, device)

# Step 2: Soft labels (probabilities)
teacher_probs = F.softmax(logits_teacher / T, dim=1)

In [20]:
# ====================
# Student Model Definition (MLP)
# ====================
class StudentMLP(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        return self.net(x)

student_input_dim = X_student_train.shape[1]
student = StudentMLP(input_dim=student_input_dim, num_classes=num_classes).to(device)

In [21]:
alpha = 0.7  # soft label weight
T = 2.0
epochs = 50
criterion_kl = nn.KLDivLoss(reduction='batchmean')
criterion_ce = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)

for epoch in range(epochs):
    student.train()
    total_loss = 0

    for i in range(0, len(X_student_train), 64):
        xb = torch.tensor(X_student_train[i:i+64], dtype=torch.float32).to(device)
        soft_yb = teacher_probs[i:i+64].to(device)
        hard_yb = y_train[i:i+64].to(device)

        optimizer.zero_grad()
        student_logits = student(xb)
        log_probs = F.log_softmax(student_logits / T, dim=1)
        probs = F.softmax(student_logits, dim=1)

        # Total loss = weighted KL + CE
        loss = (alpha * criterion_kl(log_probs, soft_yb) * T * T) + \
               ((1 - alpha) * criterion_ce(probs, hard_yb))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"[Hybrid Epoch {epoch+1}] Loss: {total_loss:.4f}")

[Hybrid Epoch 1] Loss: 1535.4407
[Hybrid Epoch 2] Loss: 1501.1917
[Hybrid Epoch 3] Loss: 1433.1939
[Hybrid Epoch 4] Loss: 1410.1412
[Hybrid Epoch 5] Loss: 1373.4161
[Hybrid Epoch 6] Loss: 1353.6587
[Hybrid Epoch 7] Loss: 1359.7862
[Hybrid Epoch 8] Loss: 1369.3010
[Hybrid Epoch 9] Loss: 1390.5202
[Hybrid Epoch 10] Loss: 1359.1615
[Hybrid Epoch 11] Loss: 1415.9242
[Hybrid Epoch 12] Loss: 1357.8120
[Hybrid Epoch 13] Loss: 1331.5165
[Hybrid Epoch 14] Loss: 1290.5261
[Hybrid Epoch 15] Loss: 1281.5550
[Hybrid Epoch 16] Loss: 1253.6949
[Hybrid Epoch 17] Loss: 1207.6277
[Hybrid Epoch 18] Loss: 1186.9655
[Hybrid Epoch 19] Loss: 1161.6037
[Hybrid Epoch 20] Loss: 1136.3213
[Hybrid Epoch 21] Loss: 1116.0588
[Hybrid Epoch 22] Loss: 1099.0876
[Hybrid Epoch 23] Loss: 1074.7043
[Hybrid Epoch 24] Loss: 1061.3066
[Hybrid Epoch 25] Loss: 1044.3156
[Hybrid Epoch 26] Loss: 1032.1249
[Hybrid Epoch 27] Loss: 1010.4622
[Hybrid Epoch 28] Loss: 995.1560
[Hybrid Epoch 29] Loss: 973.1376
[Hybrid Epoch 30] Loss: 9

In [22]:
# ====================
# Evaluate Student
# ====================
X_tensor = torch.tensor(X_student_test, dtype=torch.float32).to(device)
student.eval()
with torch.no_grad():
    logits = student(X_tensor)
    preds = logits.argmax(dim=1).cpu().numpy()

# Decode predicted block IDs
pred_block_ids = encoder.inverse_transform(preds)
true_block_ids = Y_test_row[:, pred_col].astype(int)

In [23]:
# Compute Localization Error
pred_coords = np.array([[block_info[b]['x'], block_info[b]['y']] for b in pred_block_ids])
# true_coords = np.array([[block_info[b]['x'], block_info[b]['y']] for b in true_block_ids])
true_coords = Y_test_row[:, -4:-2]  # assuming these are [x, y] columns
errors = np.linalg.norm(pred_coords - true_coords, axis=1)

# Metrics
acc = np.mean(pred_block_ids == true_block_ids)
mle = np.mean(errors)

print(f" Student MLE: {mle:.2f} meters")
print(f" 75th percentile error: {np.percentile(errors, 75):.2f}m")
print(f" 90th percentile error: {np.percentile(errors, 90):.2f}m")

 Student MLE: 7.72 meters
 75th percentile error: 10.31m
 90th percentile error: 13.78m
