In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision.transforms import Resize
from sklearn.metrics import accuracy_score
import numpy as np
import os

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.efficientnet = models.efficientnet_b0(pretrained=True)
        self.efficientnet.eval()
        self.efficientnet.classifier = nn.Identity()

    def forward(self, x):
        x = Resize((160, 160))(x)
        x = self.efficientnet(x)
        return x


def compute_prototypes(features, labels, n_classes):
    prototypes = []
    for class_idx in range(n_classes):
        class_features = features[labels == class_idx]
        if len(class_features) > 0:
            class_prototype = class_features.mean(0)
            prototypes.append(class_prototype)
    return torch.stack(prototypes)


class ContinualLearner:
    def __init__(self, feature_extractor, n_classes=10, output_dir="saved_data"):
        self.feature_extractor = feature_extractor
        self.n_classes = n_classes
        self.prototypes = None
        self.output_dir = output_dir

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

    def extract_features(self, dataloader):
        all_features = []
        all_labels = []

        with torch.no_grad():
            for batch in dataloader:
                if isinstance(batch, (tuple, list)):
                    images, labels = batch
                    all_labels.append(labels)
                else:
                    images = batch
                features = self.feature_extractor(images)
                all_features.append(features)

        features = torch.cat(all_features)
        labels = torch.cat(all_labels) if all_labels else None
        return features, labels

    def save_features(self, features, labels, dataset_id):
        torch.save(
            {"features": features.cpu(), "labels": labels.cpu() if labels is not None else None},
            os.path.join(self.output_dir, f"features_dataset_{dataset_id}.pth"),
        )

    def save_prototypes(self):
        """Save the current prototypes to a file."""
        torch.save(self.prototypes.cpu(), os.path.join(self.output_dir, "final_prototypes.pth"))

    def assign_pseudo_labels(self, features):
        """Assign pseudo-labels using current prototypes"""
        distances = torch.cdist(features, self.prototypes)
        pseudo_labels = torch.argmin(distances, dim=1)
        confidence_scores = torch.softmax(-distances, dim=1).max(dim=1)[0]
        return pseudo_labels, confidence_scores

    def update_prototypes(self, features, pseudo_labels, confidence_scores, threshold=0):
        """Update prototypes using high-confidence samples"""
        new_prototypes = []
        for class_idx in range(self.n_classes):
            mask = (pseudo_labels == class_idx) & (confidence_scores > threshold)
            if mask.sum() > 0:
                class_features = features[mask]
                new_prototype = class_features.mean(0)
                new_prototypes.append(new_prototype)
            else:
                new_prototypes.append(self.prototypes[class_idx])
        return torch.stack(new_prototypes)

    def train_iteration(self, train_loader, dataset_id=None, initial=False):
        """Train on a new dataset, using saved features if available."""
        if dataset_id is not None:
            saved_path = os.path.join(self.output_dir, f"features_dataset_{dataset_id}.pth")
            if os.path.exists(saved_path):
                saved_data = torch.load(saved_path)
                features = saved_data["features"]
                labels = saved_data["labels"]
            else:
                features, labels = self.extract_features(train_loader)
                self.save_features(features, labels, dataset_id)  # Save extracted features for future use
        else:
            features, labels = self.extract_features(train_loader)

        if initial:
            # For the first dataset, use true labels
            self.prototypes = compute_prototypes(features, labels, self.n_classes)
        else:
            # For subsequent datasets, use pseudo-labels
            pseudo_labels, confidence_scores = self.assign_pseudo_labels(features)
            self.prototypes = 0.8 * self.prototypes + 0.2 * self.update_prototypes(
                features, pseudo_labels, confidence_scores
            )

        return features, labels

    def save_eval_features(self, features, labels, dataset_id):
        torch.save(
        {"features": features.cpu(), "labels": labels.cpu()},
        os.path.join(self.output_dir, f"eval_features_dataset_{dataset_id}.pth"),
        )

    def evaluate(self, eval_loader, dataset_id=None):
        if dataset_id is not None:
            # Check if saved features exist
            saved_path = os.path.join(self.output_dir, f"eval_features_dataset_{dataset_id}.pth")
            if os.path.exists(saved_path):
                saved_data = torch.load(saved_path)
                features = saved_data["features"]
                labels = saved_data["labels"]
            else:
                features, labels = self.extract_features(eval_loader)
                self.save_eval_features(features, labels, dataset_id)  # Save extracted features for future use
        else:
            features, labels = self.extract_features(eval_loader)

        # Perform evaluation using extracted or loaded features
        distances = torch.cdist(features, self.prototypes)
        predictions = torch.argmin(distances, dim=1)
        accuracy = accuracy_score(labels.cpu().numpy(), predictions.cpu().numpy())
        return accuracy




def load_dataset(path, has_labels=False):
    """Load and preprocess dataset"""
    data = torch.load(path)
    images = torch.tensor(data["data"], dtype=torch.float32).permute(0, 3, 1, 2)

    # Normalize using ImageNet statistics
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])
    images = (images / 255.0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)

    if has_labels:
        labels = torch.tensor(data["targets"], dtype=torch.long)
        return DataLoader(list(zip(images, labels)), batch_size=64, shuffle=True)
    return DataLoader(images, batch_size=64, shuffle=False)


# Training on D1
def train_on_d1(base_path, eval_base_path):
    """
    Function to train the model on dataset D1 and evaluate on D1.
    This should be run first.
    """
    # Initialize feature extractor and learner
    feature_extractor = FeatureExtractor()
    learner = ContinualLearner(feature_extractor)

    # Load D1 training and evaluation data
    print("Training on D1...")
    d1_train = load_dataset(f"{base_path}/1_train_data.tar.pth", has_labels=True)
    d1_eval = load_dataset(f"{eval_base_path}/1_eval_data.tar.pth", has_labels=True)

    # Train on D1
    features, labels = learner.train_iteration(d1_train, dataset_id=1, initial=True)

    # Save extracted features
    learner.save_features(features, labels, dataset_id=1)

    accuracy = learner.evaluate(d1_eval, dataset_id=1)
    print(f"Initial accuracy on D1: {accuracy:.4f}")

    return learner, accuracy


# Training on Subsequent Datasets
def train_subsequent_datasets(base_path, eval_base_path, learner, num_datasets=10):

    # Initialize results matrix
    accuracies = np.zeros((num_datasets, num_datasets))

    # Record D1 accuracy
    accuracies[0, 0] = learner.evaluate(load_dataset(f"{eval_base_path}/1_eval_data.tar.pth", has_labels=True))

    # Train on subsequent datasets (D2 to Dn)
    for i in range(2, num_datasets + 1):
        print(f"\nProcessing dataset D{i}")

        train_data = load_dataset(f"{base_path}/{i}_train_data.tar.pth", has_labels=False)

        # Train on current dataset
        features, labels = learner.train_iteration(train_data, dataset_id=i, initial=False)

        # Save extracted features
        learner.save_features(features, labels, dataset_id=i)

        # Evaluate on all datasets up to current
        for j in range(1, i + 1):
            eval_data = load_dataset(f"{eval_base_path}/{j}_eval_data.tar.pth", has_labels=True)
            accuracy = learner.evaluate(eval_data, dataset_id=j)
            accuracies[i - 1, j - 1] = accuracy
            print(f"Model {i} accuracy on D{j}: {accuracy:.4f}")

    learner.save_prototypes()

    return accuracies


# Main Script
if __name__ == "__main__":
    base_path = "dataset/part_one_dataset/train_data"
    eval_base_path = "dataset/part_one_dataset/eval_data"

    print("Starting training on D1...")
    learner, initial_accuracy = train_on_d1(base_path, eval_base_path)

    num_datasets = 10
    print("Starting continual learning training on subsequent datasets...")
    accuracies = train_subsequent_datasets(base_path, eval_base_path, learner, num_datasets)

    print("\nFinal Accuracy Matrix:")
    print(accuracies)


Starting training on D1...




Training on D1...
Loading saved training features for dataset 1 from saved_data\features_dataset_1.pth
Loading saved features for evaluation dataset 1 from saved_data\eval_features_dataset_1.pth
Initial accuracy on D1: 0.8468
Starting continual learning training on subsequent datasets...
Dataset ID not provided, extracting features directly...

Processing dataset D2
Loading saved training features for dataset 2 from saved_data\features_dataset_2.pth
Loading saved features for evaluation dataset 1 from saved_data\eval_features_dataset_1.pth
Model 2 accuracy on D1: 0.8428
Loading saved features for evaluation dataset 2 from saved_data\eval_features_dataset_2.pth
Model 2 accuracy on D2: 0.8472

Processing dataset D3
Loading saved training features for dataset 3 from saved_data\features_dataset_3.pth
Loading saved features for evaluation dataset 1 from saved_data\eval_features_dataset_1.pth
Model 3 accuracy on D1: 0.8444
Loading saved features for evaluation dataset 2 from saved_data\eval_