In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm  # Import tqdm for progress bars

# Set device to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Using device:', device)

Using device: cuda:0


In [7]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load the full CIFAR-10 dataset
full_trainset = datasets.CIFAR10(root='./cifar10', train=True, download=True, transform=transform)
full_testset = datasets.CIFAR10(root='./cifar10', train=False, download=True, transform=transform)


Files already downloaded and verified
Files already downloaded and verified


In [20]:
# Get labels from datasets
train_labels = full_trainset.targets
test_labels = full_testset.targets

# Class indices for animals and machines
animal_classes = [2, 3, 4, 5, 6, 7]  # Corresponds to Bird, Cat, Deer, Dog, Frog, Horse
machine_classes = [0, 1, 8, 9]       # Corresponds to Airplane, Automobile, Ship, Truck

# Prepare animal training dataset
animal_indices = [i for i, label in enumerate(train_labels) if label in animal_classes]
animal_dataset = Subset(full_trainset, animal_indices)
animal_loader = DataLoader(animal_dataset, batch_size=64, shuffle=True, num_workers=2)

# Prepare machine training dataset
machine_indices = [i for i, label in enumerate(train_labels) if label in machine_classes]
machine_dataset = Subset(full_trainset, machine_indices)
machine_loader = DataLoader(machine_dataset, batch_size=64, shuffle=True, num_workers=2)

# Prepare animal test dataset
animal_test_indices = [i for i, label in enumerate(test_labels) if label in animal_classes]
animal_testset = Subset(full_testset, animal_test_indices)
animal_test_loader = DataLoader(animal_testset, batch_size=64, shuffle=False, num_workers=2)

# Full test dataset
full_test_loader = DataLoader(full_testset, batch_size=64, shuffle=False, num_workers=2)

In [21]:
class SimpleCNN(nn.Module):
    """ Basic CNN model """
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1) # 3x3x32
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1) # 32x32x64
        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [22]:
# Instantiate the model
simple_CNN = SimpleCNN()
simple_CNN.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()            # Cross-entropy loss for classification
optimizer = optim.SGD(simple_CNN.parameters(), lr=0.001, momentum=0.9)

In [24]:
# Training loop with tqdm progress bars
num_epochs = 10
for epoch in range(num_epochs):  # Loop over the dataset multiple times
    simple_CNN.train()  # Set the model to training mode
    running_loss = 0.0

    # Create a progress bar for batches
    batch_bar = tqdm(enumerate(animal_loader), total=len(animal_loader), desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    for i, data in batch_bar:
        # Get inputs and labels
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = simple_CNN(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Update running loss
        running_loss += loss.item()

        # Update tqdm progress bar every batch with current loss
        batch_bar.set_postfix(loss=loss.item())

    # Calculate and print epoch-level average loss
    epoch_loss = running_loss / len(animal_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}] - Average Loss: {epoch_loss:.4f}')
        
print('Finished Training on animal dataset')

# Save the trained model
PATH = './simple_CNN_1.pth'
torch.save(simple_CNN.state_dict(), PATH)

                                                                        

Epoch [1/10] - Average Loss: 1.4717


                                                                        

Epoch [2/10] - Average Loss: 1.3894


                                                                        

Epoch [3/10] - Average Loss: 1.3316


                                                                         

Epoch [4/10] - Average Loss: 1.2698


                                                                         

Epoch [5/10] - Average Loss: 1.2297


                                                                         

Epoch [6/10] - Average Loss: 1.1899


                                                                         

Epoch [7/10] - Average Loss: 1.1507


                                                                         

Epoch [8/10] - Average Loss: 1.1175


                                                                         

Epoch [9/10] - Average Loss: 1.0797


                                                                          

Epoch [10/10] - Average Loss: 1.0416
Finished Training on animal dataset


In [26]:
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in tqdm(loader, desc="Testing"):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)  # Predicted class
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

In [27]:
model = simple_CNN
task1_test_loader = animal_test_loader
task2_test_loader = machine_test_loader
print("Accuracy on Animal Test Data after train on Animal:", evaluate(model, task1_test_loader))
print("Accuracy on Machine Test Data after train on Animal:", evaluate(model, task2_test_loader))


Testing: 100%|██████████| 94/94 [00:01<00:00, 82.34it/s]


Accuracy on Animal Test Data after train on Animal: 59.86666666666667


Testing: 100%|██████████| 63/63 [00:00<00:00, 83.92it/s]

Accuracy on Machine Test Data after train on Animal: 0.0





In [28]:
# Training loop with tqdm progress bars
num_epochs = 10
for epoch in range(num_epochs):  # Loop over the dataset multiple times
    simple_CNN.train()  # Set the model to training mode
    running_loss = 0.0

    # Create a progress bar for batches
    batch_bar = tqdm(enumerate(machine_loader), total=len(machine_loader), desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    for i, data in batch_bar:
        # Get inputs and labels
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = simple_CNN(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Update running loss
        running_loss += loss.item()

        # Update tqdm progress bar every batch with current loss
        batch_bar.set_postfix(loss=loss.item())

    # Calculate and print epoch-level average loss
    epoch_loss = running_loss / len(animal_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}] - Average Loss: {epoch_loss:.4f}')
    
print('Finished Training on machine dataset without EWC')


# Save the trained model
PATH = './simple_CNN_2.pth'
torch.save(simple_CNN.state_dict(), PATH)

                                                                         

Epoch [1/10] - Average Loss: 0.8206


                                                                         

Epoch [2/10] - Average Loss: 0.5710


                                                                         

Epoch [3/10] - Average Loss: 0.5019


                                                                         

Epoch [4/10] - Average Loss: 0.4585


                                                                         

Epoch [5/10] - Average Loss: 0.4260


                                                                         

Epoch [6/10] - Average Loss: 0.4032


                                                                         

Epoch [7/10] - Average Loss: 0.3864


                                                                         

Epoch [8/10] - Average Loss: 0.3674


                                                                         

Epoch [9/10] - Average Loss: 0.3540


                                                                          

Epoch [10/10] - Average Loss: 0.3414
Finished Training on machine dataset without EWC




In [53]:
model = simple_CNN

print("Accuracy on Animal Test Data:", evaluate(model, animal_test_loader))
print("Accuracy on Machine Test Data:", evaluate(model, machine_test_loader))
print("Accuracy on Full Test Data:", evaluate(model, full_test_loader))

Testing: 100%|██████████| 94/94 [00:01<00:00, 86.61it/s]


Accuracy on Animal Test Data: 0.0


Testing: 100%|██████████| 63/63 [00:00<00:00, 85.51it/s]


Accuracy on Machine Test Data: 76.15


Testing: 100%|██████████| 157/157 [00:01<00:00, 80.42it/s]

Accuracy on Full Test Data: 30.46





In [49]:
# Implement EWC
class EWC(object):
    def __init__(self, model, dataloader):
        self.model = model
        self.dataloader = dataloader
        self.params = {n: p for n, p in model.named_parameters() if p.requires_grad}
        self._means = {}
        self._precision_matrices = self._diag_fisher()

        for n, p in self.params.items():
            self._means[n] = p.clone().detach()

    def _diag_fisher(self):
        precision_matrices = {}
        for n, p in self.params.items():
            precision_matrices[n] = torch.zeros_like(p)

        self.model.eval()
        for data in self.dataloader:
            self.model.zero_grad()
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = self.model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()

            for n, p in self.params.items():
                if p.grad is not None:
                    precision_matrices[n] += p.grad.detach() ** 2

        # Normalize the precision matrices
        for n in precision_matrices:
            precision_matrices[n] /= len(self.dataloader)

        return precision_matrices

    def penalty(self, model):
        loss = 0
        for n, p in model.named_parameters():
            if n in self._precision_matrices:
                _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2
                loss += _loss.sum()
        return loss

In [50]:
# Reload the model to the state after training on animal dataset
simple_CNN_EWC = SimpleCNN()
simple_CNN_EWC.to(device)
PATH = './simple_CNN_1.pth'
simple_CNN_EWC.load_state_dict(torch.load(PATH))

  simple_CNN_EWC.load_state_dict(torch.load(PATH))


<All keys matched successfully>

In [51]:
# Compute the EWC importance
ewc = EWC(simple_CNN_EWC, animal_loader)

# Define optimizer for EWC model
optimizer = optim.SGD(simple_CNN_EWC.parameters(), lr=0.001, momentum=0.9)
lambda_ewc = 20000  # EWC regularization strength

# Training loop on machine dataset with EWC
num_epochs = 10

for epoch in range(num_epochs):  # Loop over the dataset multiple times
    simple_CNN_EWC.train()  # Set the model to training mode
    running_loss = 0.0

    # Create a progress bar for batches
    batch_bar = tqdm(enumerate(machine_loader), total=len(machine_loader), desc=f"EWC Epoch {epoch+1}/{num_epochs}", leave=False)
    for i, data in batch_bar:
        # Get inputs and labels
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = simple_CNN_EWC(inputs)
        loss = criterion(outputs, labels) + (lambda_ewc / 2) * ewc.penalty(simple_CNN_EWC)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Update running loss
        running_loss += loss.item()

        # Update tqdm progress bar every batch with current loss
        batch_bar.set_postfix(loss=loss.item())

    # Calculate and print epoch-level average loss
    epoch_loss = running_loss / len(machine_loader)
    print(f'EWC Epoch [{epoch+1}/{num_epochs}] - Average Loss: {epoch_loss:.4f}')
    
print('Finished Training on Machine Dataset with EWC')


                                                                             

EWC Epoch [1/10] - Average Loss: 1.3768


                                                                             

EWC Epoch [2/10] - Average Loss: 0.9465


                                                                             

EWC Epoch [3/10] - Average Loss: 0.8671


                                                                             

EWC Epoch [4/10] - Average Loss: 0.8175


                                                                             

EWC Epoch [5/10] - Average Loss: 0.7877


                                                                             

EWC Epoch [6/10] - Average Loss: 0.7697


                                                                             

EWC Epoch [7/10] - Average Loss: 0.7511


                                                                             

EWC Epoch [8/10] - Average Loss: 0.7259


                                                                             

EWC Epoch [9/10] - Average Loss: 0.7054


                                                                              

EWC Epoch [10/10] - Average Loss: 0.7011
Finished Training on Machine Dataset with EWC




In [52]:
model = simple_CNN_EWC

print("Accuracy on Animal Test Data:", evaluate(model, animal_test_loader))
print("Accuracy on Machine Test Data:", evaluate(model, machine_test_loader))
print("Accuracy on Full Test Data:", evaluate(model, full_test_loader))

Testing: 100%|██████████| 94/94 [00:01<00:00, 77.62it/s]


Accuracy on Animal Test Data: 0.08333333333333333


Testing: 100%|██████████| 63/63 [00:00<00:00, 79.95it/s]


Accuracy on Machine Test Data: 72.3


Testing: 100%|██████████| 157/157 [00:01<00:00, 94.96it/s] 

Accuracy on Full Test Data: 28.97



