In [1]:
import torch
import torch.nn as nn
import math
import copy
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 256
alpha = 0.001

In [3]:
tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

In [4]:
def load_data_set(batch_size=64):
    # Old data
    train_dataset = datasets.CIFAR10(root='data', train=True, download=True, transform=tf)

    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    source_indices = [i for i, (_, label) in enumerate(train_dataset) if label < 8]
    target_indices = [i for i, (_, label) in enumerate(train_dataset) if (label == 9)]

    source_set = Subset(train_dataset, source_indices)
    target_set = Subset(train_dataset, target_indices)

    source_dl_train = DataLoader(source_set, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)
    target_dl_train = DataLoader(target_set, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)
    val_dl= DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last = True)

    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=tf)
    
    test_dl =  DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last = True)

    return source_dl_train, target_dl_train, test_dl, val_dl

# Model

In [5]:
class VGG(nn.Module):
    def __init__(self, out_dim = 10):
        super(VGG, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, out_dim),
            nn.Softmax(dim=1)
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 512)
        x = self.classifier(x)
        return x

# Pre-Train

In [6]:
def pre_train(criterion, optimizer, model, num_epochs, trainloader, testloader, device):
    train_loss_history = []
    test_acc_history = []

    for epoch in range(num_epochs):
        model.train()
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        correct = 0
        total = 0
        model.eval()

        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            epoch_acc = correct / total

            print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Test Accuracy: {epoch_acc:.4f}")
            test_acc_history.append(epoch_acc)
            train_loss_history.append(loss.item())


    return train_loss_history, test_acc_history

In [7]:
import os
if not os.path.exists('source.pth'):
        source_dl_train, target_dl_train, test_dl, val_dl = load_data_set(batch_size=batch_size)
        model = VGG()
        model = model.to(device)

        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), 0.0005)

        num_epochs = 30
        loss_history, test_acc_history = pre_train(criterion, optimizer, model, num_epochs, source_dl_train, test_dl, device)

        with torch.no_grad():
                torch.save(model.state_dict(), 'source.pth')
else:
        model = VGG().to(device)
        model.load_state_dict(torch.load('source.pth'))
        print("Loaded model from file.")


Loaded model from file.


# Fine tune

In [8]:
def fine_tune(criterion, optimizer, model, num_epochs, trainloader, testloader, valloader, device):
    best_model_wts = None
    leader = VGG().to(device)
    best_loss = float('inf')

    for epoch in range(num_epochs):
        batch_num = 0
        if (best_model_wts):
            model.load_state_dict(best_model_wts)
        
        for images, labels in trainloader:
            model.train()
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)

            '''reg_loss = 0
            for lead_para, follower_para in zip(leader.parameters(), model.parameters()):
                reg_loss += torch.norm(follower_para - lead_para, p = 2)'''
            
            classification_loss = criterion(outputs, labels)
            loss = classification_loss #+ reg_loss

            loss.backward()
            optimizer.step()

            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for val_inputs, val_labels in valloader:
                    val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                    outputs = model(val_inputs)
                    batch_loss = criterion(outputs, val_labels)
                    val_loss += batch_loss.item()

                if val_loss < best_loss:
                    best_loss = val_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                    leader.load_state_dict(best_model_wts)
            
            print(f"Batch num: {batch_num}, classification_loss: {classification_loss.item()}, Val Loss: {val_loss}, loss : {loss.item()}")
            batch_num += 1

        correct = 0
        total = 0
        model.eval()

        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            epoch_acc = correct / total

            print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Test Accuracy: {epoch_acc:.4f}")

In [9]:
source_dl_train, target_dl_train, test_dl, val_dl = load_data_set(batch_size=batch_size)

Files already downloaded and verified
Files already downloaded and verified


In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 0.0005)
num_epochs = 30
fine_tune(criterion, optimizer, model, num_epochs, target_dl_train, test_dl, val_dl, device)

Batch num: 0, classification_loss: 2.4564502239227295, Val Loss: 69.08209574222565, loss : 2.4564502239227295
Batch num: 1, classification_loss: 2.456946611404419, Val Loss: 69.77781546115875, loss : 2.456946611404419
Batch num: 2, classification_loss: 2.4577531814575195, Val Loss: 70.71295011043549, loss : 2.4577531814575195
Batch num: 3, classification_loss: 2.4576566219329834, Val Loss: 71.45107674598694, loss : 2.4576566219329834
Batch num: 4, classification_loss: 2.4570491313934326, Val Loss: 72.54166460037231, loss : 2.4570491313934326
Batch num: 5, classification_loss: 2.4578373432159424, Val Loss: 73.22082340717316, loss : 2.4578373432159424
Batch num: 6, classification_loss: 2.456967830657959, Val Loss: 73.78667771816254, loss : 2.456967830657959
Batch num: 7, classification_loss: 2.456411123275757, Val Loss: 74.47651863098145, loss : 2.456411123275757
Batch num: 8, classification_loss: 2.4575998783111572, Val Loss: 74.75889432430267, loss : 2.4575998783111572
Batch num: 9, cl

KeyboardInterrupt: 