In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.cluster import KMeans
from torch.utils.data import DataLoader
from torchvision.transforms import Resize
from scipy.optimize import linear_sum_assignment

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.efficientnet = models.efficientnet_b0(pretrained=True)
        self.efficientnet.eval()
        self.efficientnet.classifier = nn.Identity()
        
        with torch.no_grad():
            dummy_input = torch.zeros(1, 3, 224, 224)
            feature_dim = self.efficientnet(dummy_input).shape[1]
            

    def forward(self, x):
        x = Resize((160, 160))(x)
        x = self.efficientnet(x)
        return x

def compute_prototypes(features, labels, n_classes):
    """Compute prototypes for each class"""
    prototypes = []
    for class_idx in range(n_classes):
        class_features = features[labels == class_idx]
        if len(class_features) > 0:
            class_prototype = class_features.mean(0)
            prototypes.append(class_prototype)
    return torch.stack(prototypes)

class ContinualLearner:
    def __init__(self, feature_extractor, n_classes=10):
        self.feature_extractor = feature_extractor
        self.n_classes = n_classes
        self.prototypes = None
        
    def extract_features(self, dataloader):
        all_features = []
        all_labels = []
        
        with torch.no_grad():
            for batch in dataloader:
                if isinstance(batch, (tuple, list)):
                    images, labels = batch
                    all_labels.append(labels)
                else:
                    images = batch
                features = self.feature_extractor(images)
                all_features.append(features)
        
        features = torch.cat(all_features)
        labels = torch.cat(all_labels) if all_labels else None
        return features, labels

    def assign_pseudo_labels(self, features):
        """Assign pseudo-labels using current prototypes"""
        distances = torch.cdist(features, self.prototypes)
        pseudo_labels = torch.argmin(distances, dim=1)
        confidence_scores = torch.softmax(-distances, dim=1).max(dim=1)[0]
        return pseudo_labels, confidence_scores

    def update_prototypes(self, features, pseudo_labels, confidence_scores, threshold=0):
        """Update prototypes using high-confidence samples"""
        new_prototypes = []
        for class_idx in range(self.n_classes):
            mask = (pseudo_labels == class_idx) & (confidence_scores > threshold)
            if mask.sum() > 0:
                class_features = features[mask]
                new_prototype = class_features.mean(0)
                new_prototypes.append(new_prototype)
            else:
                new_prototypes.append(self.prototypes[class_idx])
        return torch.stack(new_prototypes)

    def train_iteration(self, train_loader, initial=False):
        """Train on a new dataset"""
        features, labels = self.extract_features(train_loader)
        
        if initial:
            # For the first dataset, use true labels
            self.prototypes = compute_prototypes(features, labels, self.n_classes)
        else:
            # For subsequent datasets, use pseudo-labels
            pseudo_labels, confidence_scores = self.assign_pseudo_labels(features)
            self.prototypes = 0.8 * self.prototypes + 0.2 * self.update_prototypes(
                features, pseudo_labels, confidence_scores
            )

    def evaluate(self, eval_loader):
        """Evaluate on a labeled dataset"""
        features, labels = self.extract_features(eval_loader)
        distances = torch.cdist(features, self.prototypes)
        predictions = torch.argmin(distances, dim=1)
        return accuracy_score(labels.cpu().numpy(), predictions.cpu().numpy())

def load_dataset(path, has_labels=False):
    """Load and preprocess dataset"""
    data = torch.load(path)
    images = torch.tensor(data['data'], dtype=torch.float32).permute(0, 3, 1, 2)
    
    # Normalize using ImageNet statistics
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])
    images = (images / 255.0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
    
    if has_labels:
        labels = torch.tensor(data['targets'], dtype=torch.long)
        return DataLoader(list(zip(images, labels)), batch_size=64, shuffle=True)
    return DataLoader(images, batch_size=64, shuffle=False)

def train_on_d1(base_path, eval_base_path):
    """
    Function to train the model on dataset D1 and evaluate on D1.
    This should be run first.
    """
    # Initialize feature extractor and learner
    feature_extractor = FeatureExtractor()
    learner = ContinualLearner(feature_extractor)
    
    # Load D1 training and evaluation data
    print("Training on D1...")
    d1_train = load_dataset(f"{base_path}/1_train_data.tar.pth", has_labels=True)
    d1_eval = load_dataset(f"{eval_base_path}/1_eval_data.tar.pth", has_labels=True)
    
    # Train on D1
    learner.train_iteration(d1_train, initial=True)
    
    # Evaluate on D1
    accuracy = learner.evaluate(d1_eval)
    print(f"Initial accuracy on D1: {accuracy:.4f}")
    
    return learner, accuracy


def train_subsequent_datasets(base_path, eval_base_path, learner, num_datasets=10):
    """
    Function to train the model on subsequent datasets and evaluate on previous datasets.
    This function should be run after training on D1.
    """
    # Initialize results matrix
    accuracies = np.zeros((num_datasets, num_datasets))
    
    # Record D1 accuracy
    accuracies[0, 0] = learner.evaluate(load_dataset(f"{eval_base_path}/1_eval_data.tar.pth", has_labels=True))
    
    # Train on subsequent datasets (D2 to Dn)
    for i in range(2, num_datasets + 1):
        print(f"\nProcessing dataset D{i}")
        
        # Load current training dataset
        train_data = load_dataset(f"{base_path}/{i}_train_data.tar.pth", has_labels=False)
        
        # Train on current dataset
        learner.train_iteration(train_data, initial=False)
        
        # Evaluate on D1
        for j in range(1, i+1):
            eval_data = load_dataset(f"{eval_base_path}/{j}_eval_data.tar.pth", has_labels=True)
            accuracy = learner.evaluate(eval_data)
            accuracies[i-1, j-1] = accuracy
            print(f"Model {i} accuracy on D{j}: {accuracy:.4f}")
    
    return accuracies

# Example of usage in Jupyter cells:

# Cell 1: Train on D1


In [2]:
if __name__ == "__main__":
    base_path = "dataset/part_one_dataset/train_data"
    eval_base_path = "dataset/part_one_dataset/eval_data"
    
    print("Starting training on D1...")
    learner, initial_accuracy = train_on_d1(base_path, eval_base_path)

Starting training on D1...




Training on D1...
Initial accuracy on D1: 0.8468


In [3]:
if __name__ == "__main__":
    num_datasets = 10
    print("Starting continual learning training on subsequent datasets...")
    accuracies = train_subsequent_datasets(base_path, eval_base_path, learner, num_datasets)
    
    print("\nFinal Accuracy Matrix:")
    print(accuracies)

Starting continual learning training on subsequent datasets...

Processing dataset D2
Model 2 accuracy on D1: 0.8428
Model 2 accuracy on D2: 0.8472

Processing dataset D3
Model 3 accuracy on D1: 0.8444
Model 3 accuracy on D2: 0.8452
Model 3 accuracy on D3: 0.8344

Processing dataset D4
Model 4 accuracy on D1: 0.8396
Model 4 accuracy on D2: 0.8440
Model 4 accuracy on D3: 0.8340
Model 4 accuracy on D4: 0.8384

Processing dataset D5
Model 5 accuracy on D1: 0.8356
Model 5 accuracy on D2: 0.8400
Model 5 accuracy on D3: 0.8344
Model 5 accuracy on D4: 0.8388
Model 5 accuracy on D5: 0.8416

Processing dataset D6
Model 6 accuracy on D1: 0.8336
Model 6 accuracy on D2: 0.8356
Model 6 accuracy on D3: 0.8268
Model 6 accuracy on D4: 0.8336
Model 6 accuracy on D5: 0.8388
Model 6 accuracy on D6: 0.8308

Processing dataset D7
Model 7 accuracy on D1: 0.8324
Model 7 accuracy on D2: 0.8352
Model 7 accuracy on D3: 0.8228
Model 7 accuracy on D4: 0.8320
Model 7 accuracy on D5: 0.8376
Model 7 accuracy on D6: 