In [1]:
import torch
import torch.nn as nn
import math
import os
import copy
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 256
alpha = 0.01

In [3]:
tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

In [4]:
def load_data_set(batch_size=64):
    # Old data
    train_dataset = datasets.CIFAR10(root='data', train=True, download=True, transform=tf)
    test_dataset = datasets.CIFAR10(root='./data', train=False,download=True, transform=tf)

    source_size = int(0.7 * len(train_dataset))
    target_size = int(0.1 * len(train_dataset))
    val_size =int(0.2 * len(train_dataset))

    train_dataset, target_dataset, val_dataset = random_split(train_dataset, [source_size, target_size, val_size])

    source_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    target_dl = DataLoader(target_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last = True)
    
    test_dl =  DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last = True)

    return source_dl, target_dl, test_dl, val_dl

In [5]:
source_dl, target_dl, test_dl, val_dl = load_data_set(batch_size=batch_size)

Files already downloaded and verified
Files already downloaded and verified


# Model

In [6]:
class VGG(nn.Module):
    def __init__(self, out_dim = 10):
        super(VGG, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, out_dim),
            nn.Softmax(dim=1)
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 512)
        x = self.classifier(x)
        return x

# Pre-Train

In [7]:
def cal_acc(model, dataloader, device):
    model.eval()
    correct, total = 0., 0.
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

def pre_train(criterion, optimizer, model, num_epochs, trainloader, testloader, valloader, device):
    for epoch in range(num_epochs):
        model.train()
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        test_acc = cal_acc(model, testloader, device)
        val_acc = cal_acc(model, valloader, device)

        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Test Acc: {test_acc:.4f}, Val Acc: {val_acc:.4f}")

In [8]:
if not os.path.exists('source.pth'):
        model = VGG()
        model = model.to(device)

        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), 0.0005)

        num_epochs = 30
        pre_train(criterion, optimizer, model, num_epochs, source_dl, test_dl, val_dl, device)

        with torch.no_grad():
                torch.save(model.state_dict(), 'source.pth')
else:
        model = VGG().to(device)
        model.load_state_dict(torch.load('source.pth'))
        print("Loaded model from file.")

Loaded model from file.


In [11]:
def fine_tune(criterion, optimizer, model, num_epochs, trainloader, testloader, valloader, device):
    best_model_wts = None
    leader = VGG().to(device)
    best_loss = float('inf')

    for epoch in range(num_epochs):
        batch_num = 0
        total, correct = 0., 0.
        # if (best_model_wts):
            # model.load_state_dict(best_model_wts)
        
        for images, labels in trainloader:
            model.train()
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)

            reg_loss = 0
            for lead_para, follower_para in zip(leader.parameters(), model.parameters()):
                reg_loss += torch.norm(follower_para - lead_para, p = 2)
            
            classification_loss = criterion(outputs, labels)
            loss = classification_loss + alpha * reg_loss

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            loss.backward()
            optimizer.step()

            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for val_inputs, val_labels in valloader:
                    val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                    outputs = model(val_inputs)
                    batch_loss = criterion(outputs, val_labels)
                    val_loss += batch_loss.item()

                if val_loss < best_loss:
                    best_loss = val_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                    leader.load_state_dict(best_model_wts)
            
            print(f"Batch num: {batch_num}, c_loss: {classification_loss.item():.4f}, Val Loss: {val_loss:.4f}, loss : {loss.item():.4f}")
            batch_num += 1

        test_acc = cal_acc(model, testloader, device)

        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Train Acc: {correct/total:.4f}, Test Accuracy: {test_acc:.4f}")

In [12]:
model = VGG().to(device)
model.load_state_dict(torch.load('source.pth'))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 0.0005)
num_epochs = 30
fine_tune(criterion, optimizer, model, num_epochs, target_dl, test_dl, val_dl, device)

Batch num: 0, c_loss: 1.5263, Val Loss: 62.5958, loss : 5.7539
Batch num: 1, c_loss: 1.5618, Val Loss: 63.5568, loss : 1.5618
Batch num: 2, c_loss: 1.5836, Val Loss: 63.9667, loss : 1.6168
Batch num: 3, c_loss: 1.6129, Val Loss: 63.9123, loss : 1.6605
Batch num: 4, c_loss: 1.5978, Val Loss: 64.0024, loss : 1.6561
Batch num: 5, c_loss: 1.5620, Val Loss: 63.5031, loss : 1.6279
Batch num: 6, c_loss: 1.5753, Val Loss: 63.2309, loss : 1.6465
Batch num: 7, c_loss: 1.5581, Val Loss: 62.9128, loss : 1.6337
Batch num: 8, c_loss: 1.5871, Val Loss: 62.3527, loss : 1.6667
Batch num: 9, c_loss: 1.5621, Val Loss: 62.1038, loss : 1.5621
Batch num: 10, c_loss: 1.5870, Val Loss: 62.1098, loss : 1.5870
Batch num: 11, c_loss: 1.5598, Val Loss: 62.0841, loss : 1.5721
Batch num: 12, c_loss: 1.5599, Val Loss: 62.1422, loss : 1.5599
Batch num: 13, c_loss: 1.5541, Val Loss: 62.0479, loss : 1.5647
Batch num: 14, c_loss: 1.5575, Val Loss: 62.0243, loss : 1.5575
Batch num: 15, c_loss: 1.5810, Val Loss: 61.9960, 

KeyboardInterrupt: 