In [10]:
import numpy as np
import polars as pl

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from transformers import BertTokenizer, BertModel

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity


The following required CPU features were not detected:
    ssse3, sse4.1, sse4.2, popcnt
Continuing to use this version of Polars on this processor will likely result in a crash.
Install the `polars-lts-cpu` package instead of `polars` to run Polars with better compatibility.

Hint: If you are on an Apple ARM machine (e.g. M1) this is likely due to running Python under Rosetta.
It is recommended to install a native version of Python that does not run under Rosetta x86-64 emulation.




In [11]:
class MatchingNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MatchingNetwork, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, support_set, query_set):
        # Encode support and query sets
        support_encoded = self.encoder(support_set)
        query_encoded = self.encoder(query_set)

        # Calculate cosine similarity
        similarity = F.cosine_similarity(query_encoded.unsqueeze(1), support_encoded.unsqueeze(0), dim=2)

        # Compute attention weights
        attention_weights = F.softmax(similarity, dim=1)

        # Weighted sum of support outputs
        weighted_output = torch.bmm(attention_weights.unsqueeze(1), support_set.unsqueeze(0)).squeeze(1)

        # Final classification
        output = self.output_layer(weighted_output)
        return output

In [12]:
# Generating a synthetic dataset
X, y = make_classification(n_samples=1000, n_features=20, n_classes=5, n_informative=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [13]:
# Function to create support and query sets
def create_few_shot_sets(X, y, n_support=5, n_query=15):
    unique_classes = list(set(y))
    support_set = []
    query_set = []
    support_labels = []
    query_labels = []

    for cls in unique_classes:
        cls_indices = [i for i, label in enumerate(y) if label == cls]
        selected_indices = np.random.choice(cls_indices, size=n_support + n_query, replace=False)
        support_set.extend(X[selected_indices[:n_support]])
        query_set.extend(X[selected_indices[n_support:]])
        support_labels.extend([cls] * n_support)
        query_labels.extend([cls] * n_query)

    return torch.tensor(support_set, dtype=torch.float32), torch.tensor(query_set, dtype=torch.float32), torch.tensor(support_labels), torch.tensor(query_labels)

In [14]:
def train_matching_network(model, support_set, query_set, support_labels, query_labels, epochs=100, lr=0.001):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        # Forward pass
        outputs = model(support_set, query_set)

        # Compute loss
        loss = criterion(outputs, query_labels)
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0:
            print(f'Epoch [{epoch}/{epochs}], Loss: {loss.item():.4f}')

In [15]:
def evaluate(model, support_set, query_set, support_labels, query_labels):
    model.eval()
    with torch.no_grad():
        outputs = model(support_set, query_set)
        _, predicted = torch.max(outputs, 1)
        accuracy = (predicted == query_labels).float().mean()
        print(f'Accuracy: {accuracy.item():.4f}')

In [16]:
# Create support and query sets
support_set, query_set, support_labels, query_labels = create_few_shot_sets(X_train, y_train)

# Instantiate the model
model = MatchingNetwork(input_dim=20, hidden_dim=64, output_dim=5)

  return torch.tensor(support_set, dtype=torch.float32), torch.tensor(query_set, dtype=torch.float32), torch.tensor(support_labels), torch.tensor(query_labels)


In [17]:
# Train the model
train_matching_network(model, support_set, query_set, support_labels, query_labels)

# Evaluate the model
evaluate(model, support_set, query_set, support_labels, query_labels)

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [75, 25] but got: [1, 25].