In [1]:
import torch
import torch.nn as nn
import math
import os
import copy
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 256

In [3]:
tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

In [4]:
def load_data_set(batch_size=64):
    # Old data
    train_dataset = datasets.CIFAR10(root='data', train=True, download=True, transform=tf)
    test_dataset = datasets.CIFAR10(root='./data', train=False,download=True, transform=tf)

    source_size = int(0.7 * len(train_dataset))
    target_size = int(0.1 * len(train_dataset))
    val_size =int(0.2 * len(train_dataset))

    train_dataset, target_dataset, val_dataset = random_split(train_dataset, [source_size, target_size, val_size])

    source_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    target_dl = DataLoader(target_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last = True)
    
    test_dl =  DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last = True)

    return source_dl, target_dl, test_dl, val_dl

In [5]:
source_dl, target_dl, test_dl, val_dl = load_data_set(batch_size=batch_size)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
class VGG(nn.Module):
    def __init__(self, out_dim = 10):
        super(VGG, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, out_dim),
            nn.Softmax(dim=1)
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 512)
        x = self.classifier(x)
        return x

In [7]:
def cal_acc(model, dataloader, device):
    model.eval()
    correct, total = 0., 0.
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [8]:
def pre_train(criterion, optimizer, model, num_epochs, trainloader, testloader, valloader, device):
    for epoch in range(num_epochs):
        model.train()
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        test_acc = cal_acc(model, testloader, device)
        val_acc = cal_acc(model, valloader, device)

        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Test Acc: {test_acc:.4f}, Val Acc: {val_acc:.4f}")

In [9]:
model = VGG()
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 0.0005)

num_epochs = 30
pre_train(criterion, optimizer, model, num_epochs, source_dl, test_dl, val_dl, device)

with torch.no_grad():
    torch.save(model.state_dict(), 'source.pth')

Epoch [1/30], Loss: 1.8270, Test Acc: 0.4903, Val Acc: 0.4921
Epoch [2/30], Loss: 1.7598, Test Acc: 0.5096, Val Acc: 0.5178
Epoch [3/30], Loss: 1.7122, Test Acc: 0.5774, Val Acc: 0.5908
Epoch [4/30], Loss: 1.6877, Test Acc: 0.6051, Val Acc: 0.6126
Epoch [5/30], Loss: 1.6384, Test Acc: 0.6614, Val Acc: 0.6709
Epoch [6/30], Loss: 1.6385, Test Acc: 0.6940, Val Acc: 0.6984
Epoch [7/30], Loss: 1.6048, Test Acc: 0.7291, Val Acc: 0.7369
Epoch [8/30], Loss: 1.6055, Test Acc: 0.7258, Val Acc: 0.7389
Epoch [9/30], Loss: 1.5618, Test Acc: 0.7406, Val Acc: 0.7508
Epoch [10/30], Loss: 1.5853, Test Acc: 0.7010, Val Acc: 0.7157
Epoch [11/30], Loss: 1.5336, Test Acc: 0.6760, Val Acc: 0.6856
Epoch [12/30], Loss: 1.5474, Test Acc: 0.7439, Val Acc: 0.7568
Epoch [13/30], Loss: 1.5300, Test Acc: 0.7484, Val Acc: 0.7598
Epoch [14/30], Loss: 1.5149, Test Acc: 0.7812, Val Acc: 0.7875
Epoch [15/30], Loss: 1.5395, Test Acc: 0.7652, Val Acc: 0.7715
Epoch [16/30], Loss: 1.5240, Test Acc: 0.7745, Val Acc: 0.7864
E

In [10]:
def easy_fine_tune(criterion, optimizer, model, num_epochs, trainloader, testloader, device):
    total, correct = 0., 0.
    for epoch in range(num_epochs):
        for images, labels in trainloader:
            total = 0.
            correct = 0.
            model.train()
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            
            classification_loss = criterion(outputs, labels)
            loss = classification_loss

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            loss.backward()
            optimizer.step()

        test_acc = cal_acc(model, testloader, device)

        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Train Acc: {correct/total:.4f}, Test Accuracy: {test_acc:.4f}")

In [11]:
model = VGG().to(device)
model.load_state_dict(torch.load('source.pth'))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 0.0005)
num_epochs = 30
easy_fine_tune(criterion, optimizer, model, num_epochs, target_dl, test_dl, device)

Epoch [1/30], Loss: 1.7492, Train Acc: 0.7109, Test Accuracy: 0.7553
Epoch [2/30], Loss: 1.6922, Train Acc: 0.7734, Test Accuracy: 0.7840
Epoch [3/30], Loss: 1.6593, Train Acc: 0.8008, Test Accuracy: 0.7671
Epoch [4/30], Loss: 1.6545, Train Acc: 0.8008, Test Accuracy: 0.7982
Epoch [5/30], Loss: 1.6295, Train Acc: 0.8281, Test Accuracy: 0.7760
Epoch [6/30], Loss: 1.6094, Train Acc: 0.8477, Test Accuracy: 0.7925
Epoch [7/30], Loss: 1.6028, Train Acc: 0.8594, Test Accuracy: 0.7844
Epoch [8/30], Loss: 1.6082, Train Acc: 0.8516, Test Accuracy: 0.7815
Epoch [9/30], Loss: 1.5865, Train Acc: 0.8750, Test Accuracy: 0.7992
Epoch [10/30], Loss: 1.5850, Train Acc: 0.8789, Test Accuracy: 0.7958
Epoch [11/30], Loss: 1.5689, Train Acc: 0.8906, Test Accuracy: 0.7944
Epoch [12/30], Loss: 1.5565, Train Acc: 0.9023, Test Accuracy: 0.7956
Epoch [13/30], Loss: 1.5559, Train Acc: 0.9062, Test Accuracy: 0.7861
Epoch [14/30], Loss: 1.5582, Train Acc: 0.9062, Test Accuracy: 0.7906
Epoch [15/30], Loss: 1.5406, 