In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
import torchvision.models as models
from collections import defaultdict

# Feature extractor using pretrained ResNet18
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet.children())[:-1])  # Remove FC layer

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)  # Flatten output
        return x

# Linear classifier with softmax prediction
class LwPClassifier(nn.Module):
    def __init__(self, feature_dim=512, num_classes=10):
        super(LwPClassifier, self).__init__()
        self.classifier = nn.Linear(feature_dim, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        return self.classifier(x)
    
    def predict_proba(self, x):
        with torch.no_grad():
            return self.softmax(self.forward(x))

# Sequential learning pipeline
class SequentialLearner:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.feature_extractor = FeatureExtractor().to(self.device)
        self.feature_extractor.eval()  # Freeze feature extractor
        self.models = {}
        self.results = defaultdict(dict)

    def preprocess_data(self, data):
        # Normalize and preprocess for ImageNet format
        data = data.astype(np.float32) / 255.0
        data = (data - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
        data = data.transpose(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)
        return torch.FloatTensor(data)

    def extract_features(self, data):
        data = self.preprocess_data(data)
        data = data.to(self.device)
        with torch.no_grad():
            features = self.feature_extractor(data)
        return features

    def train_initial_model(self, data, targets, model_id=1):
        features = self.extract_features(data)
        targets = torch.LongTensor(targets)

        # Initialize model
        model = LwPClassifier().to(self.device)
        optimizer = optim.Adam(model.parameters())
        criterion = nn.CrossEntropyLoss()

        # Dataset and DataLoader
        dataset = TensorDataset(features, targets)
        dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

        # Training loop
        model.train()
        for epoch in range(10):  # Adjust epochs if needed
            for batch_features, batch_targets in dataloader:
                batch_features, batch_targets = batch_features.to(self.device), batch_targets.to(self.device)
                optimizer.zero_grad()
                outputs = model(batch_features)
                loss = criterion(outputs, batch_targets)
                loss.backward()
                optimizer.step()

        self.models[model_id] = model
        return model

    def update_model(self, data, prev_model_id, new_model_id):
        features = self.extract_features(data)

        # Get predictions from previous model
        prev_model = self.models[prev_model_id]
        prev_model.eval()
        with torch.no_grad():
            pseudo_labels_prob = prev_model.predict_proba(features)
            pseudo_labels = torch.argmax(pseudo_labels_prob, dim=1)

        # Initialize new model and load weights from the previous model
        new_model = LwPClassifier().to(self.device)
        new_model.load_state_dict(prev_model.state_dict())

        optimizer = optim.Adam(new_model.parameters())
        criterion = nn.CrossEntropyLoss()

        # Dataset and DataLoader
        dataset = TensorDataset(features, pseudo_labels)
        dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

        # Knowledge distillation
        temperature = 2.0
        alpha = 0.5

        new_model.train()
        for epoch in range(5):  # Fewer epochs for updates
            for batch_features, batch_pseudo_labels in dataloader:
                batch_features = batch_features.to(self.device)
                batch_pseudo_labels = batch_pseudo_labels.to(self.device)

                optimizer.zero_grad()

                # Soft targets from previous model
                with torch.no_grad():
                    soft_targets = prev_model.predict_proba(batch_features)

                # Current model predictions
                outputs = new_model(batch_features)
                outputs_soft = new_model.softmax(outputs / temperature)

                # Combine hard and soft losses
                loss_hard = criterion(outputs, batch_pseudo_labels)
                loss_soft = nn.KLDivLoss(reduction='batchmean')(
                    torch.log(outputs_soft),
                    soft_targets
                ) * (temperature ** 2)

                loss = alpha * loss_hard + (1 - alpha) * loss_soft
                loss.backward()
                optimizer.step()

        self.models[new_model_id] = new_model
        return new_model

    def evaluate_model(self, model_id, eval_data, eval_targets):
        model = self.models[model_id]
        features = self.extract_features(eval_data)
        targets = torch.LongTensor(eval_targets).to(self.device)

        model.eval()
        with torch.no_grad():
            outputs = model(features.to(self.device))
            predictions = torch.argmax(outputs, dim=1)
            accuracy = (predictions == targets).float().mean().item()

        return accuracy * 100

    def run_sequential_learning(self, num_models=10):
        print("Training initial model...")
        initial_data = torch.load('part_one_dataset/train_data/1_train_data.tar.pth')
        self.train_initial_model(initial_data['data'], initial_data['targets'])

        for i in range(2, num_models + 1):
            print(f"Processing dataset {i}...")
            train_data = torch.load(f'part_one_dataset/train_data/{i}_train_data.tar.pth')
            self.update_model(train_data['data'], i - 1, i)

            for j in range(1, i + 1):
                eval_data = torch.load(f'part_one_dataset/eval_data/{j}_eval_data.tar.pth')
                accuracy = self.evaluate_model(i, eval_data['data'], eval_data['targets'])
                self.results[i][j] = accuracy
                print(f"Model {i}, Dataset {j}: Accuracy = {accuracy:.2f}%")

        return self.results

def print_results_matrix(results, num_models=10):
    print("\nAccuracy Matrix (%):")
    print("Model ID | " + " ".join(f"Dataset {i:2d}" for i in range(1, num_models + 1)))
    print("-" * (9 + num_models * 11))

    for model_id in range(1, num_models + 1):
        row = [f"Model {model_id:2d} |"]
        for dataset_id in range(1, num_models + 1):
            accuracy = results.get(model_id, {}).get(dataset_id, 0.0)
            if dataset_id <= model_id:
                row.append(f"{accuracy:8.2f}")
            else:
                row.append(" " * 8)
        print(" ".join(row))

# Usage example
if __name__ == "__main__":
    learner = SequentialLearner()
    results = learner.run_sequential_learning()
    print_results_matrix(results)


Training initial model...


  initial_data = torch.load('part_one_dataset/train_data/1_train_data.tar.pth')


Processing dataset 2...


  train_data = torch.load(f'part_one_dataset/train_data/{i}_train_data.tar.pth')
  eval_data = torch.load(f'part_one_dataset/eval_data/{j}_eval_data.tar.pth')


Model 2, Dataset 1: Accuracy = 58.20%
Model 2, Dataset 2: Accuracy = 58.80%
Processing dataset 3...
Model 3, Dataset 1: Accuracy = 57.00%
Model 3, Dataset 2: Accuracy = 58.56%
Model 3, Dataset 3: Accuracy = 57.28%
Processing dataset 4...
Model 4, Dataset 1: Accuracy = 56.32%
Model 4, Dataset 2: Accuracy = 58.12%
Model 4, Dataset 3: Accuracy = 56.80%
Model 4, Dataset 4: Accuracy = 56.28%
Processing dataset 5...
Model 5, Dataset 1: Accuracy = 55.76%
Model 5, Dataset 2: Accuracy = 57.56%
Model 5, Dataset 3: Accuracy = 56.20%
Model 5, Dataset 4: Accuracy = 55.76%
Model 5, Dataset 5: Accuracy = 56.20%
Processing dataset 6...
Model 6, Dataset 1: Accuracy = 55.36%
Model 6, Dataset 2: Accuracy = 56.92%
Model 6, Dataset 3: Accuracy = 55.60%
Model 6, Dataset 4: Accuracy = 55.24%
Model 6, Dataset 5: Accuracy = 56.24%
Model 6, Dataset 6: Accuracy = 55.80%
Processing dataset 7...
Model 7, Dataset 1: Accuracy = 55.16%
Model 7, Dataset 2: Accuracy = 56.52%
Model 7, Dataset 3: Accuracy = 55.08%
Model 