In [None]:
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os

class FeatureExtractor:
    def __init__(self):
        # Use ResNet18 as feature extractor
        self.model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
        # Remove the last fully connected layer
        self.model = nn.Sequential(*list(self.model.children())[:-1])
        self.model.eval()

        self.transform = transforms.Compose([
            transforms.Resize((224,224)),
            # transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])

    def extract_features(self, image):
        with torch.no_grad():
            # Convert NumPy array to PIL Image
            if isinstance(image, np.ndarray):
                image = Image.fromarray(image)

            image = self.transform(image).unsqueeze(0)
            features = self.model(image)
            return features.squeeze().numpy()

class LwPClassifier:
    def __init__(self, feature_dim=512, num_classes=10, lambda_reg=0.1):
        self.weights = np.zeros((num_classes, feature_dim))
        self.lambda_reg = lambda_reg
        self.feature_extractor = FeatureExtractor()

    def fit(self, X, y, sample_weights=None):
        if sample_weights is None:
            sample_weights = np.ones(len(X))

        # For each class
        for c in range(self.weights.shape[0]):
            # Get samples for current class
            mask = (y == c)
            if not np.any(mask):
                continue

            X_c = X[mask]
            weights_c = sample_weights[mask]

            # Weighted average of features for this class
            self.weights[c] = np.average(X_c, axis=0, weights=weights_c)

        # L2 normalization of weight vectors
        norms = np.linalg.norm(self.weights, axis=1, keepdims=True)
        self.weights = self.weights / (norms + 1e-8)

    def predict(self, X):
        # Compute similarities with all class prototypes
        similarities = np.dot(X, self.weights.T)
        return np.argmax(similarities, axis=1)

    def predict_proba(self, X):
        # Compute similarities and convert to probabilities using softmax
        similarities = np.dot(X, self.weights.T)
        exp_sim = np.exp(similarities / self.lambda_reg)
        return exp_sim / exp_sim.sum(axis=1, keepdims=True)

def update_model(current_model, new_data, predicted_labels, confidence_threshold=0.8):
    """Update model with pseudo-labeled data using confidence thresholding"""
    probs = current_model.predict_proba(new_data)
    max_probs = np.max(probs, axis=1)

    # Filter samples based on confidence
    confident_mask = max_probs >= confidence_threshold
    if not np.any(confident_mask):
        # If no confident samples, lower the threshold adaptively
        confidence_threshold = np.percentile(max_probs, 70)
        confident_mask = max_probs >= confidence_threshold

    confident_data = new_data[confident_mask]
    confident_labels = predicted_labels[confident_mask]
    confident_weights = max_probs[confident_mask]

    # Update model using weighted samples
    current_model.fit(confident_data, confident_labels, sample_weights=confident_weights)
    return current_model

# Define the feature extraction process
def process_dataset(data, feature_extractor):
    """Process a dataset and extract features"""
    features_list = []
    for img in data:
        features = feature_extractor.extract_features(img)
        features_list.append(features)
    return np.array(features_list)

# Task 1: Same distribution datasets
def task1(base_path, feature_extractor):
    models = []
    accuracies = np.zeros((10, 10))  # 10 models x 10 datasets

    # Initialize with labeled dataset D1

    # Load dataset from serialized file
    dataset_path = os.path.join(base_path, '1_train_data.tar.pth')
    dataset = torch.load(dataset_path)  # Load the serialized dataset
    print(dataset.keys())  # Check the available keys: should print 'data' and 'targets'

    # Extract data and targets
    data, targets = dataset['data'], dataset['targets']  # data: list of images, targets: labels

    # Process the features from the loaded data
    D1_features = process_dataset(data, feature_extractor)  # Process features directly from the loaded data
    D1_labels = targets  # Assign targets as labels (already in memory)

    # Initialize first model
    model = LwPClassifier()
    model.fit(D1_features, D1_labels)
    models.append(model)

    test_features = []
    test_labels = []

    for j in range(1, 11):
        # Load dataset from serialized file
        dataset_path = os.path.join('./drive/MyDrive/mini-project-2/dataset/part_one_dataset/eval_data', f'{j}_eval_data.tar.pth')
        dataset = torch.load(dataset_path)  # Load the serialized dataset
        print(dataset.keys())  # Check the available keys: should print 'data' and 'targets'

        # Extract data and targets
        data, targets = dataset['data'], dataset['targets']  # data: list of images, targets: labels

        # Process the features from the loaded data
        x = process_dataset(data, feature_extractor)  # Process features directly from the loaded data
        test_features.append(x)
        test_labels.append(targets)  # Assign targets as labels (already in memory)

    accuracies[0, 0] = (model.predict(test_features[0]) == test_labels[0]).mean()

    # Iteratively process D2 to D10
    for i in range(2, 11):
        prev_model = models[-1]

        # # Process current dataset
        # current_features = process_dataset(os.path.join(base_path, f'D{i}'))

        # Load dataset from serialized file
        dataset_path = os.path.join(base_path, f'{i}_train_data.tar.pth')
        dataset = torch.load(dataset_path)  # Load the serialized dataset
        print(dataset.keys())  # Check the available keys: should print 'data' and 'targets'

        # Extract data and targets
        data= dataset['data']  # data: list of images

        # Process the features from the loaded data
        current_features = process_dataset(data, feature_extractor)  # Process features directly from the loaded data

        # Get predictions from previous model
        predictions = prev_model.predict(current_features)

        # Update model
        new_model = update_model(prev_model, current_features, predictions)
        models.append(new_model)

        # Evaluate on all previous datasets
        for j in range(1, i+1):
            accuracies[i-1, j-1] = (new_model.predict(test_features[j-1]) == test_labels[j-1]).mean()

    return models, accuracies




# Base path to datasets
base_path1 = "./drive/MyDrive/mini-project-2/dataset/part_one_dataset/train_data"

# Create feature extractor
feature_extractor = FeatureExtractor()

# Task 1: Same distribution datasets
print("Running Task 1...")
task1_models, task1_accuracies = task1(base_path1, feature_extractor)

print("\nTask 1 Accuracy Matrix:")
print(task1_accuracies)

# Optional: Save accuracy matrices
np.save('task1_accuracies.npy', task1_accuracies)


Running Task 1...
dict_keys(['data', 'targets'])


  dataset = torch.load(dataset_path)  # Load the serialized dataset
  dataset = torch.load(dataset_path)  # Load the serialized dataset


dict_keys(['data', 'targets'])
dict_keys(['data', 'targets'])
dict_keys(['data', 'targets'])
dict_keys(['data', 'targets'])
dict_keys(['data', 'targets'])
dict_keys(['data', 'targets'])
dict_keys(['data', 'targets'])
dict_keys(['data', 'targets'])
dict_keys(['data', 'targets'])
dict_keys(['data', 'targets'])


  dataset = torch.load(dataset_path)  # Load the serialized dataset


dict_keys(['data'])
dict_keys(['data'])
dict_keys(['data'])
dict_keys(['data'])
dict_keys(['data'])
dict_keys(['data'])
dict_keys(['data'])
dict_keys(['data'])
dict_keys(['data'])

Task 1 Accuracy Matrix:
[[0.8152 0.     0.     0.     0.     0.     0.     0.     0.     0.    ]
 [0.7916 0.8    0.     0.     0.     0.     0.     0.     0.     0.    ]
 [0.7752 0.7916 0.7932 0.     0.     0.     0.     0.     0.     0.    ]
 [0.774  0.7816 0.7896 0.7828 0.     0.     0.     0.     0.     0.    ]
 [0.7744 0.7908 0.784  0.788  0.7848 0.     0.     0.     0.     0.    ]
 [0.7656 0.7808 0.7788 0.78   0.7756 0.778  0.     0.     0.     0.    ]
 [0.7684 0.7768 0.778  0.778  0.7788 0.776  0.7728 0.     0.     0.    ]
 [0.7636 0.7764 0.7736 0.774  0.78   0.7748 0.7636 0.7648 0.     0.    ]
 [0.762  0.772  0.7788 0.7796 0.7796 0.7764 0.77   0.7664 0.7576 0.    ]
 [0.7572 0.7744 0.776  0.7764 0.7752 0.774  0.7688 0.7644 0.7596 0.7868]]
