Skip to content

Commit

Permalink
fix: transfer net code
Browse files Browse the repository at this point in the history
  • Loading branch information
jindongwang committed Oct 16, 2019
1 parent e68354a commit 1411a1c
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 113 deletions.
196 changes: 196 additions & 0 deletions code/deep/DDC_DeepCoral/backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torch.autograd import Variable


# convnet without the last layer
class AlexNetFc(nn.Module):
def __init__(self):
super(AlexNetFc, self).__init__()
model_alexnet = models.alexnet(pretrained=True)
self.features = model_alexnet.features
self.classifier = nn.Sequential()
for i in range(6):
self.classifier.add_module(
"classifier"+str(i), model_alexnet.classifier[i])
self.__in_features = model_alexnet.classifier[6].in_features

def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256*6*6)
x = self.classifier(x)
return x

def output_num(self):
return self.__in_features


class ResNet18Fc(nn.Module):
def __init__(self):
super(ResNet18Fc, self).__init__()
model_resnet18 = models.resnet18(pretrained=True)
self.conv1 = model_resnet18.conv1
self.bn1 = model_resnet18.bn1
self.relu = model_resnet18.relu
self.maxpool = model_resnet18.maxpool
self.layer1 = model_resnet18.layer1
self.layer2 = model_resnet18.layer2
self.layer3 = model_resnet18.layer3
self.layer4 = model_resnet18.layer4
self.avgpool = model_resnet18.avgpool
self.__in_features = model_resnet18.fc.in_features

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x

def output_num(self):
return self.__in_features


class ResNet34Fc(nn.Module):
def __init__(self):
super(ResNet34Fc, self).__init__()
model_resnet34 = models.resnet34(pretrained=True)
self.conv1 = model_resnet34.conv1
self.bn1 = model_resnet34.bn1
self.relu = model_resnet34.relu
self.maxpool = model_resnet34.maxpool
self.layer1 = model_resnet34.layer1
self.layer2 = model_resnet34.layer2
self.layer3 = model_resnet34.layer3
self.layer4 = model_resnet34.layer4
self.avgpool = model_resnet34.avgpool
self.__in_features = model_resnet34.fc.in_features

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x

def output_num(self):
return self.__in_features


class ResNet50Fc(nn.Module):
def __init__(self):
super(ResNet50Fc, self).__init__()
model_resnet50 = models.resnet50(pretrained=True)
self.conv1 = model_resnet50.conv1
self.bn1 = model_resnet50.bn1
self.relu = model_resnet50.relu
self.maxpool = model_resnet50.maxpool
self.layer1 = model_resnet50.layer1
self.layer2 = model_resnet50.layer2
self.layer3 = model_resnet50.layer3
self.layer4 = model_resnet50.layer4
self.avgpool = model_resnet50.avgpool
self.__in_features = model_resnet50.fc.in_features

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x

def output_num(self):
return self.__in_features


class ResNet101Fc(nn.Module):
def __init__(self):
super(ResNet101Fc, self).__init__()
model_resnet101 = models.resnet101(pretrained=True)
self.conv1 = model_resnet101.conv1
self.bn1 = model_resnet101.bn1
self.relu = model_resnet101.relu
self.maxpool = model_resnet101.maxpool
self.layer1 = model_resnet101.layer1
self.layer2 = model_resnet101.layer2
self.layer3 = model_resnet101.layer3
self.layer4 = model_resnet101.layer4
self.avgpool = model_resnet101.avgpool
self.__in_features = model_resnet101.fc.in_features

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x

def output_num(self):
return self.__in_features


class ResNet152Fc(nn.Module):
def __init__(self):
super(ResNet152Fc, self).__init__()
model_resnet152 = models.resnet152(pretrained=True)
self.conv1 = model_resnet152.conv1
self.bn1 = model_resnet152.bn1
self.relu = model_resnet152.relu
self.maxpool = model_resnet152.maxpool
self.layer1 = model_resnet152.layer1
self.layer2 = model_resnet152.layer2
self.layer3 = model_resnet152.layer3
self.layer4 = model_resnet152.layer4
self.avgpool = model_resnet152.avgpool
self.__in_features = model_resnet152.fc.in_features

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x

def output_num(self):
return self.__in_features


network_dict = {"alexnet": AlexNetFc,
"resnet18": ResNet18Fc,
"resnet34": ResNet34Fc,
"resnet50": ResNet50Fc,
"resnet101": ResNet101Fc,
"resnet152": ResNet152Fc}
9 changes: 4 additions & 5 deletions code/deep/DDC_DeepCoral/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
CFG = {
'data_path': 'D:/data/Office31/Original_images/',
'kwargs' : {'num_workers': 4},
'kwargs': {'num_workers': 4},
'batch_size': 32,
'epoch': 200,
'epoch': 100,
'lr': 1e-3,
'momentum': .9,
'seed': 200,
'log_interval': 1,
'log_interval': 10,
'l2_decay': 0,
'lambda': 10,
'backbone': 'alexnet',
'n_class': 31,
}
}
6 changes: 3 additions & 3 deletions code/deep/DDC_DeepCoral/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torchvision import datasets, transforms
import torch

def load_data(root_path, dir, batch_size, train, kwargs):
def load_data(data_folder, batch_size, train, kwargs):
transform = {
'train': transforms.Compose(
[transforms.Resize([256, 256]),
Expand All @@ -16,7 +16,7 @@ def load_data(root_path, dir, batch_size, train, kwargs):
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
}
data = datasets.ImageFolder(root = root_path + dir, transform=transform['train' if train else 'test'])
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, **kwargs, drop_last = True)
data = datasets.ImageFolder(root = data_folder, transform=transform['train' if train else 'test'])
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, **kwargs, drop_last = True if train else False)
return data_loader

98 changes: 54 additions & 44 deletions code/deep/DDC_DeepCoral/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,38 @@
import data_loader
import models
from config import CFG
import utils

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def test(model, target_test_loader):
model.eval()
test_loss = utils.AverageMeter()
correct = 0
criterion = torch.nn.CrossEntropyLoss()
len_target_dataset = len(target_test_loader.dataset)
with torch.no_grad():
for data, target in target_test_loader:
data, target = data.to(DEVICE), target.to(DEVICE)
s_output = model.predict(data)
loss = criterion(s_output, target)
test_loss.update(loss.item())
pred = torch.max(s_output, 1)[1]
correct += torch.sum(pred == target)

print('{} --> {}: max correct: {}, accuracy{: .2f}%\n'.format(
source_name, target_name, correct, 100. * correct / len_target_dataset))


def train(source_loader, target_train_loader, target_test_loader, model, optimizer, CFG):
len_source_loader = len(source_loader)
len_target_loader = len(target_train_loader)
train_loss_clf = utils.AverageMeter()
train_loss_transfer = utils.AverageMeter()
train_loss_total = utils.AverageMeter()
for e in range(CFG['epoch']):
# Train
model.train()
model.isTrain = True
iter_source, iter_target = iter(
source_loader), iter(target_train_loader)
n_batch = min(len_source_loader, len_target_loader)
Expand All @@ -27,65 +48,54 @@ def train(source_loader, target_train_loader, target_test_loader, model, optimiz
data_target = data_target.to(DEVICE)

optimizer.zero_grad()
label_source_pred, loss_coral = model(data_source, data_target)
loss_cls = criterion(label_source_pred, label_source)
loss = loss_cls + CFG['lambda'] * loss_coral
label_source_pred, transfer_loss = model(data_source, data_target)
clf_loss = criterion(label_source_pred, label_source)
loss = clf_loss + CFG['lambda'] * transfer_loss
loss.backward()
optimizer.step()
train_loss_clf.update(clf_loss.item())
train_loss_transfer.update(transfer_loss.item())
train_loss_total.update(loss.item())
if i % CFG['log_interval'] == 0:
print('Train Epoch: [{}/{} ({:.0f}%)], \
total_Loss: {:.6f}, \
cls_Loss: {:.6f}, \
adapt_Loss: {:.6f}'.format(
print('Train Epoch: [{}/{} ({:02d}%)], cls_Loss: {:.6f}, transfer_loss: {:.6f}, total_Loss: {:.6f}'.format(
e + 1,
CFG['epoch'],
100. * i / len_source_loader, loss.item(), loss_cls.item(), loss_coral.item()))
int(100. * i / n_batch), train_loss_clf.avg, train_loss_transfer.avg, train_loss_total.avg))

# Test
model.eval()
test_loss = 0
correct = 0
criterion = torch.nn.CrossEntropyLoss()
len_target_dataset = len(target_test_loader.dataset)
with torch.no_grad():
model.isTrain = False
for data, target in target_test_loader:
data, target = data.to(DEVICE), target.to(DEVICE)
s_output, _ = model(data, None)
test_loss += criterion(s_output, target)
pred = torch.max(s_output, 1)[1]
correct += torch.sum(pred == target.data)

test_loss /= len_target_dataset
print('\n{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
target_name, test_loss, correct, len_target_dataset,
100. * correct / len_target_dataset))
print('source: {} to target: {} max correct: {} max accuracy{: .2f}%\n'.format(
source_name, target_name, correct, 100. * correct / len_target_dataset))
test(model, target_test_loader)


def load_data(src, tar, root_dir):
folder_src = root_dir + src + '/images/'
folder_tar = root_dir + tar + '/images/'
source_loader = data_loader.load_data(
root_dir, src, CFG['batch_size'], True, CFG['kwargs'])
folder_src, CFG['batch_size'], True, CFG['kwargs'])
target_train_loader = data_loader.load_data(
root_dir, tar, CFG['batch_size'], False, CFG['kwargs'])
folder_tar, CFG['batch_size'], True, CFG['kwargs'])
target_test_loader = data_loader.load_data(
root_dir, tar, CFG['batch_size'], False, CFG['kwargs'])
return source_loader, target_train_loader, target_test_loader
folder_tar, CFG['batch_size'], False, CFG['kwargs'])
return source_loader, target_train_loader, target_test_loader


if __name__ == '__main__':
torch.manual_seed(CFG['seed'])
torch.manual_seed(10)

source_name = "amazon"
target_name = "webcam"
source_name = "dslr"
target_name = "amazon"

print('Src: %s, Tar: %s' % (source_name, target_name))

source_loader, target_train_loader, target_test_loader = load_data(source_name, target_name, CFG['data_path'])
source_loader, target_train_loader, target_test_loader = load_data(
source_name, target_name, CFG['data_path'])

model = models.DeepCoral(CFG['n_class'],adapt_loss='mmd', backbone='alexnet').to(DEVICE)
model = models.Transfer_Net(
CFG['n_class'], transfer_loss='coral', base_net='alexnet').to(DEVICE)
optimizer = torch.optim.SGD([
{'params': model.sharedNet.parameters()},
{'params': model.fc.parameters()},
{'params': model.cls_fc.parameters(), 'lr': 10 * CFG['lr']},
{'params': model.base_network.parameters()},
{'params': model.bottleneck_layer.parameters(), 'lr': 10 * CFG['lr']},
{'params': model.classifier_layer.parameters(), 'lr': 10 * CFG['lr']},
], lr=CFG['lr'], momentum=CFG['momentum'], weight_decay=CFG['l2_decay'])

train(source_loader, target_train_loader, target_test_loader, model, optimizer, CFG)
train(source_loader, target_train_loader,
target_test_loader, model, optimizer, CFG)

0 comments on commit 1411a1c

Please sign in to comment.