In [4]:
import torch
import torch.nn.functional as F
import numpy as np

def load_dataset(path, labeled=True):
    data = torch.load(path)
    features = data.get('features', None)
    if features is None:
        raise ValueError(f"Features missing in dataset: {path}")
    features = torch.tensor(features, dtype=torch.float32)

    if labeled:
        labels = data.get('labels', None)
        if labels is None:
            raise ValueError(f"Labels missing in labeled dataset: {path}")
        labels = torch.tensor(labels, dtype=torch.long)
        return features, labels
    return features

def calculate_prototypes(features, labels, num_classes=10):
    prototypes = []
    for c in range(num_classes):
        class_features = features[labels == c]
        if len(class_features) > 0:
            prototypes.append(class_features.mean(dim=0))
        else:
            prototypes.append(torch.zeros(features.shape[1]))
    return torch.stack(prototypes)

def predict(features, prototypes):
    """Predict labels based on nearest prototype."""
    distances = torch.cdist(features, prototypes, p=2)  # L2 distance
    return torch.argmin(distances, dim=1)

def prototype_contrastive_loss(features, pseudo_labels, prototypes, temperature=0.1):
    """Compute prototype contrastive loss."""
    logits = F.cosine_similarity(features.unsqueeze(1), prototypes.unsqueeze(0), dim=2) / temperature
    loss = F.cross_entropy(logits, pseudo_labels)
    return loss

def knowledge_distillation_loss(student_logits, teacher_logits, temperature=2.0):
    """Compute knowledge distillation loss."""
    teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
    student_probs = F.log_softmax(student_logits / temperature, dim=1)
    return F.kl_div(student_probs, teacher_probs, reduction="batchmean") * (temperature ** 2)

def train_labeled_dataset(features, labels, num_classes=10):
    """Train on labeled dataset to compute initial prototypes."""
    prototypes = calculate_prototypes(features, labels, num_classes)
    return torch.nn.Parameter(prototypes)

def continual_learning(features, prototypes, prev_model, pseudo_labels, num_classes=10, lr=0.01):
    optimizer = torch.optim.SGD([prototypes], lr=lr)
    
    for _ in range(10):  # Fixed number of iterations
        optimizer.zero_grad()
        pcl_loss = prototype_contrastive_loss(features, pseudo_labels, prototypes)
        if prev_model is not None:
            with torch.no_grad():
                teacher_logits = F.cosine_similarity(features.unsqueeze(1), prev_model.unsqueeze(0), dim=2)
            student_logits = F.cosine_similarity(features.unsqueeze(1), prototypes.unsqueeze(0), dim=2)
            kd_loss = knowledge_distillation_loss(student_logits, teacher_logits)
            loss = pcl_loss + kd_loss
        else:
            loss = pcl_loss
        loss.backward()
        optimizer.step()
    return prototypes

def evaluate(features, labels, prototypes):
    """Evaluate the model and return accuracy."""
    predictions = predict(features, prototypes)
    accuracy = (predictions == labels).float().mean().item()
    return accuracy

def continual_learning_unlabeled(data_path, initial_prototypes, num_datasets=10):
    """Continual learning on unlabeled data using pseudo-labels."""
    prototypes = initial_prototypes
    prev_prototypes = None
    accuracy_matrix = np.zeros((num_datasets, num_datasets))
    
    for i in range(1, num_datasets + 1):
        train_features = load_dataset(f"{data_path}/features_dataset_{i}.pth", labeled=False)
        pseudo_labels = predict(train_features, prototypes)
        prototypes = continual_learning(train_features, prototypes, prev_prototypes, pseudo_labels)
        prev_prototypes = prototypes.detach().clone()
        
        for j in range(1, i + 1):
            eval_features, eval_labels = load_dataset(f"{data_path}/eval_features_dataset_{j}.pth")
            accuracy_matrix[i - 1, j - 1] = evaluate(eval_features, eval_labels, prototypes)
    
    return accuracy_matrix

def continual_learning_pipeline(data_path_1, data_path_2, eval_data_path, num_datasets=10):
    print("Training on labeled data from saved_data...")
    train_features, train_labels = load_dataset(f"{data_path_1}/features_dataset_1.pth")
    initial_prototypes = train_labeled_dataset(train_features, train_labels)

    print("Continual learning on saved_data2 (unlabeled)...")
    accuracy_matrix = continual_learning_unlabeled(data_path_2, initial_prototypes, num_datasets)
    
    return accuracy_matrix

data_path_1 = "saved_data"
data_path_2 = "saved_data2" 
eval_data_path = "saved_data2" 
num_datasets = 10

accuracy_matrix = continual_learning_pipeline(data_path_1, data_path_2, eval_data_path, num_datasets)
print("Accuracy Matrix for Saved Data 2:\n", accuracy_matrix)


Training on labeled data from saved_data...
Continual learning on saved_data2 (unlabeled)...


  data = torch.load(path)
  features = torch.tensor(features, dtype=torch.float32)
  labels = torch.tensor(labels, dtype=torch.long)


Accuracy Matrix for Saved Data 2:
 [[0.69440001 0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.69400001 0.53839999 0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.69400001 0.53839999 0.74720001 0.         0.         0.
  0.         0.         0.         0.        ]
 [0.69400001 0.53839999 0.74720001 0.79159999 0.         0.
  0.         0.         0.         0.        ]
 [0.69400001 0.53839999 0.74720001 0.79159999 0.83039999 0.
  0.         0.         0.         0.        ]
 [0.69400001 0.53839999 0.74720001 0.79159999 0.82999998 0.71359998
  0.         0.         0.         0.        ]
 [0.69400001 0.53839999 0.74720001 0.79159999 0.82999998 0.71359998
  0.72320002 0.         0.         0.        ]
 [0.69400001 0.53839999 0.74720001 0.79159999 0.82999998 0.71280003
  0.72280002 0.7252     0.         0.        ]
 [0.69400001 0.53839999 0.74720001 0.79159999 0.82999998 0.71280003
  0.72280002 0.72