In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision.transforms import Resize
from sklearn.metrics import accuracy_score
import numpy as np
import os

# Define FeatureExtractor
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.efficientnet = models.efficientnet_b0(pretrained=True)
        self.efficientnet.eval()
        self.efficientnet.classifier = nn.Identity()

    def forward(self, x):
        x = Resize((160, 160))(x)
        x = self.efficientnet(x)
        return x


# Function to compute prototypes
def compute_prototypes(features, labels, n_classes):
    """Compute prototypes for each class"""
    prototypes = []
    for class_idx in range(n_classes):
        class_features = features[labels == class_idx]
        if len(class_features) > 0:
            class_prototype = class_features.mean(0)
            prototypes.append(class_prototype)
    return torch.stack(prototypes)


# Continual Learner Class
class ContinualLearner:
    def __init__(self, feature_extractor, n_classes=10, output_dir="saved_data2"):
        self.feature_extractor = feature_extractor
        self.n_classes = n_classes
        self.prototypes = None
        self.output_dir = output_dir

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

    def extract_features(self, dataloader):
        all_features = []
        all_labels = []

        with torch.no_grad():
            for batch in dataloader:
                if isinstance(batch, (tuple, list)):
                    images, labels = batch
                    all_labels.append(labels)
                else:
                    images = batch
                features = self.feature_extractor(images)
                all_features.append(features)

        features = torch.cat(all_features)
        labels = torch.cat(all_labels) if all_labels else None
        return features, labels

    def save_features(self, features, labels, dataset_id):
        """Save extracted features and labels to a file."""
        torch.save(
            {"features": features.cpu(), "labels": labels.cpu() if labels is not None else None},
            os.path.join(self.output_dir, f"features_dataset_{dataset_id}.pth"),
        )

    def save_prototypes(self):
        """Save the current prototypes to a file."""
        torch.save(self.prototypes.cpu(), os.path.join(self.output_dir, "final_prototypes.pth"))

    def assign_pseudo_labels(self, features):
        """Assign pseudo-labels using current prototypes"""
        distances = torch.cdist(features, self.prototypes)
        pseudo_labels = torch.argmin(distances, dim=1)
        confidence_scores = torch.softmax(-distances, dim=1).max(dim=1)[0]
        return pseudo_labels, confidence_scores

    def update_prototypes(self, features, pseudo_labels, confidence_scores, threshold=0):
        """Update prototypes using high-confidence samples"""
        new_prototypes = []
        for class_idx in range(self.n_classes):
            mask = (pseudo_labels == class_idx) & (confidence_scores > threshold)
            if mask.sum() > 0:
                class_features = features[mask]
                new_prototype = class_features.mean(0)
                new_prototypes.append(new_prototype)
            else:
                new_prototypes.append(self.prototypes[class_idx])
        return torch.stack(new_prototypes)

    def load_initial_prototypes(self, path="saved_data/final_prototypes.pth"):
        """Load initial prototypes from task 1"""
        self.prototypes = torch.load(path)
        print("Initial prototypes loaded from task 1")

    def train_iteration(self, train_loader, dataset_id=None, initial=False):
        """Train on a new dataset, using saved features if available."""
        if dataset_id is not None:
            # Check if saved training features exist
            saved_path = os.path.join(self.output_dir, f"features_dataset_{dataset_id}.pth")
            if os.path.exists(saved_path):
                print(f"Loading saved training features for dataset {dataset_id} from {saved_path}")
                saved_data = torch.load(saved_path)
                features = saved_data["features"]
                labels = saved_data["labels"]
            else:
                print(f"Saved features for training dataset {dataset_id} not found. Extracting features...")
                features, labels = self.extract_features(train_loader)
                self.save_features(features, labels, dataset_id)  # Save extracted features for future use
        else:
            print("Dataset ID not provided, extracting features directly...")
            features, labels = self.extract_features(train_loader)

        if initial:
            # For the first dataset in task 2, use initial prototypes from task 1
            self.load_initial_prototypes()
        else:
            # For subsequent datasets, use pseudo-labels
            pseudo_labels, confidence_scores = self.assign_pseudo_labels(features)
            self.prototypes = 0.9 * self.prototypes + 0.1 * self.update_prototypes(
                features, pseudo_labels, confidence_scores
            )

        return features, labels

    def save_eval_features(self, features, labels, dataset_id):
        """Save extracted evaluation features and labels to a file."""
        torch.save(
        {"features": features.cpu(), "labels": labels.cpu()},
        os.path.join(self.output_dir, f"eval_features_dataset_{dataset_id}.pth"),
        )

    def evaluate(self, eval_loader, dataset_id=None):
        """Evaluate on a labeled dataset, using saved features if available."""
        if dataset_id is not None:
            # Check if saved features exist
            saved_path = os.path.join(self.output_dir, f"eval_features_dataset_{dataset_id}.pth")
            if os.path.exists(saved_path):
                print(f"Loading saved features for evaluation dataset {dataset_id} from {saved_path}")
                saved_data = torch.load(saved_path)
                features = saved_data["features"]
                labels = saved_data["labels"]
            else:
                print(f"Saved features for dataset {dataset_id} not found. Extracting features...")
                features, labels = self.extract_features(eval_loader)
                self.save_eval_features(features, labels, dataset_id)  # Save extracted features for future use
        else:
            print("Dataset ID not provided, extracting features directly...")
            features, labels = self.extract_features(eval_loader)

        # Perform evaluation using extracted or loaded features
        distances = torch.cdist(features, self.prototypes)
        predictions = torch.argmin(distances, dim=1)
        accuracy = accuracy_score(labels.cpu().numpy(), predictions.cpu().numpy())
        return accuracy


# Function to load dataset
def load_dataset(path, has_labels=False):
    """Load and preprocess dataset"""
    data = torch.load(path)
    images = torch.tensor(data["data"], dtype=torch.float32).permute(0, 3, 1, 2)

    # Normalize using ImageNet statistics
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])
    images = (images / 255.0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)

    if has_labels:
        labels = torch.tensor(data["targets"], dtype=torch.long)
        return DataLoader(list(zip(images, labels)), batch_size=64, shuffle=True)
    return DataLoader(images, batch_size=64, shuffle=False)


# Training on Subsequent Datasets
def train_subsequent_datasets(base_path, eval_base_path, learner, num_datasets=10):
    """
    Function to train the model on subsequent datasets and evaluate on previous datasets.
    This function should be run after training on the final dataset of task 1.
    """
    # Initialize results matrix
    accuracies = np.zeros((num_datasets, 10))  # 20 columns to accommodate both task 1 and task 2 datasets

    # Train on subsequent datasets (D11 to D20)
    for i in range(1, 11):
        print(f"\nProcessing dataset D{i}")

        # Load current training dataset (first iteration uses initial prototypes from task 1)
        train_data = load_dataset(f"{base_path}/{i}_train_data.tar.pth", has_labels=False)

        # Train on current dataset (initial is True only for the first dataset)
        features, labels = learner.train_iteration(train_data, dataset_id=i, initial=(i==1))

        # Save extracted features
        learner.save_features(features, labels, dataset_id=i)

        # Evaluate on all previous datasets (including task 1 datasets)
        for j in range(1, i + 1):
            eval_data = load_dataset(f"{eval_base_path}/{j}_eval_data.tar.pth", has_labels=True)
            accuracy = learner.evaluate(eval_data, dataset_id=j)
            accuracies[i - 1, j - 1] = accuracy
            print(f"Model {i} accuracy on D{j}: {accuracy:.4f}")

    # Save final prototypes after training on all datasets
    learner.save_prototypes()

    return accuracies


# Main Script
if __name__ == "__main__":
    base_path = "dataset/part_two_dataset/train_data"
    eval_base_path = "dataset/part_two_dataset/eval_data"

    # Initialize feature extractor and learner
    feature_extractor = FeatureExtractor()
    learner = ContinualLearner(feature_extractor, output_dir="saved_data2")

    num_datasets = 10
    print("Starting continual learning training on subsequent datasets...")
    accuracies = train_subsequent_datasets(base_path, eval_base_path, learner, num_datasets)

    print("\nFinal Accuracy Matrix:")
    print(accuracies)



Starting continual learning training on subsequent datasets...

Processing dataset D1
Loading saved training features for dataset 1 from saved_data2\features_dataset_1.pth
Initial prototypes loaded from task 1
Loading saved features for evaluation dataset 1 from saved_data2\eval_features_dataset_1.pth
Model 1 accuracy on D1: 0.6836

Processing dataset D2
Loading saved training features for dataset 2 from saved_data2\features_dataset_2.pth
Loading saved features for evaluation dataset 1 from saved_data2\eval_features_dataset_1.pth
Model 2 accuracy on D1: 0.6820
Loading saved features for evaluation dataset 2 from saved_data2\eval_features_dataset_2.pth
Model 2 accuracy on D2: 0.5348

Processing dataset D3
Loading saved training features for dataset 3 from saved_data2\features_dataset_3.pth
Loading saved features for evaluation dataset 1 from saved_data2\eval_features_dataset_1.pth
Model 3 accuracy on D1: 0.6804
Loading saved features for evaluation dataset 2 from saved_data2\eval_featur