In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import pandas as pd
import numpy as np
import math
import os
import inspect
import csv

In [2]:
print("Acquiring datasets\n")

train_query          = np.load("../DatasetCreation/training_dataset.npy")
train_support        = np.load("../DatasetCreation/training_support_dataset.npy")
train_query_neighbors= np.load("../DatasetCreation/train_query_neighbors.npy")

val_query            = np.load("../DatasetCreation/validation_dataset.npy")
val_support          = np.load("../DatasetCreation/validation_support_dataset.npy")
val_query_neighbors  = np.load("../DatasetCreation/val_query_neighbors.npy")

print("Done. neighbor_array shape:", train_query_neighbors.shape)

Acquiring datasets

Done. neighbor_array shape: (198071,)


In [5]:
class SelfAttention(nn.Module):
    """
    A minimal single-head self-attention for sequences:
    Input shape:  (B, T, hidden_dim)
    Output shape: (B, T, hidden_dim)
    """
    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.Wq = nn.Linear(hidden_dim, hidden_dim)
        self.Wk = nn.Linear(hidden_dim, hidden_dim)
        self.Wv = nn.Linear(hidden_dim, hidden_dim)
        self.scale = math.sqrt(hidden_dim)

    def forward(self, x):
        """
        x: (batch_size, seq_len, hidden_dim)
        returns: (batch_size, seq_len, hidden_dim)
        """
        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)

        attn_scores = torch.bmm(Q, K.transpose(1,2)) / self.scale  # (B, T, T)
        attn_weights = torch.softmax(attn_scores, dim=-1)          # (B, T, T)
        out = torch.bmm(attn_weights, V)                           # (B, T, hidden_dim)
        return out



class FewShotLSTMAttn(nn.Module):
    """
    A direct few-shot baseline:
    1) difference = test_past - support_past
    2) LSTM hidden init from difference
    3) Self-attention on LSTM outputs
    """
    def __init__(self, hidden_dim=64, num_layers=1):
        super().__init__()
        # We expect difference vector to be shape (B, 192)
        # So diff_to_hidden expects input_dim=192
        self.diff_to_hidden = nn.Linear(192, hidden_dim)

        self.lstm = nn.LSTM(
            input_size=1,         # support_future has shape (..., 1)
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )

        self.attention = SelfAttention(hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, 1)

    def forward(self, test_past, support_past, support_future):
        """
        test_past: (B, 192)
        support_past: (B, 192)
        support_future: (B, 192, 1)
        returns: (B, 192, 1)
        """
        # 1) difference
        diff_vec = test_past - support_past        # shape (B, 192)

        # 2) init hidden/cell from difference
        #    linear -> (B, hidden_dim)
        h0 = self.diff_to_hidden(diff_vec)         # (B, hidden_dim)
        c0 = torch.zeros_like(h0)                  # (B, hidden_dim)

        # LSTM expects (num_layers, B, hidden_dim)
        h0 = h0.unsqueeze(0)  # => (1, B, hidden_dim)
        c0 = c0.unsqueeze(0)  # => (1, B, hidden_dim)

        # 3) LSTM on support_future
        #    support_future: (B, 192, 1)
        lstm_out, (hn, cn) = self.lstm(support_future, (h0, c0))
        # => lstm_out: (B, 192, hidden_dim)

        # 4) self-attention over the LSTM outputs
        attn_out = self.attention(lstm_out)   # => (B, 192, hidden_dim)

        # 5) final linear at each time step
        pred = self.fc_out(attn_out)         # => (B, 192, 1)
        return pred


In [None]:
# Example:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import csv

BATCH_SIZE = 16
EPOCHS = 5000
LEARNING_RATE = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

model = FewShotLSTMAttn(hidden_dim=64, num_layers=1).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion_mae = nn.L1Loss()
criterion_mse = nn.MSELoss()

def get_batch(dataset, batch_size, query_neighbors, support_dataset):
    """
    Returns:
      x_test_tensor: (B, 192) -> test_past
      y_test_tensor: (B, 192, 1) -> test_future
      x_support_tensor: (B, 192) -> support_past
      y_support_tensor: (B, 192, 1) -> support_future
    """
    idxs = np.random.choice(len(dataset), batch_size, replace=False)
    
    x_test_list = []
    y_test_list = []
    x_support_list = []
    y_support_list = []

    for i in idxs:
        chunk = dataset[i]        # shape (384,)
        test_past = chunk[:192]   # (192,)
        test_future = chunk[192:] # (192,)

        # Use the precomputed nearest support index
        nearest_sup_idx = query_neighbors[i]
        support_chunk = support_dataset[nearest_sup_idx]  # shape (384,)
        support_past   = support_chunk[:192]
        support_future = support_chunk[192:]
        
        x_test_list.append(test_past)
        y_test_list.append(test_future)
        x_support_list.append(support_past)
        y_support_list.append(support_future)

    # Convert to NumPy arrays
    x_test_arr = np.array(x_test_list, dtype=np.float32)     # (B, 192)
    y_test_arr = np.array(y_test_list, dtype=np.float32)     # (B, 192)
    x_support_arr = np.array(x_support_list, dtype=np.float32)
    y_support_arr = np.array(y_support_list, dtype=np.float32)

    # Convert to Tensors, add extra dimension for future
    x_test_tensor     = torch.tensor(x_test_arr).to(device)              # (B, 192)
    y_test_tensor     = torch.tensor(y_test_arr).unsqueeze(-1).to(device)# (B, 192, 1)
    x_support_tensor  = torch.tensor(x_support_arr).to(device)           # (B, 192)
    y_support_tensor  = torch.tensor(y_support_arr).unsqueeze(-1).to(device) # (B, 192, 1)

    return x_test_tensor, y_test_tensor, x_support_tensor, y_support_tensor
csv_filename = "few_shot_lstm_attn_log.csv"
with open(csv_filename, "w", newline='') as f:
    writer = csv.writer(f)
    writer.writerow(["epoch", "MAE", "MSE"])

for epoch in range(EPOCHS):
    total_mae = 0.0
    total_mse = 0.0

    for step in range(100):
        x_test, y_test, x_support, y_support = get_batch(
            train_query,          # dataset
            BATCH_SIZE,
            train_query_neighbors, # pass the neighbor array
            train_support          # pass the support dataset
        )
        
        optimizer.zero_grad()

        # forward
        y_pred = model(x_test, x_support, y_support)  # (B, 192, 1)

        # compute loss
        loss_mae = criterion_mae(y_pred, y_test)
        loss_mse = criterion_mse(y_pred, y_test)

        # backprop
        loss_mae.backward()
        optimizer.step()

        total_mae += loss_mae.item()
        total_mse += loss_mse.item()

    avg_mae = total_mae / 100
    avg_mse = total_mse / 100
    print(f"Epoch {epoch+1}/{EPOCHS} - MAE: {avg_mae:.4f}, MSE: {avg_mse:.4f}")

    # Log to CSV
    with open(csv_filename, "a", newline='') as f:
        writer = csv.writer(f)
        writer.writerow([epoch+1, avg_mae, avg_mse])

    if(epoch >= 2000 or (epoch%100)==0):
        torch.save(model.state_dict(), f"few_shot_lstm_attn_epoch_{epoch+1}.pth")


Epoch 1/5000 - MAE: 767.7611, MSE: 947067.9475
Epoch 2/5000 - MAE: 639.6288, MSE: 702158.1728
Epoch 3/5000 - MAE: 513.6589, MSE: 440828.4197
Epoch 4/5000 - MAE: 378.7358, MSE: 246965.9924
Epoch 5/5000 - MAE: 198.6490, MSE: 88748.9530
Epoch 6/5000 - MAE: 127.9689, MSE: 39773.8005
Epoch 7/5000 - MAE: 114.7482, MSE: 33059.4392
Epoch 8/5000 - MAE: 107.3390, MSE: 26810.4762
Epoch 9/5000 - MAE: 109.6244, MSE: 27803.9835
Epoch 10/5000 - MAE: 102.8880, MSE: 24663.0698
Epoch 11/5000 - MAE: 103.7564, MSE: 25952.9757
Epoch 12/5000 - MAE: 106.5233, MSE: 26680.1830
Epoch 13/5000 - MAE: 100.2712, MSE: 23901.5317
Epoch 14/5000 - MAE: 102.7350, MSE: 25787.1285
Epoch 15/5000 - MAE: 104.7859, MSE: 25998.0527
Epoch 16/5000 - MAE: 99.5965, MSE: 23354.1059
Epoch 17/5000 - MAE: 101.1055, MSE: 23364.3953
Epoch 18/5000 - MAE: 102.1044, MSE: 23818.0264
Epoch 19/5000 - MAE: 101.4330, MSE: 25895.1333
Epoch 20/5000 - MAE: 102.4132, MSE: 24396.6098
Epoch 21/5000 - MAE: 102.2964, MSE: 24152.5214
Epoch 22/5000 - MAE